Commit 3ae3a0dd authored by Phil Wang's avatar Phil Wang
Browse files

will use gradient penalty on discriminator, as had much success with that in stylegan

parent 0a714012
Loading
Loading
Loading
Loading
+16 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ from functools import partial

import torch
from torch import nn, einsum
from torch.autograd import grad as torch_grad
import torch.nn.functional as F

from einops import rearrange
@@ -25,6 +26,21 @@ def hinge_gen_loss(fake):
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(0.1)

def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]

    gradients = torch_grad(
        outputs = output,
        inputs = images,
        grad_outputs = torch.ones(output.size(), device = images.device),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
    )[0]

    gradients = rearrange(gradients, 'b ... -> b (...)')
    return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()

# discriminators

class MultiScaleDiscriminator(nn.Module):