Loading audiolm_pytorch/soundstream.py +120 −3 Original line number Diff line number Diff line import functools from itertools import cycle from pathlib import Path from functools import partial from functools import partial, wraps import torch from torch import nn, einsum Loading @@ -28,11 +28,28 @@ def exists(val): def default(val, d): return val if exists(val) else d # decorators def auto_handle_complex(fn): @wraps(fn) def inner(*args): if args[0].dtype not in (torch.complex64, torch.complex32): return fn(*args) real_args = tuple(arg.real for arg in args) imag_args = tuple(arg.imag for arg in args) return (fn(*real_args) + fn(*imag_args)) * 0.5 return inner # gan losses @auto_handle_complex def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() @auto_handle_complex def hinge_gen_loss(fake): return -fake.mean() Loading Loading @@ -103,6 +120,104 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates # complex stft discriminator class ModReLU(nn.Module): """ https://arxiv.org/abs/1705.09792 https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801 """ def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) def ComplexSTFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class ComplexSTFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, **kwargs ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = True ) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # non-complex stft discriminator def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) Loading Loading @@ -317,7 +432,8 @@ class SoundStream(nn.Module): mhesa_dim_head = 32, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8 attn_heads = 8, use_complex_stft_discriminator = True ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading Loading @@ -391,7 +507,8 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) self.stft_discriminator = STFTDiscriminator() stft_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator self.stft_discriminator = stft_klass() # loss weights Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.7.2', version = '0.7.3', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/soundstream.py +120 −3 Original line number Diff line number Diff line import functools from itertools import cycle from pathlib import Path from functools import partial from functools import partial, wraps import torch from torch import nn, einsum Loading @@ -28,11 +28,28 @@ def exists(val): def default(val, d): return val if exists(val) else d # decorators def auto_handle_complex(fn): @wraps(fn) def inner(*args): if args[0].dtype not in (torch.complex64, torch.complex32): return fn(*args) real_args = tuple(arg.real for arg in args) imag_args = tuple(arg.imag for arg in args) return (fn(*real_args) + fn(*imag_args)) * 0.5 return inner # gan losses @auto_handle_complex def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() @auto_handle_complex def hinge_gen_loss(fake): return -fake.mean() Loading Loading @@ -103,6 +120,104 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates # complex stft discriminator class ModReLU(nn.Module): """ https://arxiv.org/abs/1705.09792 https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801 """ def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) def ComplexSTFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class ComplexSTFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, **kwargs ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(ComplexSTFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 # stft settings self.n_fft = n_fft self.hop_length = hop_length self.win_length = win_length def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.stft( x, self.n_fft, hop_length = self.hop_length, win_length = self.win_length, return_complex = True ) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # non-complex stft discriminator def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) Loading Loading @@ -317,7 +432,8 @@ class SoundStream(nn.Module): mhesa_dim_head = 32, attn_window_size = 128, attn_dim_head = 64, attn_heads = 8 attn_heads = 8, use_complex_stft_discriminator = True ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly Loading Loading @@ -391,7 +507,8 @@ class SoundStream(nn.Module): self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) self.stft_discriminator = STFTDiscriminator() stft_klass = ComplexSTFTDiscriminator if use_complex_stft_discriminator else STFTDiscriminator self.stft_discriminator = stft_klass() # loss weights Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.7.2', version = '0.7.3', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading