Commit 9b62899e authored by Phil Wang's avatar Phil Wang
Browse files

less is more

parent ba7dcd68
Loading
Loading
Loading
Loading
+2 −116
Original line number Diff line number Diff line
@@ -28,28 +28,11 @@ def exists(val):
def default(val, d):
    return val if exists(val) else d

# decorators

def auto_handle_complex(fn):
    @wraps(fn)
    def inner(*args):
        if args[0].dtype not in (torch.complex64, torch.complex32):
            return fn(*args)

        real_args = tuple(arg.real for arg in args)
        imag_args = tuple(arg.imag for arg in args)

        return (fn(*real_args) + fn(*imag_args)) * 0.5

    return inner

# gan losses

@auto_handle_complex
def hinge_discr_loss(fake, real):
    return (F.relu(1 + fake) + F.relu(1 - real)).mean()

@auto_handle_complex
def hinge_gen_loss(fake):
    return -fake.mean()

@@ -120,101 +103,6 @@ class MultiScaleDiscriminator(nn.Module):

        return out, intermediates

# complex stft discriminator

class ModReLU(nn.Module):
    """
    https://arxiv.org/abs/1705.09792
    https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
    """
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

def ComplexSTFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64),
        ModReLU(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64)
    )

class ComplexSTFTDiscriminator(nn.Module):
    def __init__(
        self,
        *,
        channels = 32,
        strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
        chan_mults = (1, 2, 4, 4, 8, 8),
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024
    ):
        super().__init__()
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
        layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:]))

        curr_channels = channels

        self.layers = nn.ModuleList([])

        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride))

        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16

        # stft settings

        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length

    def forward(self, x, return_intermediates = False):
        x = rearrange(x, 'b 1 n -> b n')

        '''
        reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows:
        The STFT-based discriminator is illustrated in Figure 4
        and operates on a single scale, computing the STFT with a
        window length of W = 1024 samples and a hop length of
        H = 256 samples
        '''

        x = torch.stft(
            x,
            self.n_fft,
            hop_length = self.hop_length,
            win_length = self.win_length,
            return_complex = True
        )

        x = rearrange(x, 'b ... -> b 1 ...')

        intermediates = []

        x = self.init_conv(x)
        intermediates.append(x)

        for layer in self.layers:
            x = layer(x)
            intermediates.append(x)

        complex_logits = self.final_conv(x)

        if not return_intermediates:
            return complex_logits

        return complex_logits, intermediates

# simulated complex stft discriminator

class ComplexConv2d(nn.Module):
@@ -461,8 +349,7 @@ class SoundStream(nn.Module):
        mhesa_dim_head = 32,
        attn_window_size = 128,
        attn_dim_head = 64,
        attn_heads = 8,
        use_complex_stft_discriminator = False
        attn_heads = 8
    ):
        super().__init__()
        self.target_sample_hz = target_sample_hz # for resampling on the fly
@@ -536,8 +423,7 @@ class SoundStream(nn.Module):
        self.discr_multi_scales = discr_multi_scales
        self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

        stft_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator
        self.stft_discriminator = stft_klass()
        self.stft_discriminator = STFTDiscriminator()

        # loss weights