Commit dfe9a1e8 authored by Phil Wang's avatar Phil Wang
Browse files

attempt to fix distributed training for soundstream in...

attempt to fix distributed training for soundstream in https://github.com/lucidrains/audiolm-pytorch/issues/31
parent 49499085
Loading
Loading
Loading
Loading
+61 −32
Original line number Diff line number Diff line
@@ -154,8 +154,7 @@ class ComplexSTFTDiscriminator(nn.Module):
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        **kwargs
        win_length = 1024
    ):
        super().__init__()
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64)
@@ -216,16 +215,48 @@ class ComplexSTFTDiscriminator(nn.Module):

        return complex_logits, intermediates

# non-complex stft discriminator
# simulated complex stft discriminator

class ComplexConv2d(nn.Module):
    def __init__(
        self,
        *args,
        **kwargs
    ):
        super().__init__()
        self.conv_real = nn.Conv2d(*args, **kwargs)
        self.conv_imag = nn.Conv2d(*args, **kwargs)

    def forward(self, x):
        real, imag = x.unbind(dim = 1)
        new_real = self.conv_real(real) - self.conv_imag(imag)
        new_imag = self.conv_real(imag) + self.conv_imag(real)
        return torch.stack((new_real, new_imag), dim = 1)

class ComplexModReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        real, imag = x.unbind(dim = 1)

        x_abs = (real ** 2 + imag ** 2).clamp(min = 1e-5).sqrt()
        x_angle = torch.atan(imag / real.clamp(min = 1e-5))

        new_real = F.relu(x_abs + self.b)
        new_imag = x_angle.exp()

        return torch.stack((new_real, new_imag), dim = 1)

def STFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        nn.Conv2d(chan_in, chan_in, 3, padding = 1),
        leaky_relu(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
        ComplexConv2d(chan_in, chan_in, 3, padding = 1),
        ComplexModReLU(),
        ComplexConv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
    )

class STFTDiscriminator(nn.Module):
@@ -238,23 +269,11 @@ class STFTDiscriminator(nn.Module):
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        stft_normalized = True
        win_length = 1024
    ):
        super().__init__()
        self.stft = T.Spectrogram(
            n_fft = n_fft,
            hop_length = hop_length,
            win_length = win_length,
            window_fn = torch.hann_window,
            normalized = stft_normalized,
            center = False,
            pad_mode = None,
            power = None
        )

        input_channels *= 2
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3)
        self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
@@ -267,23 +286,33 @@ class STFTDiscriminator(nn.Module):
        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride))

        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16
        self.final_conv = ComplexConv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16

        # stft settings

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

    def forward(self, x, return_intermediates = False):
        x = rearrange(x, 'b 1 n -> b n')
        """
        einstein notation
        b - batch
        n - sequence length
        x - dimension of 2 that holds the real and imaginary values
        """

        '''
        reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:
        x = rearrange(x, 'b 1 n -> b n')

        The STFT-based discriminator is illustrated in Figure 4
        and operates on a single scale, computing the STFT with a
        window length of W = 1024 samples and a hop length of
        H = 256 samples
        '''
        x = torch.stft(
            x,
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            return_complex = False  # going to simulate complex using dimension 1
        )

        x = self.stft(x)
        x = rearrange(x, 'b ... -> b 1 ...')
        x = torch.cat((x.real, x.imag), dim = 1)
        x = rearrange(x, 'b ... x -> b x 1 ...')

        intermediates = []

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.7.3',
  version = '0.7.4',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',