Loading audiolm_pytorch/audiolm_pytorch.py +16 −0 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ from functools import partial import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F from einops import rearrange Loading @@ -25,6 +26,21 @@ def hinge_gen_loss(fake): def leaky_relu(p = 0.1): return nn.LeakyReLU(0.1) def gradient_penalty(images, output, weight = 10): batch_size = images.shape[0] gradients = torch_grad( outputs = output, inputs = images, grad_outputs = torch.ones(output.size(), device = images.device), create_graph = True, retain_graph = True, only_inputs = True )[0] gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() # discriminators class MultiScaleDiscriminator(nn.Module): Loading Loading
audiolm_pytorch/audiolm_pytorch.py +16 −0 Original line number Diff line number Diff line Loading @@ -3,6 +3,7 @@ from functools import partial import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F from einops import rearrange Loading @@ -25,6 +26,21 @@ def hinge_gen_loss(fake): def leaky_relu(p = 0.1): return nn.LeakyReLU(0.1) def gradient_penalty(images, output, weight = 10): batch_size = images.shape[0] gradients = torch_grad( outputs = output, inputs = images, grad_outputs = torch.ones(output.size(), device = images.device), create_graph = True, retain_graph = True, only_inputs = True )[0] gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() # discriminators class MultiScaleDiscriminator(nn.Module): Loading