Loading audiolm_pytorch/audiolm_pytorch.py +16 −4 Original line number Diff line number Diff line Loading @@ -119,19 +119,27 @@ class STFTDiscriminator(nn.Module): self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 def forward(self, x): def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') x = torch.view_as_complex(torch.stft(x, 256)) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # sound stream class Residual(nn.Module): Loading Loading @@ -318,10 +326,15 @@ class SoundStream(nn.Module): discr_intermediates = [] # adversarial loss for multi-scale discriminators real, fake = orig_x, recon_x # features from stft (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake)) discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates)) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) Loading @@ -337,12 +350,11 @@ class SoundStream(nn.Module): losses = [F.mse_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)] feature_losses.extend(losses) feature_loss = torch.stack(feature_losses).mean() # adversarial loss for stft discriminator stft_fake_logits = self.stft_discriminator(recon_x) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) Loading Loading
audiolm_pytorch/audiolm_pytorch.py +16 −4 Original line number Diff line number Diff line Loading @@ -119,19 +119,27 @@ class STFTDiscriminator(nn.Module): self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 def forward(self, x): def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') x = torch.view_as_complex(torch.stft(x, 256)) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # sound stream class Residual(nn.Module): Loading Loading @@ -318,10 +326,15 @@ class SoundStream(nn.Module): discr_intermediates = [] # adversarial loss for multi-scale discriminators real, fake = orig_x, recon_x # features from stft (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake)) discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates)) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) Loading @@ -337,12 +350,11 @@ class SoundStream(nn.Module): losses = [F.mse_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)] feature_losses.extend(losses) feature_loss = torch.stack(feature_losses).mean() # adversarial loss for stft discriminator stft_fake_logits = self.stft_discriminator(recon_x) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) Loading