Commit 47871901 authored by Phil Wang's avatar Phil Wang
Browse files

address equations 4 and 5 in the soundstream paper - "multi spectral recon...

address equations 4 and 5 in the soundstream paper - "multi spectral recon loss", raised in issue https://github.com/lucidrains/audiolm-pytorch/issues/76
parent c229f62a
Loading
Loading
Loading
Loading
+40 −1
Original line number Diff line number Diff line
@@ -30,6 +30,9 @@ def default(val, d):

# gan losses

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

@@ -338,7 +341,10 @@ class SoundStream(nn.Module):
        discr_multi_scales = (1, 0.5, 0.25),
        enc_cycle_dilations = (1, 3, 9),
        dec_cycle_dilations = (1, 3, 9),
        multi_spectral_window_powers_of_two = tuple(range(6, 12)),
        num_mel_bins = 64,
        recon_loss_weight = 1.,
        multi_spectral_recon_loss_weight = 1.,
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100,
        quantize_dropout_cutoff_index = 1,
@@ -426,9 +432,32 @@ class SoundStream(nn.Module):

        self.stft_discriminator = STFTDiscriminator()

        # multi spectral reconstruction

        self.mel_spec_transforms = nn.ModuleList([])
        self.mel_spec_recon_alphas = []

        max_win_length = 2 ** max(multi_spectral_window_powers_of_two)

        for powers in multi_spectral_window_powers_of_two:
            win_length = 2 ** powers
            alpha = (win_length / 2) ** 0.5

            melspec_transform = T.MelSpectrogram(
                sample_rate = target_sample_hz,
                n_fft = max_win_length,
                win_length = win_length,
                hop_length = win_length // 4,
                n_mels = num_mel_bins
            )

            self.mel_spec_transforms.append(melspec_transform)
            self.mel_spec_recon_alphas.append(alpha)

        # loss weights

        self.recon_loss_weight = recon_loss_weight
        self.multi_spectral_recon_loss_weight = multi_spectral_recon_loss_weight
        self.adversarial_loss_weight = adversarial_loss_weight
        self.feature_loss_weight = feature_loss_weight

@@ -558,6 +587,16 @@ class SoundStream(nn.Module):

        recon_loss = F.mse_loss(orig_x, recon_x)

        # multispectral recon loss - eq (4) and (5) in https://arxiv.org/abs/2107.03312

        multi_spectral_recon_loss = 0

        for mel_transform, alpha in zip(self.mel_spec_transforms, self.mel_spec_recon_alphas):
            orig_mel, recon_mel = map(mel_transform, (orig_x, recon_x))
            log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel))

            multi_spectral_recon_loss = multi_spectral_recon_loss + (orig_mel - recon_mel).abs().sum() + alpha * ((log_orig_mel - log_recon_mel) ** 2).sum()

        # adversarial loss

        adversarial_losses = []
@@ -599,7 +638,7 @@ class SoundStream(nn.Module):

        all_commitment_loss = commit_loss.sum()

        total_loss = recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss
        total_loss = recon_loss * self.recon_loss_weight + multi_spectral_recon_loss * self.multi_spectral_recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss

        if return_loss_breakdown:
            return total_loss, (recon_loss, adversarial_loss, feature_loss, all_commitment_loss)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.10.4',
  version = '0.11.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',