Commit cb442af8 authored by Phil Wang's avatar Phil Wang
Browse files

try one more time to fix complex stft discriminator network distributed training

parent 92376359
Loading
Loading
Loading
Loading
+27 −129
Original line number Diff line number Diff line
@@ -130,14 +130,35 @@ class ModReLU(nn.Module):
    def forward(self, x):
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

class ComplexConv2d(nn.Module):
    def __init__(
        self,
        dim,
        dim_out,
        kernel_size,
        stride = 1,
        padding = 0
    ):
        super().__init__()
        conv = nn.Conv2d(dim, dim_out, kernel_size, dtype = torch.complex64)
        self.weight = nn.Parameter(torch.view_as_real(conv.weight))
        self.bias = nn.Parameter(torch.view_as_real(conv.bias))

        self.stride = stride
        self.padding = padding

    def forward(self, x):
        weight, bias = map(torch.view_as_complex, (self.weight, self.bias))
        return F.conv2d(x, weight, bias, stride = self.stride, padding = self.padding)

def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64),
        ComplexConv2d(chan_in, chan_in, 3, padding = 1),
        ModReLU(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64)
        ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
    )

class ComplexSTFTDiscriminator(nn.Module):
@@ -153,7 +174,7 @@ class ComplexSTFTDiscriminator(nn.Module):
        win_length = 1024
    ):
        super().__init__()
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64)
        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
@@ -166,7 +187,7 @@ class ComplexSTFTDiscriminator(nn.Module):
        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride))

        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16
        self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16

        # stft settings

@@ -213,127 +234,6 @@ class ComplexSTFTDiscriminator(nn.Module):

        return complex_logits_abs, intermediates

# simulated complex stft discriminator

class ComplexConv2d(nn.Module):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__()
        self.conv_real = nn.Conv2d(*args, **kwargs)
        self.conv_imag = nn.Conv2d(*args, **kwargs)

    def forward(self, x):
        real, imag = x.unbind(dim = 1)
        new_real = self.conv_real(real) - self.conv_imag(imag)
        new_imag = self.conv_real(imag) + self.conv_imag(real)
        return torch.stack((new_real, new_imag), dim = 1)

def complex_abs(t, dim = 1, eps = 1e-8):
    real, imag = t.unbind(dim = 1)
    return (real ** 2 + imag ** 2).clamp(min = eps).sqrt()

class ComplexModReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        x_abs = complex_abs(x)

        real, imag = x.unbind(dim = 1)
        x_angle = torch.atan2(imag, real)

        new_real = F.relu(x_abs + self.b)
        new_imag = x_angle.exp()

        return torch.stack((new_real, new_imag), dim = 1)

def STFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        ComplexConv2d(chan_in, chan_in, 3, padding = 1),
        ComplexModReLU(),
        ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
    )

class STFTDiscriminator(nn.Module):
    def __init__(
        self,
        *,
        channels = 32,
        strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
        chan_mults = (1, 2, 4, 4, 8, 8),
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024
    ):
        super().__init__()

        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
        layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))

        curr_channels = channels

        self.layers = nn.ModuleList([])

        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride))

        self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16

        # stft settings

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

    def forward(self, x, return_intermediates = False):
        """
        einstein notation
        b - batch
        n - sequence length
        x - dimension of 2 that holds the real and imaginary values
        """

        x = rearrange(x, 'b 1 n -> b n')

        x = torch.stft(
            x,
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            return_complex = False  # going to simulate complex using dimension 1
        )

        x = rearrange(x, 'b ... x -> b x 1 ...')

        intermediates = []

        x = self.init_conv(x)
        intermediates.append(x)

        for layer in self.layers:
            x = layer(x)
            intermediates.append(x)

        logits = self.final_conv(x)

        logits_abs = complex_abs(logits)

        if not return_intermediates:
            return logits_abs

        return logits_abs, intermediates

# learned EMA blocks

class MultiHeadEMABlock(nn.Module):
@@ -470,8 +370,7 @@ class SoundStream(nn.Module):
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8,
        attn_depth = 1,
        use_complex_stft_discriminator = True
        attn_depth = 1
    ):
        super().__init__()
        self.target_sample_hz = target_sample_hz # for resampling on the fly
@@ -545,8 +444,7 @@ class SoundStream(nn.Module):
        self.discr_multi_scales = discr_multi_scales
        self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

        stft_discr_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator
        self.stft_discriminator = stft_discr_klass()
        self.stft_discriminator = ComplexSTFTDiscriminator()

        # multi spectral reconstruction

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.11.9',
  version = '0.11.11',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',