Commit 6bd8b8fa authored by Phil Wang's avatar Phil Wang
Browse files

first pass at stft discriminator

parent d9c6ba40
Loading
Loading
Loading
Loading
+83 −3
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ def hinge_gen_loss(fake):
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# gans
# discriminators

class MultiScaleDiscriminator(nn.Module):
    def __init__(
@@ -65,6 +65,64 @@ class MultiScaleDiscriminator(nn.Module):

        return self.final_conv(x)

class ComplexLeakyReLU(nn.Module):
    """ just do nonlinearity on imag and real component separately for now """
    def __init__(self, p = 0.1):
        super().__init__()
        self.nonlin = leaky_relu(p)

    def forward(self, x):
        imag, real = map(self.nonlin, (x.imag, x.real))
        return torch.view_as_complex(torch.stack((imag, real), dim = -1))

def STFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64),
        ComplexLeakyReLU(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64)
    )

class STFTDiscriminator(nn.Module):
    def __init__(
        self,
        *,
        channels = 32,
        strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
        chan_mults = (1, 2, 4, 4, 8, 8),
        input_channels = 1
    ):
        super().__init__()
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
        layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))

        curr_channels = channels

        self.layers = nn.ModuleList([])

        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride))

        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16

    def forward(self, x):
        x = rearrange(x, 'b 1 n -> b n')
        x = torch.view_as_complex(torch.stft(x, 256))
        x = rearrange(x, 'b ... -> b 1 ...')

        x = self.init_conv(x)

        for layer in self.layers:
            x = layer(x)

        complex_logits = self.final_conv(x)
        return complex_logits

# sound stream

class Residual(nn.Module):
@@ -148,6 +206,7 @@ class SoundStream(nn.Module):
        adversarial_loss_weight = 1.
    ):
        super().__init__()
        self.single_channel = input_channels == 1

        layer_channels = tuple(map(lambda t: t * channels, channel_mults))
        layer_channels = (channels, *layer_channels)
@@ -189,6 +248,8 @@ class SoundStream(nn.Module):
        self.discr_multi_scales = discr_multi_scales
        self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

        self.stft_discriminator = STFTDiscriminator()

        # loss weights

        self.recon_loss_weight = recon_loss_weight
@@ -197,7 +258,8 @@ class SoundStream(nn.Module):
    def forward(
        self,
        x,
        return_discr_loss = False
        return_discr_loss = False,
        return_stft_discr_loss = False
    ):
        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')
@@ -212,6 +274,15 @@ class SoundStream(nn.Module):

        recon_x = self.decoder(x)

        # stft discr loss

        if return_stft_discr_loss:
            assert self.single_channel
            real, fake = orig_x, recon_x.detach()
            stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake))
            stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2
            return stft_discr_loss

        # multi-scale discriminator loss

        if return_discr_loss:
@@ -231,16 +302,25 @@ class SoundStream(nn.Module):

        recon_loss = F.mse_loss(orig_x, recon_x)

        # generator loss
        # adversarial loss

        adversarial_losses = []

        # adversarial loss for multi-scale discriminators

        for discr, scale in zip(self.discriminators, self.discr_multi_scales):
            scaled_fake = F.interpolate(recon_x, scale_factor = scale)
            fake_logits = discr(scaled_fake)
            one_adversarial_loss = hinge_gen_loss(fake_logits)
            adversarial_losses.append(one_adversarial_loss)

        # adversarial loss for stft discriminator

        stft_fake_logits = self.stft_discriminator(recon_x)

        adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real))
        adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag))

        adversarial_loss = torch.stack(adversarial_losses).mean()

        return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight