Unverified Commit 5fa0e3c6 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #93 from ilya16/trainer-tweaks

SoundStreamTrainer tweaks
parents 20ac524a 95ebad6c
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -588,6 +588,7 @@ class SoundStream(nn.Module):
                one_discr_loss = hinge_discr_loss(fake_logits, real_logits)

                discr_losses.append(one_discr_loss)
                if apply_grad_penalty:
                    discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss))

            if not return_discr_losses_separately:
+47 −16
Original line number Diff line number Diff line
@@ -125,9 +125,11 @@ class SoundStreamTrainer(nn.Module):
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        log_losses_every = 1,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 500,
        ema_update_every = 10,
@@ -143,6 +145,9 @@ class SoundStreamTrainer(nn.Module):
        self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs)

        self.soundstream = soundstream

        self.use_ema = use_ema
        if self.use_ema:
            self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every)

        self.register_buffer('steps', torch.Tensor([0]))
@@ -220,6 +225,7 @@ class SoundStreamTrainer(nn.Module):

        self.save_model_every = save_model_every
        self.save_results_every = save_results_every
        self.log_losses_every = log_losses_every

        self.apply_grad_penalty_every = apply_grad_penalty_every

@@ -236,11 +242,13 @@ class SoundStreamTrainer(nn.Module):
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
            ema_model = self.ema_soundstream.state_dict(),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )

        if self.use_ema:
            pkg['ema_model'] = self.ema_soundstream.state_dict()

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            pkg[key] = discr_optim.state_dict()
@@ -258,7 +266,10 @@ class SoundStreamTrainer(nn.Module):

        self.unwrapped_soundstream.load_state_dict(pkg['model'])

        if self.use_ema:
            assert 'ema_model' in pkg
            self.ema_soundstream.load_state_dict(pkg['ema_model'])

        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

@@ -297,7 +308,8 @@ class SoundStreamTrainer(nn.Module):
        device = self.device

        steps = int(self.steps.item())
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
        apply_grad_penalty = self.apply_grad_penalty_every > 0 and not (steps % self.apply_grad_penalty_every)
        log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every)

        self.soundstream.train()

@@ -311,14 +323,21 @@ class SoundStreamTrainer(nn.Module):
            wave, = next(self.dl_iter)
            wave = wave.to(device)

            loss, (recon_loss, multi_spectral_recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True)
            loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, dict(
                loss = loss.item() / self.grad_accum_every,
                recon_loss = recon_loss.item() / self.grad_accum_every,
                multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every
            ))

            if log_losses:
                accum_log(logs, dict(
                    multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every,
                    adversarial_loss = adversarial_loss.item() / self.grad_accum_every,
                    feature_loss = feature_loss.item() / self.grad_accum_every,
                    all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every,
                ))

        if exists(self.max_grad_norm):
@@ -362,7 +381,16 @@ class SoundStreamTrainer(nn.Module):
        # build pretty printed losses

        losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}"
        self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps)
        if log_losses:
            self.accelerator.log({
                "total_loss": logs['loss'],
                "recon_loss": logs['recon_loss'],
                "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'],
                "adversarial_loss": logs['adversarial_loss'],
                "feature_loss": logs['feature_loss'],
                "all_commitment_loss": logs['all_commitment_loss'],
                "stft_discr_loss": logs['stft']
            }, step=steps)

        for key, loss in logs.items():
            if not key.startswith('scale:'):
@@ -370,6 +398,7 @@ class SoundStreamTrainer(nn.Module):
            _, scale_factor = key.split(':')

            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}"
            if log_losses:
                self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps)

        # log
@@ -380,7 +409,7 @@ class SoundStreamTrainer(nn.Module):

        self.accelerator.wait_for_everyone()

        if self.is_main:
        if self.is_main and self.use_ema:
            self.ema_soundstream.update()

        # sample results every so often
@@ -388,19 +417,21 @@ class SoundStreamTrainer(nn.Module):
        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))):
                model.eval()
            models = [(self.unwrapped_soundstream, str(steps))]
            if self.use_ema:
                models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema'))

            wave, = next(self.valid_dl_iter)
            wave = wave.to(device)

            for model, label in models:
                model.eval()

                with torch.no_grad():
                    recons = model(wave, return_recons_only = True)

                milestone = steps // self.save_results_every

                for ind, recon in enumerate(recons.unbind(dim = 0)):
                    filename = str(self.results_folder / f'sample_{steps}.flac')
                    filename = str(self.results_folder / f'sample_{label}.flac')
                    torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz)

            self.print(f'{steps}: saving to {str(self.results_folder)}')