Commit 1b9081f0 authored by Leon Wu's avatar Leon Wu
Browse files

switch to encodec

parent c46df011
Loading
Loading
Loading
Loading
+29 −25
Original line number Diff line number Diff line
@@ -5,7 +5,9 @@ import struct
import os
import urllib.request
# import tarfile
from audiolm_pytorch import AudioLMSoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
# from audiolm_pytorch import AudioLMSoundStream, SoundStreamTrainer
from audiolm_pytorch import EncodecWrapper
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio
@@ -68,27 +70,29 @@ make_placeholder_dataset()

#######

soundstream = AudioLMSoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
    attn_window_size = 128,       # local attention receptive field at bottleneck
    attn_depth = 2                # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
)

soundstream_trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    lr=3e-4,
    batch_size = 4,
    grad_accum_every = 8, # effective batch size of batch_size * grad_accum_every = 32
    data_max_length_seconds = 2,  # train on 2 second audio
    results_folder = f"{prefix}/soundstream_results",
    save_results_every = 4,
    save_model_every = 4,
    num_train_steps = 9
).cuda()

soundstream_trainer.train() # skip soundstream for now
# soundstream = AudioLMSoundStream(
#     codebook_size = 1024,
#     rq_num_quantizers = 8,
#     attn_window_size = 128,       # local attention receptive field at bottleneck
#     attn_depth = 2                # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
# )

# soundstream_trainer = SoundStreamTrainer(
#     soundstream,
#     folder = dataset_folder,
#     lr=3e-4,
#     batch_size = 4,
#     grad_accum_every = 8, # effective batch size of batch_size * grad_accum_every = 32
#     data_max_length_seconds = 2,  # train on 2 second audio
#     results_folder = f"{prefix}/soundstream_results",
#     save_results_every = 4,
#     save_model_every = 4,
#     num_train_steps = 9
# ).cuda()

# soundstream_trainer.train()

encodec = EncodecWrapper()

#############

@@ -129,7 +133,7 @@ coarse_transformer = CoarseTransformer(

coarse_trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    soundstream = encodec,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
@@ -156,7 +160,7 @@ fine_transformer = FineTransformer(

fine_trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    soundstream = encodec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
@@ -175,7 +179,7 @@ fine_trainer.train()

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    soundstream = encodec,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer