Unverified Commit 6734f7f8 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #142 from hmartiro/accelerate_arg

Option to pass in Accelerator to SoundStreamTrainer
parents 6d943ebe b04d207d
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ class SoundDataset(Dataset):
    def __init__(
        self,
        folder,
        exts = ['flac', 'wav'],
        exts = ['flac', 'wav', 'mp3', 'webm'],
        max_length: OptionalIntOrTupleInt = None,
        target_sample_hz: OptionalIntOrTupleInt = None,
        seq_len_multiple_of: OptionalIntOrTupleInt = None
+7 −2
Original line number Diff line number Diff line
@@ -140,12 +140,17 @@ class SoundStreamTrainer(nn.Module):
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        dl_num_workers = 0,
        accelerator: Accelerator = None,
        accelerate_kwargs: dict = dict(),
        use_lion = False,
        force_clear_prev_results = None  # set to True | False to skip the prompt
    ):
        super().__init__()

        if accelerator:
            self.accelerator = accelerator
            assert len(accelerate_kwargs) == 0
        else:
            kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
            self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs)