Commit 8910dbfe authored by Phil Wang's avatar Phil Wang
Browse files

use l1 loss for "feature" loss

parent 532d8e95
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
@@ -326,12 +326,12 @@ class SoundStream(nn.Module):

        discr_intermediates = []


        # adversarial loss for multi-scale discriminators

        real, fake = orig_x, recon_x

        # features from stft

        (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake))
        discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates))

@@ -347,10 +347,9 @@ class SoundStream(nn.Module):
        feature_losses = []

        for real_intermediates, fake_intermediates in discr_intermediates:
            losses = [F.mse_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)]
            losses = [F.l1_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)]
            feature_losses.extend(losses)


        feature_loss = torch.stack(feature_losses).mean()

        # adversarial loss for stft discriminator
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.1',
  version = '0.0.2',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',