Commit ae15bdd0 authored by Leon Wu's avatar Leon Wu
Browse files

try other codec

parent e006097f
Loading
Loading
Loading
Loading
+26 −26
Original line number Diff line number Diff line
@@ -70,29 +70,29 @@ make_placeholder_dataset()

#######

# codec = AudioLMSoundStream(
#     codebook_size = 1024,
#     rq_num_quantizers = 8,
#     attn_window_size = 128,       # local attention receptive field at bottleneck
#     attn_depth = 2                # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
# )

# soundstream_trainer = SoundStreamTrainer(
#     soundstream,
#     folder = dataset_folder,
#     lr=3e-4,
#     batch_size = 4,
#     grad_accum_every = 8, # effective batch size of batch_size * grad_accum_every = 32
#     data_max_length_seconds = 2,  # train on 2 second audio
#     results_folder = f"{prefix}/soundstream_results",
#     save_results_every = 4,
#     save_model_every = 4,
#     num_train_steps = 9
# ).cuda()

# soundstream_trainer.train()

encodec = EncodecWrapper()
codec = AudioLMSoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
    attn_window_size = 128,       # local attention receptive field at bottleneck
    attn_depth = 2                # 2 local attention transformer blocks - the soundstream folks were not experts with attention, so i took the liberty to add some. encodec went with lstms, but attention should be better
)

soundstream_trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    lr=3e-4,
    batch_size = 4,
    grad_accum_every = 8, # effective batch size of batch_size * grad_accum_every = 32
    data_max_length_seconds = 2,  # train on 2 second audio
    results_folder = f"{prefix}/soundstream_results",
    save_results_every = 4,
    save_model_every = 4,
    num_train_steps = 9
).cuda()

soundstream_trainer.train()

# codec = EncodecWrapper()

#############

@@ -134,7 +134,7 @@ coarse_transformer = CoarseTransformer(

coarse_trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = encodec,
    codec = codec,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
@@ -161,7 +161,7 @@ fine_transformer = FineTransformer(

fine_trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = encodec,
    codec = codec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
@@ -180,7 +180,7 @@ fine_trainer.train()

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = encodec,
    codec = codec,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer