Unverified Commit 532d8e95 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #9 from aabzaliev/main

include stft discriminator features into feature loss
parents d0e1f681 42b334cf
Loading
Loading
Loading
Loading
+16 −4
Original line number Diff line number Diff line
@@ -119,19 +119,27 @@ class STFTDiscriminator(nn.Module):

        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16

    def forward(self, x):
    def forward(self, x, return_intermediates = False):
        x = rearrange(x, 'b 1 n -> b n')
        x = torch.view_as_complex(torch.stft(x, 256))
        x = rearrange(x, 'b ... -> b 1 ...')

        intermediates = []

        x = self.init_conv(x)
        intermediates.append(x)

        for layer in self.layers:
            x = layer(x)
            intermediates.append(x)

        complex_logits = self.final_conv(x)

        if not return_intermediates:
            return complex_logits

        return complex_logits, intermediates

# sound stream

class Residual(nn.Module):
@@ -318,10 +326,15 @@ class SoundStream(nn.Module):

        discr_intermediates = []


        # adversarial loss for multi-scale discriminators

        real, fake = orig_x, recon_x

        # features from stft
        (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake))
        discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates))

        for discr, scale in zip(self.discriminators, self.discr_multi_scales):
            scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake))
            (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake))
@@ -337,12 +350,11 @@ class SoundStream(nn.Module):
            losses = [F.mse_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)]
            feature_losses.extend(losses)


        feature_loss = torch.stack(feature_losses).mean()

        # adversarial loss for stft discriminator

        stft_fake_logits = self.stft_discriminator(recon_x)

        adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real))
        adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag))