Commit d0e1f681 authored by Phil Wang's avatar Phil Wang
Browse files

add the "feature" loss as mentioned in the paper with customizable weight

parent 6810d515
Loading
Loading
Loading
Loading
+32 −6
Original line number Diff line number Diff line
import math
from functools import partial

import torch
from torch import nn, einsum
@@ -57,13 +58,21 @@ class MultiScaleDiscriminator(nn.Module):
            nn.Conv1d(curr_channels, 1, 1),
        )

    def forward(self, x):
    def forward(self, x, return_intermediates = False):
        x = self.init_conv(x)

        intermediates = []

        for layer in self.conv_layers:
            x = layer(x)
            intermediates.append(x)

        return self.final_conv(x)
        out = self.final_conv(x)

        if not return_intermediates:
            return out

        return out, intermediates

class ComplexLeakyReLU(nn.Module):
    """ just do nonlinearity on imag and real component separately for now """
@@ -202,7 +211,8 @@ class SoundStream(nn.Module):
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        recon_loss_weight = 1.,
        adversarial_loss_weight = 1.
        adversarial_loss_weight = 1.,
        feature_loss_weight = 100
    ):
        super().__init__()
        self.single_channel = input_channels == 1
@@ -253,6 +263,7 @@ class SoundStream(nn.Module):

        self.recon_loss_weight = recon_loss_weight
        self.adversarial_loss_weight = adversarial_loss_weight
        self.feature_loss_weight = feature_loss_weight

    def forward(
        self,
@@ -305,14 +316,29 @@ class SoundStream(nn.Module):

        adversarial_losses = []

        discr_intermediates = []

        # adversarial loss for multi-scale discriminators

        real, fake = orig_x, recon_x

        for discr, scale in zip(self.discriminators, self.discr_multi_scales):
            scaled_fake = F.interpolate(recon_x, scale_factor = scale)
            fake_logits = discr(scaled_fake)
            scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake))
            (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake))

            discr_intermediates.append((real_intermediates, fake_intermediates))

            one_adversarial_loss = hinge_gen_loss(fake_logits)
            adversarial_losses.append(one_adversarial_loss)

        feature_losses = []

        for real_intermediates, fake_intermediates in discr_intermediates:
            losses = [F.mse_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)]
            feature_losses.extend(losses)

        feature_loss = torch.stack(feature_losses).mean()

        # adversarial loss for stft discriminator

        stft_fake_logits = self.stft_discriminator(recon_x)
@@ -322,7 +348,7 @@ class SoundStream(nn.Module):

        adversarial_loss = torch.stack(adversarial_losses).mean()

        return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight
        return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight

# relative positional bias