Commit 2ebfd1c1 authored by Phil Wang's avatar Phil Wang
Browse files

bring back the complex stft discriminator, to try to figure out where the...

bring back the complex stft discriminator, to try to figure out where the regression for soundstream happened. default to complex stft discriminator
parent e6eaabdb
Loading
Loading
Loading
Loading
+120 −3
Original line number Diff line number Diff line
import functools
from itertools import cycle
from pathlib import Path
from functools import partial
from functools import partial, wraps

import torch
from torch import nn, einsum
@@ -28,11 +28,28 @@ def exists(val):
def default(val, d):
    return val if exists(val) else d

# decorators

def auto_handle_complex(fn):
    @wraps(fn)
    def inner(*args):
        if args[0].dtype not in (torch.complex64, torch.complex32):
            return fn(*args)

        real_args = tuple(arg.real for arg in args)
        imag_args = tuple(arg.imag for arg in args)

        return (fn(*real_args) + fn(*imag_args)) * 0.5

    return inner

# gan losses

@auto_handle_complex
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

@auto_handle_complex
def hinge_gen_loss(fake):
    return -fake.mean()

@@ -103,6 +120,104 @@ class MultiScaleDiscriminator(nn.Module):

        return out, intermediates

# complex stft discriminator

class ModReLU(nn.Module):
    """
    https://arxiv.org/abs/1705.09792
    https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
    """
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64),
        ModReLU(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64)
    )

class ComplexSTFTDiscriminator(nn.Module):
    def __init__(
        self,
        *,
        channels = 32,
        strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
        chan_mults = (1, 2, 4, 4, 8, 8),
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        **kwargs
    ):
        super().__init__()
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
        layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))

        curr_channels = channels

        self.layers = nn.ModuleList([])

        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride))

        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16

        # stft settings

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

    def forward(self, x, return_intermediates = False):
        x = rearrange(x, 'b 1 n -> b n')

        '''
        reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:
        The STFT-based discriminator is illustrated in Figure 4
        and operates on a single scale, computing the STFT with a
        window length of W = 1024 samples and a hop length of
        H = 256 samples
        '''

        x = torch.stft(
            x,
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            return_complex = True
        )

        x = rearrange(x, 'b ... -> b 1 ...')

        intermediates = []

        x = self.init_conv(x)
        intermediates.append(x)

        for layer in self.layers:
            x = layer(x)
            intermediates.append(x)

        complex_logits = self.final_conv(x)

        if not return_intermediates:
            return complex_logits

        return complex_logits, intermediates

# non-complex stft discriminator

def STFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))
@@ -317,7 +432,8 @@ class SoundStream(nn.Module):
        mhesa_dim_head = 32,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8
        attn_heads = 8,
        use_complex_stft_discriminator = True
    ):
        super().__init__()
        self.target_sample_hz = target_sample_hz # for resampling on the fly
@@ -391,7 +507,8 @@ class SoundStream(nn.Module):
        self.discr_multi_scales = discr_multi_scales
        self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

        self.stft_discriminator = STFTDiscriminator()
        stft_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator
        self.stft_discriminator = stft_klass()

        # loss weights

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.7.2',
  version = '0.7.3',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',