Commit 5d5265e8 authored by Leon Wu's avatar Leon Wu
Browse files

fix sampling rate of demo wavs

parent c23d483d
Loading
Loading
Loading
Loading
+28 −7
Original line number Diff line number Diff line
@@ -9,6 +9,8 @@ from audiolm_pytorch import AudioLMSoundStream, SoundStreamTrainer, HubertWithKm
from torch import nn
import torch
import torchaudio
from torch.profiler import profile, record_function, ProfilerActivity
import datetime
# import boto3
# import datetime
# from botocore.errorfactory import ClientError
@@ -20,7 +22,7 @@ hubert_ckpt = f'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

# Placeholder data generation
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=16000.0):
  # code adapted from https://stackoverflow.com/a/33913403
  audio = []
  num_samples = duration_ms * (sample_rate / 1000.0)
@@ -28,13 +30,13 @@ def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):
    audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))
  return audio

def save_wav(file_name, audio, sample_rate=44100.0):
def save_wav(file_name, audio, sample_rate=16000.0):
  # Open up a wav file
  wav_file=wave.open(file_name,"w")
  # wav params
  nchannels = 1
  sampwidth = 2
  # 44100 is the industry standard sample rate - CD quality.  If you need to
  # 16000 is the industry standard sample rate - CD quality.  If you need to
  # save on file size you can adjust it downwards. The stanard for low quality
  # is 8000 or 8kHz.
  nframes = len(audio)
@@ -175,7 +177,26 @@ audiolm = AudioLM(
    fine_transformer = fine_transformer
)

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, use_cuda=True) as prof:
    with record_function("model_inference"):
        generated_wav = audiolm(batch_size = 1)
        output_path = f"{prefix}/out.wav"
sample_rate = 44100
        sample_rate = 16000
        torchaudio.save(output_path, generated_wav.cpu(), sample_rate)

filename = f"{prefix}/profile-{datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")}.txt"
with open(filename, "w") as f:
    f.write("cpu time sorted:\n")
    f.write(f"{prof.key_averages(group_by_input_shape=True).table(sort_by="cpu_time_total", row_limit=10)}")
    f.write("\n cuda time sorted:\n")
    f.write(f"{prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)}")
    f.write("\ncpu memory self\n") # excludes children memory allocated
    f.write(f"{prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10)}")
    f.write("\ncpu memory\n") # includes children memory allocated
    f.write(f"{prof.key_averages().table(sort_by="cpu_memory_usage", row_limit=10)}\n")