Commit 4d081c7a authored by Phil Wang's avatar Phil Wang
Browse files

switch to modrelu, seems to be recommended by a lot of torch complex packages

parent aa053305
Loading
Loading
Loading
Loading
+9 −7
Original line number Diff line number Diff line
@@ -96,15 +96,17 @@ class MultiScaleDiscriminator(nn.Module):

        return out, intermediates

class ComplexLeakyReLU(nn.Module):
    """ just do nonlinearity on imag and real component separately for now """
    def __init__(self, p = 0.1):
class ModReLU(nn.Module):
    """
    https://arxiv.org/abs/2102.13092
    https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
    """
    def __init__(self):
        super().__init__()
        self.nonlin = leaky_relu(p)
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        imag, real = map(self.nonlin, (x.imag, x.real))
        return torch.view_as_complex(torch.stack((imag, real), dim = -1))
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

def STFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
@@ -112,7 +114,7 @@ def STFTResidualUnit(chan_in, chan_out, strides):

    return nn.Sequential(
        nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64),
        ComplexLeakyReLU(),
        ModReLU(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64)
    )

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.15',
  version = '0.1.16',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',