Commit 6df317f8 authored by Hayk Martiros's avatar Hayk Martiros
Browse files

Allow passing custom DataLoaders to SoundStreamTrainer

This allows customizing the data loading from the default provided
by SoundDataset, which is a folder of audio files on disk. For example,
it can allow someone to connect a WebDataset loader instead.

Also adds types to SoundStreamTrainer's constructor.
parent acb37195
Loading
Loading
Loading
Loading
+75 −48
Original line number Diff line number Diff line
@@ -118,33 +118,39 @@ class SoundStreamTrainer(nn.Module):
        self,
        soundstream: SoundStream,
        *,
        num_train_steps,
        batch_size,
        data_max_length = None,
        data_max_length_seconds = None,
        folder,
        lr = 2e-4,
        grad_accum_every = 4,
        wd = 0.,
        max_grad_norm = 0.5,
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        log_losses_every = 1,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 500,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        dl_num_workers = 0,
        num_train_steps: int,
        batch_size: int,
        data_max_length: int = None,
        data_max_length_seconds: float = None,
        folder: str = None,
        train_dataloader: DataLoader = None,
        val_dataloader: DataLoader = None,
        lr: float = 2e-4,
        grad_accum_every: int = 4,
        wd: float = 0.,
        max_grad_norm: float = 0.5,
        discr_max_grad_norm: float = None,
        save_results_every: int = 100,
        save_model_every: int= 1000,
        log_losses_every: int= 1,
        results_folder: str = './results',
        valid_frac: float = 0.05,
        random_split_seed: int = 42,
        use_ema: bool = True,
        ema_beta: float = 0.995,
        ema_update_after_step: int = 500,
        ema_update_every: int = 10,
        apply_grad_penalty_every: int = 4,
        dl_num_workers: int = 0,
        accelerator: Accelerator = None,
        accelerate_kwargs: dict = dict(),
        use_lion = False,
        force_clear_prev_results = None  # set to True | False to skip the prompt
        use_lion: bool = False,
        force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    ):
        """
        Initialize with a SoundStream instance and either a folder containing audio data or
        train/val DataLoader instances.
        """
        super().__init__()

        if accelerator:
@@ -166,6 +172,14 @@ class SoundStreamTrainer(nn.Module):
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        hyperparameters = {
            "num_train_steps": num_train_steps,
            "batch_size": batch_size,
            "gradient_accum_every": grad_accum_every,
            "learning_rate": lr,
            "target_sample_hz": soundstream.target_sample_hz,
        }

        # optimizers

        self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd)
@@ -181,12 +195,24 @@ class SoundStreamTrainer(nn.Module):
        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm

        # create dataset
        if folder is None:
            assert train_dataloader is not None
            assert val_dataloader is not None
            self.dl = train_dataloader
            self.valid_dl = val_dataloader
        else:
            assert train_dataloader is None
            assert val_dataloader is None

        assert not (exists(data_max_length) and exists(data_max_length_seconds))
            # create dataset

            if exists(data_max_length_seconds):
            data_max_length = data_max_length_seconds * soundstream.target_sample_hz
                assert not exists(data_max_length)
                data_max_length = int(data_max_length_seconds * soundstream.target_sample_hz)
            else:
                assert exists(data_max_length)

            hyperparameters['data_max_length'] = data_max_length

            self.ds = SoundDataset(
                folder,
@@ -251,8 +277,9 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("soundstream", config=hps)        
        # Initialize experiment trackers if an external Accelerator is not passed in
        if not accelerator:
            self.accelerator.init_trackers("soundstream", config=hyperparameters)        

    def set_model_as_ema_model_(self):
        """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """