Commit 15f83db0 authored by Ilya Borovik's avatar Ilya Borovik
Browse files

apply the same gradient penalty schedule for stft and multiscale discriminators

parent 151edc9a
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -585,6 +585,7 @@ class SoundStream(nn.Module):
                one_discr_loss = hinge_discr_loss(fake_logits, real_logits)

                discr_losses.append(one_discr_loss)
                if apply_grad_penalty:
                    discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss))

            if not return_discr_losses_separately:
+1 −1
Original line number Diff line number Diff line
@@ -299,7 +299,7 @@ class SoundStreamTrainer(nn.Module):
        device = self.device

        steps = int(self.steps.item())
        apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
        apply_grad_penalty = self.apply_grad_penalty_every > 0 and not (steps % self.apply_grad_penalty_every)
        log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every)

        self.soundstream.train()