Commit 4acee04e authored by Phil Wang's avatar Phil Wang
Browse files

complete the multi-scale discriminators for soundstream

parent ee51e21c
Loading
Loading
Loading
Loading
+97 −4
Original line number Diff line number Diff line
@@ -21,6 +21,50 @@ def hinge_discr_loss(fake, real):
def hinge_gen_loss(fake):
    return -fake.mean()

def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

# gans

class MultiScaleDiscriminator(nn.Module):
    def __init__(
        self,
        channels = 16,
        layers = 4,
        groups = 4,
        chan_max = 1024,
        input_channels = 1
    ):
        super().__init__()
        self.init_conv = nn.Conv1d(input_channels, channels, 7)
        self.conv_layers = nn.ModuleList([])

        curr_channels = channels

        for _ in range(layers):
            chan_out = min(curr_channels * 4, chan_max)

            self.conv_layers.append(nn.Sequential(
                nn.Conv1d(curr_channels, chan_out, 8, stride = 4, padding = 4),
                leaky_relu()
            ))

            curr_channels = chan_out

        self.final_conv = nn.Sequential(
            nn.Conv1d(curr_channels, curr_channels, 3),
            leaky_relu(),
            nn.Conv1d(curr_channels, 1, 1),
        )

    def forward(self, x):
        x = self.init_conv(x)

        for layer in self.conv_layers:
            x = layer(x)

        return self.final_conv(x)

# sound stream

class Residual(nn.Module):
@@ -47,7 +91,9 @@ class CausalConv1d(nn.Module):
def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7):
    return Residual(nn.Sequential(
        CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation),
        CausalConv1d(chan_out, chan_out, 1)
        nn.ELU(),
        CausalConv1d(chan_out, chan_out, 1),
        nn.ELU()
    ))

def EncoderBlock(chan_in, chan_out, stride):
@@ -80,7 +126,10 @@ class SoundStream(nn.Module):
        codebook_dim = 512,
        codebook_size = 1024,
        rq_num_quantizers = 8,
        input_channels = 1
        input_channels = 1,
        discr_multi_scales = (1, 0.5, 0.25),
        recon_loss_weight = 1.,
        adversarial_loss_weight = 1.
    ):
        super().__init__()

@@ -119,7 +168,21 @@ class SoundStream(nn.Module):
            CausalConv1d(channels, input_channels, 7)
        )

    def forward(self, x):
        # discriminators

        self.discr_multi_scales = discr_multi_scales
        self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

        # loss weights

        self.recon_loss_weight = recon_loss_weight
        self.adversarial_loss_weight = adversarial_loss_weight

    def forward(
        self,
        x,
        return_discr_loss = False
    ):
        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')

@@ -133,8 +196,38 @@ class SoundStream(nn.Module):

        recon_x = self.decoder(x)

        # multi-scale discriminator loss

        if return_discr_loss:
            real, fake = orig_x, recon_x.detach()
            discr_losses = []

            for discr, scale in zip(self.discriminators, self.discr_multi_scales):
                scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake))

                real_logits, fake_logits = map(discr, (scaled_real, scaled_fake))
                one_discr_loss = hinge_discr_loss(fake_logits, real_logits)
                discr_losses.append(one_discr_loss)

            return torch.stack(discr_losses).mean()

        # recon loss

        recon_loss = F.mse_loss(orig_x, recon_x)
        return recon_loss

        # generator loss

        adversarial_losses = []

        for discr, scale in zip(self.discriminators, self.discr_multi_scales):
            scaled_fake = F.interpolate(recon_x, scale_factor = scale)
            fake_logits = discr(scaled_fake)
            one_adversarial_loss = hinge_gen_loss(fake_logits)
            adversarial_losses.append(one_adversarial_loss)

        adversarial_loss = torch.stack(adversarial_losses).mean()

        return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight

# relative positional bias