Loading audiolm_pytorch/audiolm_pytorch.py +97 −4 Original line number Diff line number Diff line Loading @@ -21,6 +21,50 @@ def hinge_discr_loss(fake, real): def hinge_gen_loss(fake): return -fake.mean() def leaky_relu(p = 0.1): return nn.LeakyReLU(0.1) # gans class MultiScaleDiscriminator(nn.Module): def __init__( self, channels = 16, layers = 4, groups = 4, chan_max = 1024, input_channels = 1 ): super().__init__() self.init_conv = nn.Conv1d(input_channels, channels, 7) self.conv_layers = nn.ModuleList([]) curr_channels = channels for _ in range(layers): chan_out = min(curr_channels * 4, chan_max) self.conv_layers.append(nn.Sequential( nn.Conv1d(curr_channels, chan_out, 8, stride = 4, padding = 4), leaky_relu() )) curr_channels = chan_out self.final_conv = nn.Sequential( nn.Conv1d(curr_channels, curr_channels, 3), leaky_relu(), nn.Conv1d(curr_channels, 1, 1), ) def forward(self, x): x = self.init_conv(x) for layer in self.conv_layers: x = layer(x) return self.final_conv(x) # sound stream class Residual(nn.Module): Loading @@ -47,7 +91,9 @@ class CausalConv1d(nn.Module): def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), CausalConv1d(chan_out, chan_out, 1) nn.ELU(), CausalConv1d(chan_out, chan_out, 1), nn.ELU() )) def EncoderBlock(chan_in, chan_out, stride): Loading Loading @@ -80,7 +126,10 @@ class SoundStream(nn.Module): codebook_dim = 512, codebook_size = 1024, rq_num_quantizers = 8, input_channels = 1 input_channels = 1, discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1. ): super().__init__() Loading Loading @@ -119,7 +168,21 @@ class SoundStream(nn.Module): CausalConv1d(channels, input_channels, 7) ) def forward(self, x): # discriminators self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) # loss weights self.recon_loss_weight = recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight def forward( self, x, return_discr_loss = False ): if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -133,8 +196,38 @@ class SoundStream(nn.Module): recon_x = self.decoder(x) # multi-scale discriminator loss if return_discr_loss: real, fake = orig_x, recon_x.detach() discr_losses = [] for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) real_logits, fake_logits = map(discr, (scaled_real, scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) return torch.stack(discr_losses).mean() # recon loss recon_loss = F.mse_loss(orig_x, recon_x) return recon_loss # generator loss adversarial_losses = [] for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_fake = F.interpolate(recon_x, scale_factor = scale) fake_logits = discr(scaled_fake) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight # relative positional bias Loading Loading
audiolm_pytorch/audiolm_pytorch.py +97 −4 Original line number Diff line number Diff line Loading @@ -21,6 +21,50 @@ def hinge_discr_loss(fake, real): def hinge_gen_loss(fake): return -fake.mean() def leaky_relu(p = 0.1): return nn.LeakyReLU(0.1) # gans class MultiScaleDiscriminator(nn.Module): def __init__( self, channels = 16, layers = 4, groups = 4, chan_max = 1024, input_channels = 1 ): super().__init__() self.init_conv = nn.Conv1d(input_channels, channels, 7) self.conv_layers = nn.ModuleList([]) curr_channels = channels for _ in range(layers): chan_out = min(curr_channels * 4, chan_max) self.conv_layers.append(nn.Sequential( nn.Conv1d(curr_channels, chan_out, 8, stride = 4, padding = 4), leaky_relu() )) curr_channels = chan_out self.final_conv = nn.Sequential( nn.Conv1d(curr_channels, curr_channels, 3), leaky_relu(), nn.Conv1d(curr_channels, 1, 1), ) def forward(self, x): x = self.init_conv(x) for layer in self.conv_layers: x = layer(x) return self.final_conv(x) # sound stream class Residual(nn.Module): Loading @@ -47,7 +91,9 @@ class CausalConv1d(nn.Module): def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), CausalConv1d(chan_out, chan_out, 1) nn.ELU(), CausalConv1d(chan_out, chan_out, 1), nn.ELU() )) def EncoderBlock(chan_in, chan_out, stride): Loading Loading @@ -80,7 +126,10 @@ class SoundStream(nn.Module): codebook_dim = 512, codebook_size = 1024, rq_num_quantizers = 8, input_channels = 1 input_channels = 1, discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1. ): super().__init__() Loading Loading @@ -119,7 +168,21 @@ class SoundStream(nn.Module): CausalConv1d(channels, input_channels, 7) ) def forward(self, x): # discriminators self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) # loss weights self.recon_loss_weight = recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight def forward( self, x, return_discr_loss = False ): if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -133,8 +196,38 @@ class SoundStream(nn.Module): recon_x = self.decoder(x) # multi-scale discriminator loss if return_discr_loss: real, fake = orig_x, recon_x.detach() discr_losses = [] for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) real_logits, fake_logits = map(discr, (scaled_real, scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) return torch.stack(discr_losses).mean() # recon loss recon_loss = F.mse_loss(orig_x, recon_x) return recon_loss # generator loss adversarial_losses = [] for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_fake = F.interpolate(recon_x, scale_factor = scale) fake_logits = discr(scaled_fake) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight # relative positional bias Loading