Commit 73aa6a4e authored by Leon Wu's avatar Leon Wu
Browse files

Start full training conditions

parent eef98894
Loading
Loading
Loading
Loading
+15 −15
Original line number Diff line number Diff line
@@ -97,7 +97,6 @@ codec = EncodecWrapper()

#############

raise AssertionError("note to self, try larger batch size and grad update https://github.com/lucidrains/audiolm-pytorch/discussions/107#discussioncomment-5373414")
wav2vec = HubertWithKmeans(
    # use_mert = True,
    checkpoint_path = f"{prefix}/{hubert_ckpt}",
@@ -115,9 +114,12 @@ semantic_trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    batch_size = 32,
    grad_accum_every = 16,
    data_max_length = 320 * 32,
    num_train_steps = 1,
    num_train_steps = 200001,
    save_results_every = 20000,
    save_model_every = 20000,
    results_folder = f"{prefix}/semantic_results",
)

@@ -138,15 +140,14 @@ coarse_trainer = CoarseTransformerTrainer(
    codec = codec,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    batch_size = 32,
    grad_accum_every = 16,
    data_max_length = 320 * 32,
    results_folder = f"{prefix}/coarse_results",
    save_results_every = 4,
    save_model_every = 4,
    num_train_steps = 9
    num_train_steps = 200001,
    save_results_every = 20000,
    save_model_every = 20000,
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

coarse_trainer.train()

@@ -164,15 +165,14 @@ fine_trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = codec,
    folder = dataset_folder,
    batch_size = 1,
    batch_size = 32,
    grad_accum_every = 16,
    data_max_length = 320 * 32,
    save_results_every = 4,
    save_model_every = 4,
    num_train_steps = 9,
    num_train_steps = 200001,
    save_results_every = 20000,
    save_model_every = 20000,
    results_folder = f"{prefix}/fine_results",
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

fine_trainer.train()