Commit b6ff57d9 authored by Phil Wang's avatar Phil Wang
Browse files

return the multi spectral recon loss

parent f84bf917
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -748,6 +748,6 @@ class SoundStream(nn.Module):
        total_loss = recon_loss * self.recon_loss_weight + multi_spectral_recon_loss * self.multi_spectral_recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss

        if return_loss_breakdown:
            return total_loss, (recon_loss, adversarial_loss, feature_loss, all_commitment_loss)
            return total_loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss)

        return total_loss
+3 −2
Original line number Diff line number Diff line
@@ -308,13 +308,14 @@ class SoundStreamTrainer(nn.Module):
            wave, = next(self.dl_iter)
            wave = wave.to(device)

            loss, (recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True)
            loss, (recon_loss, multi_spectral_recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True)

            self.accelerator.backward(loss / self.grad_accum_every)

            accum_log(logs, dict(
                loss = loss.item() / self.grad_accum_every,
                recon_loss = recon_loss.item() / self.grad_accum_every
                recon_loss = recon_loss.item() / self.grad_accum_every,
                multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every
            ))

        if exists(self.max_grad_norm):
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.11.4',
  version = '0.11.5',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',