Commit e3a4e897 authored by Phil Wang's avatar Phil Wang
Browse files

should use absolute value of complex value for final loss

parent 45bb030a
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord

This repository now also contains a MIT licensed version of <a href="https://arxiv.org/abs/2107.03312">SoundStream</a>. Once <a href="https://github.com/facebookresearch/encodec">EnCodec</a> becomes MIT licensed, then I will consider adding a wrapper for that as well for use here.

Update: AudioLM was essentially used to 'solve' music generation in the new <a href="https://github.com/lucidrains/audiolm-pytorch">MusicLM</a>
Update: AudioLM was essentially used to 'solve' music generation in the new <a href="https://github.com/lucidrains/musiclm-pytorch">MusicLM</a>

## Appreciation

+11 −22
Original line number Diff line number Diff line
@@ -28,21 +28,6 @@ def exists(val):
def default(val, d):
    return val if exists(val) else d

# decorators

def auto_handle_complex(fn):
    @wraps(fn)
    def inner(*args):
        if args[0].dtype not in (torch.complex64, torch.complex32):
            return fn(*args)

        real_args = tuple(arg.real for arg in args)
        imag_args = tuple(arg.imag for arg in args)

        return (fn(*real_args) + fn(*imag_args)) * 0.5

    return inner

# tensor helpers

def l2norm(t, dim = -1):
@@ -53,11 +38,9 @@ def l2norm(t, dim = -1):
def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

@auto_handle_complex
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

@auto_handle_complex
def hinge_gen_loss(fake):
    return -fake.mean()

@@ -221,7 +204,7 @@ class ComplexSTFTDiscriminator(nn.Module):
        if not return_intermediates:
            return complex_logits

        return complex_logits, intermediates
        return torch.abs(complex_logits), intermediates

# simulated complex stft discriminator

@@ -241,15 +224,19 @@ class ComplexConv2d(nn.Module):
        new_imag = self.conv_real(imag) + self.conv_imag(real)
        return torch.stack((new_real, new_imag), dim = 1)

def complex_abs(t, dim = 1, eps = 1e-8):
    real, imag = t.unbind(dim = 1)
    return (real ** 2 + imag ** 2).clamp(min = eps).sqrt()

class ComplexModReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        real, imag = x.unbind(dim = 1)
        x_abs = complex_abs(x)

        x_abs = (real ** 2 + imag ** 2).clamp(min = 1e-8).sqrt()
        real, imag = x.unbind(dim = 1)
        x_angle = torch.atan2(imag, real)

        new_real = F.relu(x_abs + self.b)
@@ -333,10 +320,12 @@ class STFTDiscriminator(nn.Module):

        logits = self.final_conv(x)

        logits_abs = complex_abs(logits)

        if not return_intermediates:
            return logits
            return logits_abs

        return logits, intermediates
        return logits_abs, intermediates

# learned EMA blocks

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.11.2',
  version = '0.11.3',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',