Commit 5739c8f0 authored by Phil Wang's avatar Phil Wang
Browse files

wire up gradient penalty for all discriminators

parent 37d9efab
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -93,9 +93,9 @@ def curtail_to_shortest_collate(data):
    return torch.stack(data)

@collate_one_or_multiple_tensors
def pad_to_longest(data):
def pad_to_longest_fn(data):
    return pad_sequence(data, batch_first = True)

def get_dataloader(ds, pad_to_longest = True, **kwargs):
    collate_fn = pad_to_longest if pad_to_longest else curtail_to_shortest_collate
    collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate
    return DataLoader(ds, collate_fn = collate_fn, **kwargs)
+25 −8
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from functools import partial

import torch
from torch import nn, einsum
from torch.autograd import grad as torch_grad
import torch.nn.functional as F
from einops import rearrange, reduce

@@ -30,13 +31,13 @@ def hinge_gen_loss(fake):
def leaky_relu(p = 0.1):
    return nn.LeakyReLU(p)

def gradient_penalty(images, output, weight = 10):
    batch_size = images.shape[0]
def gradient_penalty(wave, output, weight = 10):
    batch_size, device = wave.shape[0], wave.device

    gradients = torch_grad(
        outputs = output,
        inputs = images,
        grad_outputs = torch.ones(output.size(), device = images.device),
        inputs = wave,
        grad_outputs = torch.ones_like(output),
        create_graph = True,
        retain_graph = True,
        only_inputs = True
@@ -334,7 +335,8 @@ class SoundStream(nn.Module):
        return_discr_loss = False,
        return_discr_losses_separately = False,
        return_recons_only = False,
        input_sample_hz = None
        input_sample_hz = None,
        apply_grad_penalty = False
    ):
        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)
@@ -366,19 +368,26 @@ class SoundStream(nn.Module):
            real, fake = orig_x, recon_x.detach()

            stft_discr_loss = None
            stft_grad_penalty = None
            discr_losses = []
            discr_grad_penalties = []

            if self.single_channel:
                real, fake = orig_x, recon_x.detach()
                stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake))
                real, fake = orig_x.clone(), recon_x.detach()
                stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real.requires_grad_(), fake))
                stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2

                if apply_grad_penalty:
                    stft_grad_penalty = gradient_penalty(real, stft_discr_loss)

            for discr, scale in zip(self.discriminators, self.discr_multi_scales):
                scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake))

                real_logits, fake_logits = map(discr, (scaled_real, scaled_fake))
                real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake))
                one_discr_loss = hinge_discr_loss(fake_logits, real_logits)

                discr_losses.append(one_discr_loss)
                discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss))

            if not return_discr_losses_separately:
                all_discr_losses = torch.stack(discr_losses).mean()
@@ -386,6 +395,9 @@ class SoundStream(nn.Module):
                if exists(stft_discr_loss):
                    all_discr_losses = all_discr_losses + stft_discr_loss

                if exists(stft_grad_penalty):
                    all_discr_losses = all_discr_losses + stft_grad_penalty

                return all_discr_losses

            # return a list of discriminator losses with List[Tuple[str, Tensor]]
@@ -394,9 +406,14 @@ class SoundStream(nn.Module):

            discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)])

            discr_losses_pkg.extend([(f'scale_grad_penalty:{scale}', discr_grad_penalty) for scale, discr_grad_penalty in zip(self.discr_multi_scales, discr_grad_penalties)])

            if exists(stft_discr_loss):
                discr_losses_pkg.append(('stft', stft_discr_loss))

            if exists(stft_grad_penalty):
                discr_losses_pkg.append(('stft_grad_penalty', stft_grad_penalty))

            return discr_losses_pkg

        # recon loss
+2 −1
Original line number Diff line number Diff line
@@ -236,12 +236,13 @@ class SoundStreamTrainer(nn.Module):

            discr_losses = self.soundstream(
                wave,
                apply_grad_penalty = apply_grad_penalty,
                return_discr_loss = True,
                return_discr_losses_separately = True
            )

            for name, discr_loss in discr_losses:
                self.accelerator.backward(discr_loss / self.grad_accum_every)
                self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True)
                accum_log(logs, {name: discr_loss.item() / self.grad_accum_every})

        if exists(self.discr_max_grad_norm):
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.49',
  version = '0.0.50',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',