Loading audiolm_pytorch/soundstream.py +9 −4 Original line number Diff line number Diff line Loading @@ -444,6 +444,8 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) discr_rel_factors = [int(s1 / s2) for s1, s2 in zip(discr_multi_scales[:-1], discr_multi_scales[1:])] self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors]) self.stft_discriminator = ComplexSTFTDiscriminator( stft_normalized = stft_normalized Loading Loading @@ -578,8 +580,9 @@ class SoundStream(nn.Module): if apply_grad_penalty: stft_grad_penalty = gradient_penalty(real, stft_discr_loss) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) scaled_real, scaled_fake = real, fake for discr, downsample in zip(self.discriminators, self.downsamples): scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake)) real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) Loading Loading @@ -647,8 +650,10 @@ class SoundStream(nn.Module): (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake)) discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates)) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) scaled_real, scaled_fake = real, fake for discr, downsample in zip(self.discriminators, self.downsamples): scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) discr_intermediates.append((real_intermediates, fake_intermediates)) Loading Loading
audiolm_pytorch/soundstream.py +9 −4 Original line number Diff line number Diff line Loading @@ -444,6 +444,8 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) discr_rel_factors = [int(s1 / s2) for s1, s2 in zip(discr_multi_scales[:-1], discr_multi_scales[1:])] self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors]) self.stft_discriminator = ComplexSTFTDiscriminator( stft_normalized = stft_normalized Loading Loading @@ -578,8 +580,9 @@ class SoundStream(nn.Module): if apply_grad_penalty: stft_grad_penalty = gradient_penalty(real, stft_discr_loss) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) scaled_real, scaled_fake = real, fake for discr, downsample in zip(self.discriminators, self.downsamples): scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake)) real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) Loading Loading @@ -647,8 +650,10 @@ class SoundStream(nn.Module): (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake)) discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates)) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) scaled_real, scaled_fake = real, fake for discr, downsample in zip(self.discriminators, self.downsamples): scaled_real, scaled_fake = map(downsample, (scaled_real, scaled_fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) discr_intermediates.append((real_intermediates, fake_intermediates)) Loading