Unverified Commit 5afd7651 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #145 from hmartiro/soundstream_custom_dataloader

Allow passing custom DataLoaders to SoundStreamTrainer
parents acb37195 6df317f8
Loading
Loading
Loading
Loading
+75 −48
Original line number Diff line number Diff line
@@ -118,33 +118,39 @@ class SoundStreamTrainer(nn.Module):
        self,
        soundstream: SoundStream,
        *,
        num_train_steps,
        batch_size,
        data_max_length = None,
        data_max_length_seconds = None,
        folder,
        lr = 2e-4,
        grad_accum_every = 4,
        wd = 0.,
        max_grad_norm = 0.5,
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        log_losses_every = 1,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 500,
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        dl_num_workers = 0,
        num_train_steps: int,
        batch_size: int,
        data_max_length: int = None,
        data_max_length_seconds: float = None,
        folder: str = None,
        train_dataloader: DataLoader = None,
        val_dataloader: DataLoader = None,
        lr: float = 2e-4,
        grad_accum_every: int = 4,
        wd: float = 0.,
        max_grad_norm: float = 0.5,
        discr_max_grad_norm: float = None,
        save_results_every: int = 100,
        save_model_every: int= 1000,
        log_losses_every: int= 1,
        results_folder: str = './results',
        valid_frac: float = 0.05,
        random_split_seed: int = 42,
        use_ema: bool = True,
        ema_beta: float = 0.995,
        ema_update_after_step: int = 500,
        ema_update_every: int = 10,
        apply_grad_penalty_every: int = 4,
        dl_num_workers: int = 0,
        accelerator: Accelerator = None,
        accelerate_kwargs: dict = dict(),
        use_lion = False,
        force_clear_prev_results = None  # set to True | False to skip the prompt
        use_lion: bool = False,
        force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    ):
        """
        Initialize with a SoundStream instance and either a folder containing audio data or
        train/val DataLoader instances.
        """
        super().__init__()

        if accelerator:
@@ -166,6 +172,14 @@ class SoundStreamTrainer(nn.Module):
        self.batch_size = batch_size
        self.grad_accum_every = grad_accum_every

        hyperparameters = {
            "num_train_steps": num_train_steps,
            "batch_size": batch_size,
            "gradient_accum_every": grad_accum_every,
            "learning_rate": lr,
            "target_sample_hz": soundstream.target_sample_hz,
        }

        # optimizers

        self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd)
@@ -181,12 +195,24 @@ class SoundStreamTrainer(nn.Module):
        self.max_grad_norm = max_grad_norm
        self.discr_max_grad_norm = discr_max_grad_norm

        # create dataset
        if folder is None:
            assert train_dataloader is not None
            assert val_dataloader is not None
            self.dl = train_dataloader
            self.valid_dl = val_dataloader
        else:
            assert train_dataloader is None
            assert val_dataloader is None

        assert not (exists(data_max_length) and exists(data_max_length_seconds))
            # create dataset

            if exists(data_max_length_seconds):
            data_max_length = data_max_length_seconds * soundstream.target_sample_hz
                assert not exists(data_max_length)
                data_max_length = int(data_max_length_seconds * soundstream.target_sample_hz)
            else:
                assert exists(data_max_length)

            hyperparameters['data_max_length'] = data_max_length

            self.ds = SoundDataset(
                folder,
@@ -251,8 +277,9 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("soundstream", config=hps)        
        # Initialize experiment trackers if an external Accelerator is not passed in
        if not accelerator:
            self.accelerator.init_trackers("soundstream", config=hyperparameters)        

    def set_model_as_ema_model_(self):
        """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """