Commit 8259b0d0 authored by Phil Wang's avatar Phil Wang
Browse files

just do what encodec does by concatting real and imag after the stft transform...

just do what encodec does by concatting real and imag after the stft transform in the stft discriminator
parent 86b5e4a0
Loading
Loading
Loading
Loading
+31 −25
Original line number Diff line number Diff line
@@ -6,6 +6,9 @@ import torch
from torch import nn, einsum
from torch.autograd import grad as torch_grad
import torch.nn.functional as F

import torchaudio.transforms as T

from einops import rearrange, reduce

from vector_quantize_pytorch import ResidualVQ
@@ -96,26 +99,14 @@ class MultiScaleDiscriminator(nn.Module):

        return out, intermediates

class ModReLU(nn.Module):
    """
    https://arxiv.org/abs/1705.09792
    https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801
    """
    def __init__(self):
        super().__init__()
        self.b = nn.Parameter(torch.tensor(0.))

    def forward(self, x):
        return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x))

def STFTResidualUnit(chan_in, chan_out, strides):
    kernel_sizes = tuple(map(lambda t: t + 2, strides))
    paddings = tuple(map(lambda t: t // 2, kernel_sizes))

    return nn.Sequential(
        nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64),
        ModReLU(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64)
        nn.Conv2d(chan_in, chan_in, 3, padding = 1),
        leaky_relu(),
        nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings)
    )

class STFTDiscriminator(nn.Module):
@@ -125,10 +116,26 @@ class STFTDiscriminator(nn.Module):
        channels = 32,
        strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)),
        chan_mults = (1, 2, 4, 4, 8, 8),
        input_channels = 1
        input_channels = 1,
        n_fft = 1024,
        hop_length = 256,
        win_length = 1024,
        stft_normalized = True
    ):
        super().__init__()
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64)
        self.stft = T.Spectrogram(
            n_fft = n_fft,
            hop_length = hop_length,
            win_length = win_length,
            window_fn = torch.hann_window,
            normalized = stft_normalized,
            center = False,
            pad_mode = None,
            power = None
        )

        input_channels *= 2
        self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3)

        layer_channels = tuple(map(lambda mult: mult * channels, chan_mults))
        layer_channels = (channels, *layer_channels)
@@ -141,7 +148,7 @@ class STFTDiscriminator(nn.Module):
        for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs):
            self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride))

        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16
        self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16

    def forward(self, x, return_intermediates = False):
        x = rearrange(x, 'b 1 n -> b n')
@@ -155,8 +162,9 @@ class STFTDiscriminator(nn.Module):
        H = 256 samples
        '''

        x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024, return_complex=False))
        x = self.stft(x)
        x = rearrange(x, 'b ... -> b 1 ...')
        x = torch.cat((x.real, x.imag), dim = 1).detach()

        intermediates = []

@@ -167,12 +175,12 @@ class STFTDiscriminator(nn.Module):
            x = layer(x)
            intermediates.append(x)

        complex_logits = self.final_conv(x)
        logits = self.final_conv(x)

        if not return_intermediates:
            return complex_logits
            return logits

        return complex_logits, intermediates
        return logits, intermediates

# sound stream

@@ -484,9 +492,7 @@ class SoundStream(nn.Module):

        # adversarial loss for stft discriminator

        adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real))
        adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag))

        adversarial_losses.append(hinge_gen_loss(stft_fake_logits))
        adversarial_loss = torch.stack(adversarial_losses).mean()

        # sum commitment loss
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.2.3',
  version = '0.3.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',