Loading audiolm_pytorch/audiolm_pytorch.py +32 −6 Original line number Diff line number Diff line import math from functools import partial import torch from torch import nn, einsum Loading Loading @@ -57,13 +58,21 @@ class MultiScaleDiscriminator(nn.Module): nn.Conv1d(curr_channels, 1, 1), ) def forward(self, x): def forward(self, x, return_intermediates = False): x = self.init_conv(x) intermediates = [] for layer in self.conv_layers: x = layer(x) intermediates.append(x) return self.final_conv(x) out = self.final_conv(x) if not return_intermediates: return out return out, intermediates class ComplexLeakyReLU(nn.Module): """ just do nonlinearity on imag and real component separately for now """ Loading Loading @@ -202,7 +211,8 @@ class SoundStream(nn.Module): input_channels = 1, discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1. adversarial_loss_weight = 1., feature_loss_weight = 100 ): super().__init__() self.single_channel = input_channels == 1 Loading Loading @@ -253,6 +263,7 @@ class SoundStream(nn.Module): self.recon_loss_weight = recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def forward( self, Loading Loading @@ -305,14 +316,29 @@ class SoundStream(nn.Module): adversarial_losses = [] discr_intermediates = [] # adversarial loss for multi-scale discriminators real, fake = orig_x, recon_x for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_fake = F.interpolate(recon_x, scale_factor = scale) fake_logits = discr(scaled_fake) scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) discr_intermediates.append((real_intermediates, fake_intermediates)) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) feature_losses = [] for real_intermediates, fake_intermediates in discr_intermediates: losses = [F.mse_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)] feature_losses.extend(losses) feature_loss = torch.stack(feature_losses).mean() # adversarial loss for stft discriminator stft_fake_logits = self.stft_discriminator(recon_x) Loading @@ -322,7 +348,7 @@ class SoundStream(nn.Module): adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight # relative positional bias Loading Loading
audiolm_pytorch/audiolm_pytorch.py +32 −6 Original line number Diff line number Diff line import math from functools import partial import torch from torch import nn, einsum Loading Loading @@ -57,13 +58,21 @@ class MultiScaleDiscriminator(nn.Module): nn.Conv1d(curr_channels, 1, 1), ) def forward(self, x): def forward(self, x, return_intermediates = False): x = self.init_conv(x) intermediates = [] for layer in self.conv_layers: x = layer(x) intermediates.append(x) return self.final_conv(x) out = self.final_conv(x) if not return_intermediates: return out return out, intermediates class ComplexLeakyReLU(nn.Module): """ just do nonlinearity on imag and real component separately for now """ Loading Loading @@ -202,7 +211,8 @@ class SoundStream(nn.Module): input_channels = 1, discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1. adversarial_loss_weight = 1., feature_loss_weight = 100 ): super().__init__() self.single_channel = input_channels == 1 Loading Loading @@ -253,6 +263,7 @@ class SoundStream(nn.Module): self.recon_loss_weight = recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def forward( self, Loading Loading @@ -305,14 +316,29 @@ class SoundStream(nn.Module): adversarial_losses = [] discr_intermediates = [] # adversarial loss for multi-scale discriminators real, fake = orig_x, recon_x for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_fake = F.interpolate(recon_x, scale_factor = scale) fake_logits = discr(scaled_fake) scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) discr_intermediates.append((real_intermediates, fake_intermediates)) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) feature_losses = [] for real_intermediates, fake_intermediates in discr_intermediates: losses = [F.mse_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)] feature_losses.extend(losses) feature_loss = torch.stack(feature_losses).mean() # adversarial loss for stft discriminator stft_fake_logits = self.stft_discriminator(recon_x) Loading @@ -322,7 +348,7 @@ class SoundStream(nn.Module): adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight # relative positional bias Loading