Commit c46df011 authored by Leon Wu's avatar Leon Wu
Browse files

switch to 24khz for encodec purposes (thanks tchambs#6840 on LAION Discord!)

parent 4750ff07
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ hubert_ckpt = f'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer

# Placeholder data generation
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=16000.0):
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=24000.0):
  # code adapted from https://stackoverflow.com/a/33913403
  audio = []
  num_samples = duration_ms * (sample_rate / 1000.0)
@@ -30,13 +30,13 @@ def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=16000.0):
    audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))
  return audio

def save_wav(file_name, audio, sample_rate=16000.0):
def save_wav(file_name, audio, sample_rate=24000.0):
  # Open up a wav file
  wav_file=wave.open(file_name,"w")
  # wav params
  nchannels = 1
  sampwidth = 2
  # 16000 is the industry standard sample rate - CD quality.  If you need to
  # 24000 is the industry standard sample rate - CD quality.  If you need to
  # save on file size you can adjust it downwards. The stanard for low quality
  # is 8000 or 8kHz.
  nframes = len(audio)
@@ -183,14 +183,14 @@ audiolm = AudioLM(

generated_wav = audiolm(batch_size = 1)
output_path = f"{prefix}/out.wav"
sample_rate = 16000
sample_rate = 24000
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)

# with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, use_cuda=True) as prof:
#     with record_function("model_inference"):
#         generated_wav = audiolm(batch_size = 1)
#         output_path = f"{prefix}/out.wav"
#         sample_rate = 16000
#         sample_rate = 24000
#         torchaudio.save(output_path, generated_wav.cpu(), sample_rate)

# filename = f"{prefix}/profile-{datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}.txt"