Loading audiolm_pytorch/soundstream.py +31 −25 Original line number Diff line number Diff line Loading @@ -6,6 +6,9 @@ import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F import torchaudio.transforms as T from einops import rearrange, reduce from vector_quantize_pytorch import ResidualVQ Loading Loading @@ -96,26 +99,14 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates class ModReLU(nn.Module): """ https://arxiv.org/abs/1705.09792 https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801 """ def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) nn.Conv2d(chan_in, chan_in, 3, padding = 1), leaky_relu(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class STFTDiscriminator(nn.Module): Loading @@ -125,10 +116,26 @@ class STFTDiscriminator(nn.Module): channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1 input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, stft_normalized = True ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) self.stft = T.Spectrogram( n_fft = n_fft, hop_length = hop_length, win_length = win_length, window_fn = torch.hann_window, normalized = stft_normalized, center = False, pad_mode = None, power = None ) input_channels *= 2 self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) Loading @@ -141,7 +148,7 @@ class STFTDiscriminator(nn.Module): for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') Loading @@ -155,8 +162,9 @@ class STFTDiscriminator(nn.Module): H = 256 samples ''' x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024, return_complex=False)) x = self.stft(x) x = rearrange(x, 'b ... -> b 1 ...') x = torch.cat((x.real, x.imag), dim = 1).detach() intermediates = [] Loading @@ -167,12 +175,12 @@ class STFTDiscriminator(nn.Module): x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) logits = self.final_conv(x) if not return_intermediates: return complex_logits return logits return complex_logits, intermediates return logits, intermediates # sound stream Loading Loading @@ -484,9 +492,7 @@ class SoundStream(nn.Module): # adversarial loss for stft discriminator adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits)) adversarial_loss = torch.stack(adversarial_losses).mean() # sum commitment loss Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.2.3', version = '0.3.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/soundstream.py +31 −25 Original line number Diff line number Diff line Loading @@ -6,6 +6,9 @@ import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F import torchaudio.transforms as T from einops import rearrange, reduce from vector_quantize_pytorch import ResidualVQ Loading Loading @@ -96,26 +99,14 @@ class MultiScaleDiscriminator(nn.Module): return out, intermediates class ModReLU(nn.Module): """ https://arxiv.org/abs/1705.09792 https://github.com/pytorch/pytorch/issues/47052#issuecomment-718948801 """ def __init__(self): super().__init__() self.b = nn.Parameter(torch.tensor(0.)) def forward(self, x): return F.relu(torch.abs(x) + self.b) * torch.exp(1.j * torch.angle(x)) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ModReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) nn.Conv2d(chan_in, chan_in, 3, padding = 1), leaky_relu(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings) ) class STFTDiscriminator(nn.Module): Loading @@ -125,10 +116,26 @@ class STFTDiscriminator(nn.Module): channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1 input_channels = 1, n_fft = 1024, hop_length = 256, win_length = 1024, stft_normalized = True ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) self.stft = T.Spectrogram( n_fft = n_fft, hop_length = hop_length, win_length = win_length, window_fn = torch.hann_window, normalized = stft_normalized, center = False, pad_mode = None, power = None ) input_channels *= 2 self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) Loading @@ -141,7 +148,7 @@ class STFTDiscriminator(nn.Module): for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1)) # todo: remove hardcoded 16 def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') Loading @@ -155,8 +162,9 @@ class STFTDiscriminator(nn.Module): H = 256 samples ''' x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024, return_complex=False)) x = self.stft(x) x = rearrange(x, 'b ... -> b 1 ...') x = torch.cat((x.real, x.imag), dim = 1).detach() intermediates = [] Loading @@ -167,12 +175,12 @@ class STFTDiscriminator(nn.Module): x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) logits = self.final_conv(x) if not return_intermediates: return complex_logits return logits return complex_logits, intermediates return logits, intermediates # sound stream Loading Loading @@ -484,9 +492,7 @@ class SoundStream(nn.Module): # adversarial loss for stft discriminator adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits)) adversarial_loss = torch.stack(adversarial_losses).mean() # sum commitment loss Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.2.3', version = '0.3.0', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading