Loading audiolm_pytorch/audiolm_pytorch.py +83 −3 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ def hinge_gen_loss(fake): def leaky_relu(p = 0.1): return nn.LeakyReLU(0.1) # gans # discriminators class MultiScaleDiscriminator(nn.Module): def __init__( Loading Loading @@ -65,6 +65,64 @@ class MultiScaleDiscriminator(nn.Module): return self.final_conv(x) class ComplexLeakyReLU(nn.Module): """ just do nonlinearity on imag and real component separately for now """ def __init__(self, p = 0.1): super().__init__() self.nonlin = leaky_relu(p) def forward(self, x): imag, real = map(self.nonlin, (x.imag, x.real)) return torch.view_as_complex(torch.stack((imag, real), dim = -1)) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ComplexLeakyReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class STFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 def forward(self, x): x = rearrange(x, 'b 1 n -> b n') x = torch.view_as_complex(torch.stft(x, 256)) x = rearrange(x, 'b ... -> b 1 ...') x = self.init_conv(x) for layer in self.layers: x = layer(x) complex_logits = self.final_conv(x) return complex_logits # sound stream class Residual(nn.Module): Loading Loading @@ -148,6 +206,7 @@ class SoundStream(nn.Module): adversarial_loss_weight = 1. ): super().__init__() self.single_channel = input_channels == 1 layer_channels = tuple(map(lambda t: t * channels, channel_mults)) layer_channels = (channels, *layer_channels) Loading Loading @@ -189,6 +248,8 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) self.stft_discriminator = STFTDiscriminator() # loss weights self.recon_loss_weight = recon_loss_weight Loading @@ -197,7 +258,8 @@ class SoundStream(nn.Module): def forward( self, x, return_discr_loss = False return_discr_loss = False, return_stft_discr_loss = False ): if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -212,6 +274,15 @@ class SoundStream(nn.Module): recon_x = self.decoder(x) # stft discr loss if return_stft_discr_loss: assert self.single_channel real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 return stft_discr_loss # multi-scale discriminator loss if return_discr_loss: Loading @@ -231,16 +302,25 @@ class SoundStream(nn.Module): recon_loss = F.mse_loss(orig_x, recon_x) # generator loss # adversarial loss adversarial_losses = [] # adversarial loss for multi-scale discriminators for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_fake = F.interpolate(recon_x, scale_factor = scale) fake_logits = discr(scaled_fake) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) # adversarial loss for stft discriminator stft_fake_logits = self.stft_discriminator(recon_x) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight Loading Loading
audiolm_pytorch/audiolm_pytorch.py +83 −3 Original line number Diff line number Diff line Loading @@ -24,7 +24,7 @@ def hinge_gen_loss(fake): def leaky_relu(p = 0.1): return nn.LeakyReLU(0.1) # gans # discriminators class MultiScaleDiscriminator(nn.Module): def __init__( Loading Loading @@ -65,6 +65,64 @@ class MultiScaleDiscriminator(nn.Module): return self.final_conv(x) class ComplexLeakyReLU(nn.Module): """ just do nonlinearity on imag and real component separately for now """ def __init__(self, p = 0.1): super().__init__() self.nonlin = leaky_relu(p) def forward(self, x): imag, real = map(self.nonlin, (x.imag, x.real)) return torch.view_as_complex(torch.stack((imag, real), dim = -1)) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ComplexLeakyReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class STFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 def forward(self, x): x = rearrange(x, 'b 1 n -> b n') x = torch.view_as_complex(torch.stft(x, 256)) x = rearrange(x, 'b ... -> b 1 ...') x = self.init_conv(x) for layer in self.layers: x = layer(x) complex_logits = self.final_conv(x) return complex_logits # sound stream class Residual(nn.Module): Loading Loading @@ -148,6 +206,7 @@ class SoundStream(nn.Module): adversarial_loss_weight = 1. ): super().__init__() self.single_channel = input_channels == 1 layer_channels = tuple(map(lambda t: t * channels, channel_mults)) layer_channels = (channels, *layer_channels) Loading Loading @@ -189,6 +248,8 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) self.stft_discriminator = STFTDiscriminator() # loss weights self.recon_loss_weight = recon_loss_weight Loading @@ -197,7 +258,8 @@ class SoundStream(nn.Module): def forward( self, x, return_discr_loss = False return_discr_loss = False, return_stft_discr_loss = False ): if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -212,6 +274,15 @@ class SoundStream(nn.Module): recon_x = self.decoder(x) # stft discr loss if return_stft_discr_loss: assert self.single_channel real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 return stft_discr_loss # multi-scale discriminator loss if return_discr_loss: Loading @@ -231,16 +302,25 @@ class SoundStream(nn.Module): recon_loss = F.mse_loss(orig_x, recon_x) # generator loss # adversarial loss adversarial_losses = [] # adversarial loss for multi-scale discriminators for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_fake = F.interpolate(recon_x, scale_factor = scale) fake_logits = discr(scaled_fake) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) # adversarial loss for stft discriminator stft_fake_logits = self.stft_discriminator(recon_x) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight Loading