Loading audiolm_pytorch/data.py +1 −1 Original line number Diff line number Diff line Loading @@ -36,7 +36,7 @@ class SoundDataset(Dataset): def __init__( self, folder, exts = ['flac', 'wav'], exts = ['flac', 'wav', 'mp3', 'webm'], max_length: OptionalIntOrTupleInt = None, target_sample_hz: OptionalIntOrTupleInt = None, seq_len_multiple_of: OptionalIntOrTupleInt = None Loading audiolm_pytorch/trainer.py +7 −2 Original line number Diff line number Diff line Loading @@ -140,12 +140,17 @@ class SoundStreamTrainer(nn.Module): ema_update_every = 10, apply_grad_penalty_every = 4, dl_num_workers = 0, accelerator: Accelerator = None, accelerate_kwargs: dict = dict(), use_lion = False, force_clear_prev_results = None # set to True | False to skip the prompt ): super().__init__() if accelerator: self.accelerator = accelerator assert len(accelerate_kwargs) == 0 else: kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs) Loading Loading
audiolm_pytorch/data.py +1 −1 Original line number Diff line number Diff line Loading @@ -36,7 +36,7 @@ class SoundDataset(Dataset): def __init__( self, folder, exts = ['flac', 'wav'], exts = ['flac', 'wav', 'mp3', 'webm'], max_length: OptionalIntOrTupleInt = None, target_sample_hz: OptionalIntOrTupleInt = None, seq_len_multiple_of: OptionalIntOrTupleInt = None Loading
audiolm_pytorch/trainer.py +7 −2 Original line number Diff line number Diff line Loading @@ -140,12 +140,17 @@ class SoundStreamTrainer(nn.Module): ema_update_every = 10, apply_grad_penalty_every = 4, dl_num_workers = 0, accelerator: Accelerator = None, accelerate_kwargs: dict = dict(), use_lion = False, force_clear_prev_results = None # set to True | False to skip the prompt ): super().__init__() if accelerator: self.accelerator = accelerator assert len(accelerate_kwargs) == 0 else: kwargs = DistributedDataParallelKwargs(find_unused_parameters = True) self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs) Loading