Commit 0863e897 authored by Phil Wang's avatar Phil Wang
Browse files

stop using deprecated torch.norm, as suggested by @zhvng

parent c8b4b748
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ import torch
from torch import nn, einsum
from torch.autograd import grad as torch_grad
import torch.nn.functional as F
from torch.linalg import vector_norm

import torchaudio.transforms as T

@@ -60,7 +61,7 @@ def gradient_penalty(wave, output, weight = 10):
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
    return weight * ((vector_norm(gradients, dim = 1) - 1) ** 2).mean()

# discriminators

@@ -612,7 +613,7 @@ class SoundStream(nn.Module):
                log_orig_mel, log_recon_mel = map(log, (orig_mel, recon_mel))

                l1_mel_loss = (orig_mel - recon_mel).abs().sum(dim = -2).mean()
                l2_log_mel_loss = alpha * (log_orig_mel - log_recon_mel).norm(p = 2, dim = -2).mean()
                l2_log_mel_loss = alpha * vector_norm(log_orig_mel - log_recon_mel, dim = -2).mean()

                multi_spectral_recon_loss = multi_spectral_recon_loss + l1_mel_loss + l2_log_mel_loss