Commit 151edc9a authored by Ilya Borovik's avatar Ilya Borovik
Browse files

log all soundstream losses with specified frequency

parent 51cef900
Loading
Loading
Loading
Loading
+24 −4
Original line number Diff line number Diff line
@@ -125,6 +125,7 @@ class SoundStreamTrainer(nn.Module):
        discr_max_grad_norm = None,
        save_results_every = 100,
        save_model_every = 1000,
        log_losses_every = 1,
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
@@ -220,6 +221,7 @@ class SoundStreamTrainer(nn.Module):

        self.save_model_every = save_model_every
        self.save_results_every = save_results_every
        self.log_losses_every = log_losses_every

        self.apply_grad_penalty_every = apply_grad_penalty_every

@@ -298,6 +300,7 @@ class SoundStreamTrainer(nn.Module):

        steps = int(self.steps.item())
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
        log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every)

        self.soundstream.train()

@@ -311,14 +314,21 @@ class SoundStreamTrainer(nn.Module):
            wave, = next(self.dl_iter)
            wave = wave.to(device)

            loss, (recon_loss, multi_spectral_recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True)
            loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, dict(
                loss = loss.item() / self.grad_accum_every,
                recon_loss = recon_loss.item() / self.grad_accum_every,
                multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every
            ))

            if log_losses:
                accum_log(logs, dict(
                    multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every,
                    adversarial_loss = adversarial_loss.item() / self.grad_accum_every,
                    feature_loss = feature_loss.item() / self.grad_accum_every,
                    all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every,
                ))

        if exists(self.max_grad_norm):
@@ -362,7 +372,16 @@ class SoundStreamTrainer(nn.Module):
        # build pretty printed losses

        losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}"
        self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps)
        if log_losses:
            self.accelerator.log({
                "total_loss": logs['loss'],
                "recon_loss": logs['recon_loss'],
                "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'],
                "adversarial_loss": logs['adversarial_loss'],
                "feature_loss": logs['feature_loss'],
                "all_commitment_loss": logs['all_commitment_loss'],
                "stft_discr_loss": logs['stft']
            }, step=steps)

        for key, loss in logs.items():
            if not key.startswith('scale:'):
@@ -370,6 +389,7 @@ class SoundStreamTrainer(nn.Module):
            _, scale_factor = key.split(':')

            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}"
            if log_losses:
                self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps)

        # log