Loading audiolm_pytorch/soundstream.py +3 −2 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F from torch.linalg import vector_norm import torchaudio.transforms as T Loading Loading @@ -60,7 +61,7 @@ def gradient_penalty(wave, output, weight = 10): )[0] gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean() # discriminators Loading Loading @@ -612,7 +613,7 @@ class SoundStream(nn.Module): log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel)) l1_mel_loss = (orig_mel - recon_mel).abs().sum(dim = -2).mean() l2_log_mel_loss = alpha * (log_orig_mel - log_recon_mel).norm(p = 2, dim = -2).mean() l2_log_mel_loss = alpha * vector_norm(log_orig_mel - log_recon_mel, dim = -2).mean() multi_spectral_recon_loss = multi_spectral_recon_loss + l1_mel_loss + l2_log_mel_loss Loading Loading
audiolm_pytorch/soundstream.py +3 −2 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F from torch.linalg import vector_norm import torchaudio.transforms as T Loading Loading @@ -60,7 +61,7 @@ def gradient_penalty(wave, output, weight = 10): )[0] gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean() # discriminators Loading Loading @@ -612,7 +613,7 @@ class SoundStream(nn.Module): log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel)) l1_mel_loss = (orig_mel - recon_mel).abs().sum(dim = -2).mean() l2_log_mel_loss = alpha * (log_orig_mel - log_recon_mel).norm(p = 2, dim = -2).mean() l2_log_mel_loss = alpha * vector_norm(log_orig_mel - log_recon_mel, dim = -2).mean() multi_spectral_recon_loss = multi_spectral_recon_loss + l1_mel_loss + l2_log_mel_loss Loading