Loading audiolm_pytorch/trainer.py +75 −48 Original line number Diff line number Diff line Loading @@ -118,33 +118,39 @@ class SoundStreamTrainer(nn.Module): self, soundstream: SoundStream, *, num_train_steps, batch_size, data_max_length = None, data_max_length_seconds = None, folder, lr = 2e-4, grad_accum_every = 4, wd = 0., max_grad_norm = 0.5, discr_max_grad_norm = None, save_results_every = 100, save_model_every = 1000, log_losses_every = 1, results_folder = './results', valid_frac = 0.05, random_split_seed = 42, use_ema = True, ema_beta = 0.995, ema_update_after_step = 500, ema_update_every = 10, apply_grad_penalty_every = 4, dl_num_workers = 0, num_train_steps: int, batch_size: int, data_max_length: int = None, data_max_length_seconds: float = None, folder: str = None, train_dataloader: DataLoader = None, val_dataloader: DataLoader = None, lr: float = 2e-4, grad_accum_every: int = 4, wd: float = 0., max_grad_norm: float = 0.5, discr_max_grad_norm: float = None, save_results_every: int = 100, save_model_every: int= 1000, log_losses_every: int= 1, results_folder: str = './results', valid_frac: float = 0.05, random_split_seed: int = 42, use_ema: bool = True, ema_beta: float = 0.995, ema_update_after_step: int = 500, ema_update_every: int = 10, apply_grad_penalty_every: int = 4, dl_num_workers: int = 0, accelerator: Accelerator = None, accelerate_kwargs: dict = dict(), use_lion = False, force_clear_prev_results = None # set to True | False to skip the prompt use_lion: bool = False, force_clear_prev_results: bool = None # set to True | False to skip the prompt ): """ Initialize with a SoundStream instance and either a folder containing audio data or train/val DataLoader instances. """ super().__init__() if accelerator: Loading @@ -166,6 +172,14 @@ class SoundStreamTrainer(nn.Module): self.batch_size = batch_size self.grad_accum_every = grad_accum_every hyperparameters = { "num_train_steps": num_train_steps, "batch_size": batch_size, "gradient_accum_every": grad_accum_every, "learning_rate": lr, "target_sample_hz": soundstream.target_sample_hz, } # optimizers self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd) Loading @@ -181,12 +195,24 @@ class SoundStreamTrainer(nn.Module): self.max_grad_norm = max_grad_norm self.discr_max_grad_norm = discr_max_grad_norm # create dataset if folder is None: assert train_dataloader is not None assert val_dataloader is not None self.dl = train_dataloader self.valid_dl = val_dataloader else: assert train_dataloader is None assert val_dataloader is None assert not (exists(data_max_length) and exists(data_max_length_seconds)) # create dataset if exists(data_max_length_seconds): data_max_length = data_max_length_seconds * soundstream.target_sample_hz assert not exists(data_max_length) data_max_length = int(data_max_length_seconds * soundstream.target_sample_hz) else: assert exists(data_max_length) hyperparameters['data_max_length'] = data_max_length self.ds = SoundDataset( folder, Loading Loading @@ -251,8 +277,9 @@ class SoundStreamTrainer(nn.Module): self.results_folder.mkdir(parents = True, exist_ok = True) hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr} self.accelerator.init_trackers("soundstream", config=hps) # Initialize experiment trackers if an external Accelerator is not passed in if not accelerator: self.accelerator.init_trackers("soundstream", config=hyperparameters) def set_model_as_ema_model_(self): """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """ Loading Loading
audiolm_pytorch/trainer.py +75 −48 Original line number Diff line number Diff line Loading @@ -118,33 +118,39 @@ class SoundStreamTrainer(nn.Module): self, soundstream: SoundStream, *, num_train_steps, batch_size, data_max_length = None, data_max_length_seconds = None, folder, lr = 2e-4, grad_accum_every = 4, wd = 0., max_grad_norm = 0.5, discr_max_grad_norm = None, save_results_every = 100, save_model_every = 1000, log_losses_every = 1, results_folder = './results', valid_frac = 0.05, random_split_seed = 42, use_ema = True, ema_beta = 0.995, ema_update_after_step = 500, ema_update_every = 10, apply_grad_penalty_every = 4, dl_num_workers = 0, num_train_steps: int, batch_size: int, data_max_length: int = None, data_max_length_seconds: float = None, folder: str = None, train_dataloader: DataLoader = None, val_dataloader: DataLoader = None, lr: float = 2e-4, grad_accum_every: int = 4, wd: float = 0., max_grad_norm: float = 0.5, discr_max_grad_norm: float = None, save_results_every: int = 100, save_model_every: int= 1000, log_losses_every: int= 1, results_folder: str = './results', valid_frac: float = 0.05, random_split_seed: int = 42, use_ema: bool = True, ema_beta: float = 0.995, ema_update_after_step: int = 500, ema_update_every: int = 10, apply_grad_penalty_every: int = 4, dl_num_workers: int = 0, accelerator: Accelerator = None, accelerate_kwargs: dict = dict(), use_lion = False, force_clear_prev_results = None # set to True | False to skip the prompt use_lion: bool = False, force_clear_prev_results: bool = None # set to True | False to skip the prompt ): """ Initialize with a SoundStream instance and either a folder containing audio data or train/val DataLoader instances. """ super().__init__() if accelerator: Loading @@ -166,6 +172,14 @@ class SoundStreamTrainer(nn.Module): self.batch_size = batch_size self.grad_accum_every = grad_accum_every hyperparameters = { "num_train_steps": num_train_steps, "batch_size": batch_size, "gradient_accum_every": grad_accum_every, "learning_rate": lr, "target_sample_hz": soundstream.target_sample_hz, } # optimizers self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd) Loading @@ -181,12 +195,24 @@ class SoundStreamTrainer(nn.Module): self.max_grad_norm = max_grad_norm self.discr_max_grad_norm = discr_max_grad_norm # create dataset if folder is None: assert train_dataloader is not None assert val_dataloader is not None self.dl = train_dataloader self.valid_dl = val_dataloader else: assert train_dataloader is None assert val_dataloader is None assert not (exists(data_max_length) and exists(data_max_length_seconds)) # create dataset if exists(data_max_length_seconds): data_max_length = data_max_length_seconds * soundstream.target_sample_hz assert not exists(data_max_length) data_max_length = int(data_max_length_seconds * soundstream.target_sample_hz) else: assert exists(data_max_length) hyperparameters['data_max_length'] = data_max_length self.ds = SoundDataset( folder, Loading Loading @@ -251,8 +277,9 @@ class SoundStreamTrainer(nn.Module): self.results_folder.mkdir(parents = True, exist_ok = True) hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr} self.accelerator.init_trackers("soundstream", config=hps) # Initialize experiment trackers if an external Accelerator is not passed in if not accelerator: self.accelerator.init_trackers("soundstream", config=hyperparameters) def set_model_as_ema_model_(self): """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """ Loading