Loading audiolm_pytorch/__init__.py +1 −1 Original line number Diff line number Diff line from audiolm_pytorch.audiolm_pytorch import AudioLM from audiolm_pytorch.audiolm_pytorch import SoundStream from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper Loading audiolm_pytorch/audiolm_pytorch.py +2 −422 Original line number Diff line number Diff line import math import functools from functools import partial from typing import Optional, Union Loading @@ -11,16 +10,15 @@ from torch.nn.utils.rnn import pad_sequence from einops import rearrange, repeat from vector_quantize_pytorch import ResidualVQ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME from audiolm_pytorch.utils import curtail_to_multiple from torchaudio.functional import resample from audiolm_pytorch.soundstream import SoundStream # helper functions def exists(val): Loading @@ -38,32 +36,6 @@ def remainder_needed_until_multiple(n, mult): def round_down_nearest_multiple(val, mult): return (val // mult) * mult # gan losses def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() def hinge_gen_loss(fake): return -fake.mean() def leaky_relu(p = 0.1): return nn.LeakyReLU(p) def gradient_penalty(images, output, weight = 10): batch_size = images.shape[0] gradients = torch_grad( outputs = output, inputs = images, grad_outputs = torch.ones(output.size(), device = images.device), create_graph = True, retain_graph = True, only_inputs = True )[0] gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() # attention related utils def grad_shrink(t, alpha = 0.1): Loading @@ -71,9 +43,6 @@ def grad_shrink(t, alpha = 0.1): # classifier free guidance functions def uniform(shape, device): return torch.zeros(shape, device = device).float().uniform_(0, 1) def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) Loading @@ -96,395 +65,6 @@ def batch_unique_consecutive(t, pad_value = 0.): unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)] return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value) # discriminators class MultiScaleDiscriminator(nn.Module): def __init__( self, channels = 16, layers = 4, groups = 4, chan_max = 1024, input_channels = 1 ): super().__init__() self.init_conv = nn.Conv1d(input_channels, channels, 7) self.conv_layers = nn.ModuleList([]) curr_channels = channels for _ in range(layers): chan_out = min(curr_channels * 4, chan_max) self.conv_layers.append(nn.Sequential( nn.Conv1d(curr_channels, chan_out, 8, stride = 4, padding = 4, groups = groups), leaky_relu() )) curr_channels = chan_out self.final_conv = nn.Sequential( nn.Conv1d(curr_channels, curr_channels, 3), leaky_relu(), nn.Conv1d(curr_channels, 1, 1), ) def forward(self, x, return_intermediates = False): x = self.init_conv(x) intermediates = [] for layer in self.conv_layers: x = layer(x) intermediates.append(x) out = self.final_conv(x) if not return_intermediates: return out return out, intermediates class ComplexLeakyReLU(nn.Module): """ just do nonlinearity on imag and real component separately for now """ def __init__(self, p = 0.1): super().__init__() self.nonlin = leaky_relu(p) def forward(self, x): imag, real = map(self.nonlin, (x.imag, x.real)) return torch.view_as_complex(torch.stack((imag, real), dim = -1)) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ComplexLeakyReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class STFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024)) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # sound stream class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class CausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, **kwargs): super().__init__() kernel_size = kernel_size dilation = kwargs.get('dilation', 1) self.causal_padding = dilation * (kernel_size - 1) self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs) def forward(self, x): x = F.pad(x, (self.causal_padding, 0)) return self.conv(x) class CausalConvTranspose1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs): super().__init__() self.upsample_factor = stride self.padding = kernel_size - 1 self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs) def forward(self, x): n = x.shape[-1] out = self.conv(x) out = out[..., :(n * self.upsample_factor)] return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), nn.ELU(), CausalConv1d(chan_out, chan_out, 1), nn.ELU() )) def EncoderBlock(chan_in, chan_out, stride): return nn.Sequential( ResidualUnit(chan_in, chan_in, 1), ResidualUnit(chan_in, chan_in, 3), ResidualUnit(chan_in, chan_in, 9), CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride) ) def DecoderBlock(chan_in, chan_out, stride): even_stride = (stride % 2 == 0) padding = (stride + (0 if even_stride else 1)) // 2 output_padding = 0 if even_stride else 1 return nn.Sequential( CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride), ResidualUnit(chan_out, chan_out, 1), ResidualUnit(chan_out, chan_out, 3), ResidualUnit(chan_out, chan_out, 9), ) class SoundStream(nn.Module): def __init__( self, *, channels = 32, strides = (2, 4, 5, 8), channel_mults = (2, 4, 8, 16), codebook_dim = 512, codebook_size = 1024, rq_num_quantizers = 8, input_channels = 1, discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout = True, quantize_dropout_cutoff_index = 0, target_sample_hz = 24000 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly self.single_channel = input_channels == 1 self.strides = strides layer_channels = tuple(map(lambda t: t * channels, channel_mults)) layer_channels = (channels, *layer_channels) chan_in_out_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) encoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), *encoder_blocks, CausalConv1d(layer_channels[-1], codebook_dim, 3) ) self.rq = ResidualVQ( dim = codebook_dim, num_quantizers = rq_num_quantizers, codebook_size = codebook_size, kmeans_init = True, threshold_ema_dead_code = 2, quantize_dropout = quantize_dropout, quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), *decoder_blocks, CausalConv1d(channels, input_channels, 7) ) # discriminators self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) self.stft_discriminator = STFTDiscriminator() # loss weights self.recon_loss_weight = recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def non_discr_parameters(self): return [*self.encoder.parameters(), *self.decoder.parameters()] @property def seq_len_multiple_of(self): return functools.reduce(lambda x, y: x * y, self.strides) def forward( self, x, return_encoded = False, return_discr_loss = False, return_discr_losses_separately = False, return_recons_only = False, input_sample_hz = None ): if exists(input_sample_hz): x = resample(x, input_sample_hz, self.target_sample_hz) x = curtail_to_multiple(x, self.seq_len_multiple_of) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') orig_x = x.clone() x = self.encoder(x) x = rearrange(x, 'b c n -> b n c') x, indices, commit_loss = self.rq(x) x = rearrange(x, 'b n c -> b c n') if return_encoded: return x, indices, commit_loss recon_x = self.decoder(x) if return_recons_only: return recon_x # multi-scale discriminator loss if return_discr_loss: real, fake = orig_x, recon_x.detach() stft_discr_loss = None discr_losses = [] if self.single_channel: real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) real_logits, fake_logits = map(discr, (scaled_real, scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) if not return_discr_losses_separately: all_discr_losses = torch.stack(discr_losses).mean() if exists(stft_discr_loss): all_discr_losses = all_discr_losses + stft_discr_loss return all_discr_losses # return a list of discriminator losses with List[Tuple[str, Tensor]] discr_losses_pkg = [] discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)]) if exists(stft_discr_loss): discr_losses_pkg.append(('stft', stft_discr_loss)) return discr_losses_pkg # recon loss recon_loss = F.mse_loss(orig_x, recon_x) # adversarial loss adversarial_losses = [] discr_intermediates = [] # adversarial loss for multi-scale discriminators real, fake = orig_x, recon_x # features from stft (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake)) discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates)) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) discr_intermediates.append((real_intermediates, fake_intermediates)) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) feature_losses = [] for real_intermediates, fake_intermediates in discr_intermediates: losses = [F.l1_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)] feature_losses.extend(losses) feature_loss = torch.stack(feature_losses).mean() # adversarial loss for stft discriminator adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight # relative positional bias class RelativePositionBias(nn.Module): Loading audiolm_pytorch/soundstream.py 0 → 100644 +434 −0 File added.Preview size limit exceeded, changes collapsed. Show changes audiolm_pytorch/trainer.py +1 −1 Original line number Diff line number Diff line Loading @@ -20,7 +20,7 @@ from audiolm_pytorch.optimizer import get_optimizer from ema_pytorch import EMA from audiolm_pytorch.audiolm_pytorch import SoundStream from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.data import SoundDataset, get_dataloader from accelerate import Accelerator Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.25', version = '0.0.26', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/__init__.py +1 −1 Original line number Diff line number Diff line from audiolm_pytorch.audiolm_pytorch import AudioLM from audiolm_pytorch.audiolm_pytorch import SoundStream from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper, CoarseTransformerWrapper Loading
audiolm_pytorch/audiolm_pytorch.py +2 −422 Original line number Diff line number Diff line import math import functools from functools import partial from typing import Optional, Union Loading @@ -11,16 +10,15 @@ from torch.nn.utils.rnn import pad_sequence from einops import rearrange, repeat from vector_quantize_pytorch import ResidualVQ from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec from audiolm_pytorch.hubert_kmeans import HubertWithKmeans from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME from audiolm_pytorch.utils import curtail_to_multiple from torchaudio.functional import resample from audiolm_pytorch.soundstream import SoundStream # helper functions def exists(val): Loading @@ -38,32 +36,6 @@ def remainder_needed_until_multiple(n, mult): def round_down_nearest_multiple(val, mult): return (val // mult) * mult # gan losses def hinge_discr_loss(fake, real): return (F.relu(1 + fake) + F.relu(1 - real)).mean() def hinge_gen_loss(fake): return -fake.mean() def leaky_relu(p = 0.1): return nn.LeakyReLU(p) def gradient_penalty(images, output, weight = 10): batch_size = images.shape[0] gradients = torch_grad( outputs = output, inputs = images, grad_outputs = torch.ones(output.size(), device = images.device), create_graph = True, retain_graph = True, only_inputs = True )[0] gradients = rearrange(gradients, 'b ... -> b (...)') return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean() # attention related utils def grad_shrink(t, alpha = 0.1): Loading @@ -71,9 +43,6 @@ def grad_shrink(t, alpha = 0.1): # classifier free guidance functions def uniform(shape, device): return torch.zeros(shape, device = device).float().uniform_(0, 1) def prob_mask_like(shape, prob, device): if prob == 1: return torch.ones(shape, device = device, dtype = torch.bool) Loading @@ -96,395 +65,6 @@ def batch_unique_consecutive(t, pad_value = 0.): unique_arr = [torch.unique_consecutive(el) for el in t.unbind(dim = 0)] return pad_sequence(unique_arr, batch_first = True, padding_value = pad_value) # discriminators class MultiScaleDiscriminator(nn.Module): def __init__( self, channels = 16, layers = 4, groups = 4, chan_max = 1024, input_channels = 1 ): super().__init__() self.init_conv = nn.Conv1d(input_channels, channels, 7) self.conv_layers = nn.ModuleList([]) curr_channels = channels for _ in range(layers): chan_out = min(curr_channels * 4, chan_max) self.conv_layers.append(nn.Sequential( nn.Conv1d(curr_channels, chan_out, 8, stride = 4, padding = 4, groups = groups), leaky_relu() )) curr_channels = chan_out self.final_conv = nn.Sequential( nn.Conv1d(curr_channels, curr_channels, 3), leaky_relu(), nn.Conv1d(curr_channels, 1, 1), ) def forward(self, x, return_intermediates = False): x = self.init_conv(x) intermediates = [] for layer in self.conv_layers: x = layer(x) intermediates.append(x) out = self.final_conv(x) if not return_intermediates: return out return out, intermediates class ComplexLeakyReLU(nn.Module): """ just do nonlinearity on imag and real component separately for now """ def __init__(self, p = 0.1): super().__init__() self.nonlin = leaky_relu(p) def forward(self, x): imag, real = map(self.nonlin, (x.imag, x.real)) return torch.view_as_complex(torch.stack((imag, real), dim = -1)) def STFTResidualUnit(chan_in, chan_out, strides): kernel_sizes = tuple(map(lambda t: t + 2, strides)) paddings = tuple(map(lambda t: t // 2, kernel_sizes)) return nn.Sequential( nn.Conv2d(chan_in, chan_in, 3, padding = 1, dtype = torch.complex64), ComplexLeakyReLU(), nn.Conv2d(chan_in, chan_out, kernel_sizes, stride = strides, padding = paddings, dtype = torch.complex64) ) class STFTDiscriminator(nn.Module): def __init__( self, *, channels = 32, strides = ((1, 2), (2, 2), (1, 2), (2, 2), (1, 2), (2, 2)), chan_mults = (1, 2, 4, 4, 8, 8), input_channels = 1 ): super().__init__() self.init_conv = nn.Conv2d(input_channels, channels, 7, padding = 3, dtype = torch.complex64) layer_channels = tuple(map(lambda mult: mult * channels, chan_mults)) layer_channels = (channels, *layer_channels) layer_channels_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) curr_channels = channels self.layers = nn.ModuleList([]) for layer_stride, (chan_in, chan_out) in zip(strides, layer_channels_pairs): self.layers.append(STFTResidualUnit(chan_in, chan_out, layer_stride)) self.final_conv = nn.Conv2d(layer_channels[-1], 1, (16, 1), dtype = torch.complex64) # todo: remove hardcoded 16 def forward(self, x, return_intermediates = False): x = rearrange(x, 'b 1 n -> b n') ''' reference: The content of the paper( https://arxiv.org/pdf/2107.03312.pdf)is as follows: The STFT-based discriminator is illustrated in Figure 4 and operates on a single scale, computing the STFT with a window length of W = 1024 samples and a hop length of H = 256 samples ''' x = torch.view_as_complex(torch.stft(x,1024, hop_length=256,win_length=1024)) x = rearrange(x, 'b ... -> b 1 ...') intermediates = [] x = self.init_conv(x) intermediates.append(x) for layer in self.layers: x = layer(x) intermediates.append(x) complex_logits = self.final_conv(x) if not return_intermediates: return complex_logits return complex_logits, intermediates # sound stream class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) + x class CausalConv1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, **kwargs): super().__init__() kernel_size = kernel_size dilation = kwargs.get('dilation', 1) self.causal_padding = dilation * (kernel_size - 1) self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, **kwargs) def forward(self, x): x = F.pad(x, (self.causal_padding, 0)) return self.conv(x) class CausalConvTranspose1d(nn.Module): def __init__(self, chan_in, chan_out, kernel_size, stride, **kwargs): super().__init__() self.upsample_factor = stride self.padding = kernel_size - 1 self.conv = nn.ConvTranspose1d(chan_in, chan_out, kernel_size, stride, **kwargs) def forward(self, x): n = x.shape[-1] out = self.conv(x) out = out[..., :(n * self.upsample_factor)] return out def ResidualUnit(chan_in, chan_out, dilation, kernel_size = 7): return Residual(nn.Sequential( CausalConv1d(chan_in, chan_out, kernel_size, dilation = dilation), nn.ELU(), CausalConv1d(chan_out, chan_out, 1), nn.ELU() )) def EncoderBlock(chan_in, chan_out, stride): return nn.Sequential( ResidualUnit(chan_in, chan_in, 1), ResidualUnit(chan_in, chan_in, 3), ResidualUnit(chan_in, chan_in, 9), CausalConv1d(chan_in, chan_out, 2 * stride, stride = stride) ) def DecoderBlock(chan_in, chan_out, stride): even_stride = (stride % 2 == 0) padding = (stride + (0 if even_stride else 1)) // 2 output_padding = 0 if even_stride else 1 return nn.Sequential( CausalConvTranspose1d(chan_in, chan_out, 2 * stride, stride = stride), ResidualUnit(chan_out, chan_out, 1), ResidualUnit(chan_out, chan_out, 3), ResidualUnit(chan_out, chan_out, 9), ) class SoundStream(nn.Module): def __init__( self, *, channels = 32, strides = (2, 4, 5, 8), channel_mults = (2, 4, 8, 16), codebook_dim = 512, codebook_size = 1024, rq_num_quantizers = 8, input_channels = 1, discr_multi_scales = (1, 0.5, 0.25), recon_loss_weight = 1., adversarial_loss_weight = 1., feature_loss_weight = 100, quantize_dropout = True, quantize_dropout_cutoff_index = 0, target_sample_hz = 24000 ): super().__init__() self.target_sample_hz = target_sample_hz # for resampling on the fly self.single_channel = input_channels == 1 self.strides = strides layer_channels = tuple(map(lambda t: t * channels, channel_mults)) layer_channels = (channels, *layer_channels) chan_in_out_pairs = tuple(zip(layer_channels[:-1], layer_channels[1:])) encoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(chan_in_out_pairs, strides): encoder_blocks.append(EncoderBlock(chan_in, chan_out, layer_stride)) self.encoder = nn.Sequential( CausalConv1d(input_channels, channels, 7), *encoder_blocks, CausalConv1d(layer_channels[-1], codebook_dim, 3) ) self.rq = ResidualVQ( dim = codebook_dim, num_quantizers = rq_num_quantizers, codebook_size = codebook_size, kmeans_init = True, threshold_ema_dead_code = 2, quantize_dropout = quantize_dropout, quantize_dropout_cutoff_index = quantize_dropout_cutoff_index ) decoder_blocks = [] for ((chan_in, chan_out), layer_stride) in zip(reversed(chan_in_out_pairs), reversed(strides)): decoder_blocks.append(DecoderBlock(chan_out, chan_in, layer_stride)) self.decoder = nn.Sequential( CausalConv1d(codebook_dim, layer_channels[-1], 7), *decoder_blocks, CausalConv1d(channels, input_channels, 7) ) # discriminators self.discr_multi_scales = discr_multi_scales self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))]) self.stft_discriminator = STFTDiscriminator() # loss weights self.recon_loss_weight = recon_loss_weight self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def non_discr_parameters(self): return [*self.encoder.parameters(), *self.decoder.parameters()] @property def seq_len_multiple_of(self): return functools.reduce(lambda x, y: x * y, self.strides) def forward( self, x, return_encoded = False, return_discr_loss = False, return_discr_losses_separately = False, return_recons_only = False, input_sample_hz = None ): if exists(input_sample_hz): x = resample(x, input_sample_hz, self.target_sample_hz) x = curtail_to_multiple(x, self.seq_len_multiple_of) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') orig_x = x.clone() x = self.encoder(x) x = rearrange(x, 'b c n -> b n c') x, indices, commit_loss = self.rq(x) x = rearrange(x, 'b n c -> b c n') if return_encoded: return x, indices, commit_loss recon_x = self.decoder(x) if return_recons_only: return recon_x # multi-scale discriminator loss if return_discr_loss: real, fake = orig_x, recon_x.detach() stft_discr_loss = None discr_losses = [] if self.single_channel: real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) real_logits, fake_logits = map(discr, (scaled_real, scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) if not return_discr_losses_separately: all_discr_losses = torch.stack(discr_losses).mean() if exists(stft_discr_loss): all_discr_losses = all_discr_losses + stft_discr_loss return all_discr_losses # return a list of discriminator losses with List[Tuple[str, Tensor]] discr_losses_pkg = [] discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)]) if exists(stft_discr_loss): discr_losses_pkg.append(('stft', stft_discr_loss)) return discr_losses_pkg # recon loss recon_loss = F.mse_loss(orig_x, recon_x) # adversarial loss adversarial_losses = [] discr_intermediates = [] # adversarial loss for multi-scale discriminators real, fake = orig_x, recon_x # features from stft (stft_real_logits, stft_real_intermediates), (stft_fake_logits, stft_fake_intermediates) = map(partial(self.stft_discriminator, return_intermediates=True), (real, fake)) discr_intermediates.append((stft_real_intermediates, stft_fake_intermediates)) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) (real_logits, real_intermediates), (fake_logits, fake_intermediates) = map(partial(discr, return_intermediates = True), (scaled_real, scaled_fake)) discr_intermediates.append((real_intermediates, fake_intermediates)) one_adversarial_loss = hinge_gen_loss(fake_logits) adversarial_losses.append(one_adversarial_loss) feature_losses = [] for real_intermediates, fake_intermediates in discr_intermediates: losses = [F.l1_loss(real_intermediate, fake_intermediate) for real_intermediate, fake_intermediate in zip(real_intermediates, fake_intermediates)] feature_losses.extend(losses) feature_loss = torch.stack(feature_losses).mean() # adversarial loss for stft discriminator adversarial_losses.append(hinge_gen_loss(stft_fake_logits.real)) adversarial_losses.append(hinge_gen_loss(stft_fake_logits.imag)) adversarial_loss = torch.stack(adversarial_losses).mean() return recon_loss * self.recon_loss_weight + adversarial_loss * self.adversarial_loss_weight + feature_loss * self.feature_loss_weight # relative positional bias class RelativePositionBias(nn.Module): Loading
audiolm_pytorch/soundstream.py 0 → 100644 +434 −0 File added.Preview size limit exceeded, changes collapsed. Show changes
audiolm_pytorch/trainer.py +1 −1 Original line number Diff line number Diff line Loading @@ -20,7 +20,7 @@ from audiolm_pytorch.optimizer import get_optimizer from ema_pytorch import EMA from audiolm_pytorch.audiolm_pytorch import SoundStream from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.data import SoundDataset, get_dataloader from accelerate import Accelerator Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.25', version = '0.0.26', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading