Loading audiolm_pytorch/soundstream.py +40 −1 Original line number Diff line number Diff line Loading @@ -30,6 +30,9 @@ def default(val, d): # gan losses def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() Loading Loading @@ -338,7 +341,10 @@ class SoundStream(nn.Module): discr_multi_scales = (1, 0.5, 0.25), enc_cycle_dilations = (1, 3, 9), dec_cycle_dilations = (1, 3, 9), multi_spectral_window_powers_of_two = tuple(range(6, 12)), num_mel_bins = 64, recon_loss_weight = 1., multi_spectral_recon_loss_weight = 1., adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout_cutoff_index = 1, Loading Loading @@ -426,9 +432,32 @@ class SoundStream(nn.Module): self.stft_discriminator = STFTDiscriminator() # multi spectral reconstruction self.mel_spec_transforms = nn.ModuleList([]) self.mel_spec_recon_alphas = [] max_win_length = 2 ** max(multi_spectral_window_powers_of_two) for powers in multi_spectral_window_powers_of_two: win_length = 2 ** powers alpha = (win_length / 2) ** 0.5 melspec_transform = T.MelSpectrogram( sample_rate = target_sample_hz, n_fft = max_win_length, win_length = win_length, hop_length = win_length // 4, n_mels = num_mel_bins ) self.mel_spec_transforms.append(melspec_transform) self.mel_spec_recon_alphas.append(alpha) # loss weights self.recon_loss_weight = recon_loss_weight self.multi_spectral_recon_loss_weight = multi_spectral_recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight Loading Loading @@ -558,6 +587,16 @@ class SoundStream(nn.Module): recon_loss = F.mse_loss(orig_x, recon_x) # multispectral recon loss - eq (4) and (5) in https://arxiv.org/abs/2107.03312 multi_spectral_recon_loss = 0 for mel_transform, alpha in zip(self.mel_spec_transforms, self.mel_spec_recon_alphas): orig_mel, recon_mel = map(mel_transform, (orig_x, recon_x)) log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel)) multi_spectral_recon_loss = multi_spectral_recon_loss + (orig_mel - recon_mel).abs().sum() + alpha * ((log_orig_mel - log_recon_mel) ** 2).sum() # adversarial loss adversarial_losses = [] Loading Loading @@ -599,7 +638,7 @@ class SoundStream(nn.Module): all_commitment_loss = commit_loss.sum() total_loss = recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss total_loss = recon_loss * self.recon_loss_weight + multi_spectral_recon_loss * self.multi_spectral_recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss if return_loss_breakdown: return total_loss, (recon_loss, adversarial_loss, feature_loss, all_commitment_loss) Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.10.4', version = '0.11.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/soundstream.py +40 −1 Original line number Diff line number Diff line Loading @@ -30,6 +30,9 @@ def default(val, d): # gan losses def log(t, eps = 1e-20): return torch.log(t.clamp(min = eps)) def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() Loading Loading @@ -338,7 +341,10 @@ class SoundStream(nn.Module): discr_multi_scales = (1, 0.5, 0.25), enc_cycle_dilations = (1, 3, 9), dec_cycle_dilations = (1, 3, 9), multi_spectral_window_powers_of_two = tuple(range(6, 12)), num_mel_bins = 64, recon_loss_weight = 1., multi_spectral_recon_loss_weight = 1., adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout_cutoff_index = 1, Loading Loading @@ -426,9 +432,32 @@ class SoundStream(nn.Module): self.stft_discriminator = STFTDiscriminator() # multi spectral reconstruction self.mel_spec_transforms = nn.ModuleList([]) self.mel_spec_recon_alphas = [] max_win_length = 2 ** max(multi_spectral_window_powers_of_two) for powers in multi_spectral_window_powers_of_two: win_length = 2 ** powers alpha = (win_length / 2) ** 0.5 melspec_transform = T.MelSpectrogram( sample_rate = target_sample_hz, n_fft = max_win_length, win_length = win_length, hop_length = win_length // 4, n_mels = num_mel_bins ) self.mel_spec_transforms.append(melspec_transform) self.mel_spec_recon_alphas.append(alpha) # loss weights self.recon_loss_weight = recon_loss_weight self.multi_spectral_recon_loss_weight = multi_spectral_recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight Loading Loading @@ -558,6 +587,16 @@ class SoundStream(nn.Module): recon_loss = F.mse_loss(orig_x, recon_x) # multispectral recon loss - eq (4) and (5) in https://arxiv.org/abs/2107.03312 multi_spectral_recon_loss = 0 for mel_transform, alpha in zip(self.mel_spec_transforms, self.mel_spec_recon_alphas): orig_mel, recon_mel = map(mel_transform, (orig_x, recon_x)) log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel)) multi_spectral_recon_loss = multi_spectral_recon_loss + (orig_mel - recon_mel).abs().sum() + alpha * ((log_orig_mel - log_recon_mel) ** 2).sum() # adversarial loss adversarial_losses = [] Loading Loading @@ -599,7 +638,7 @@ class SoundStream(nn.Module): all_commitment_loss = commit_loss.sum() total_loss = recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss total_loss = recon_loss * self.recon_loss_weight + multi_spectral_recon_loss * self.multi_spectral_recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight + all_commitment_loss if return_loss_breakdown: return total_loss, (recon_loss, adversarial_loss, feature_loss, all_commitment_loss) Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.10.4', version = '0.11.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading