Loading audiolm_pytorch/data.py +2 −2 Original line number Diff line number Diff line Loading @@ -93,9 +93,9 @@ def curtail_to_shortest_collate(data): return torch.stack(data) @collate_one_or_multiple_tensors def pad_to_longest(data): def pad_to_longest_fn(data): return pad_sequence(data, batch_first = True) def get_dataloader(ds, pad_to_longest = True, **kwargs): collate_fn = pad_to_longest if pad_to_longest else curtail_to_shortest_collate collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate return DataLoader(ds, collate_fn = collate_fn, **kwargs) audiolm_pytorch/soundstream.py +25 −8 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from functools import partial import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F from einops import rearrange, reduce Loading @@ -30,13 +31,13 @@ def hinge_gen_loss(fake): def leaky_relu(p = 0.1): return nn.LeakyReLU(p) def gradient_penalty(images, output, weight = 10): batch_size = images.shape[0] def gradient_penalty(wave, output, weight = 10): batch_size, device = wave.shape[0], wave.device gradients = torch_grad( outputs = output, inputs = images, grad_outputs = torch.ones(output.size(), device = images.device), inputs = wave, grad_outputs = torch.ones_like(output), create_graph = True, retain_graph = True, only_inputs = True Loading Loading @@ -334,7 +335,8 @@ class SoundStream(nn.Module): return_discr_loss = False, return_discr_losses_separately = False, return_recons_only = False, input_sample_hz = None input_sample_hz = None, apply_grad_penalty = False ): if exists(input_sample_hz): x = resample(x, input_sample_hz, self.target_sample_hz) Loading Loading @@ -366,19 +368,26 @@ class SoundStream(nn.Module): real, fake = orig_x, recon_x.detach() stft_discr_loss = None stft_grad_penalty = None discr_losses = [] discr_grad_penalties = [] if self.single_channel: real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) real, fake = orig_x.clone(), recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real.requires_grad_(), fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 if apply_grad_penalty: stft_grad_penalty = gradient_penalty(real, stft_discr_loss) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) real_logits, fake_logits = map(discr, (scaled_real, scaled_fake)) real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss)) if not return_discr_losses_separately: all_discr_losses = torch.stack(discr_losses).mean() Loading @@ -386,6 +395,9 @@ class SoundStream(nn.Module): if exists(stft_discr_loss): all_discr_losses = all_discr_losses + stft_discr_loss if exists(stft_grad_penalty): all_discr_losses = all_discr_losses + stft_grad_penalty return all_discr_losses # return a list of discriminator losses with List[Tuple[str, Tensor]] Loading @@ -394,9 +406,14 @@ class SoundStream(nn.Module): discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)]) discr_losses_pkg.extend([(f'scale_grad_penalty:{scale}', discr_grad_penalty) for scale, discr_grad_penalty in zip(self.discr_multi_scales, discr_grad_penalties)]) if exists(stft_discr_loss): discr_losses_pkg.append(('stft', stft_discr_loss)) if exists(stft_grad_penalty): discr_losses_pkg.append(('stft_grad_penalty', stft_grad_penalty)) return discr_losses_pkg # recon loss Loading audiolm_pytorch/trainer.py +2 −1 Original line number Diff line number Diff line Loading @@ -236,12 +236,13 @@ class SoundStreamTrainer(nn.Module): discr_losses = self.soundstream( wave, apply_grad_penalty = apply_grad_penalty, return_discr_loss = True, return_discr_losses_separately = True ) for name, discr_loss in discr_losses: self.accelerator.backward(discr_loss / self.grad_accum_every) self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True) accum_log(logs, {name: discr_loss.item() / self.grad_accum_every}) if exists(self.discr_max_grad_norm): Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.49', version = '0.0.50', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/data.py +2 −2 Original line number Diff line number Diff line Loading @@ -93,9 +93,9 @@ def curtail_to_shortest_collate(data): return torch.stack(data) @collate_one_or_multiple_tensors def pad_to_longest(data): def pad_to_longest_fn(data): return pad_sequence(data, batch_first = True) def get_dataloader(ds, pad_to_longest = True, **kwargs): collate_fn = pad_to_longest if pad_to_longest else curtail_to_shortest_collate collate_fn = pad_to_longest_fn if pad_to_longest else curtail_to_shortest_collate return DataLoader(ds, collate_fn = collate_fn, **kwargs)
audiolm_pytorch/soundstream.py +25 −8 Original line number Diff line number Diff line Loading @@ -4,6 +4,7 @@ from functools import partial import torch from torch import nn, einsum from torch.autograd import grad as torch_grad import torch.nn.functional as F from einops import rearrange, reduce Loading @@ -30,13 +31,13 @@ def hinge_gen_loss(fake): def leaky_relu(p = 0.1): return nn.LeakyReLU(p) def gradient_penalty(images, output, weight = 10): batch_size = images.shape[0] def gradient_penalty(wave, output, weight = 10): batch_size, device = wave.shape[0], wave.device gradients = torch_grad( outputs = output, inputs = images, grad_outputs = torch.ones(output.size(), device = images.device), inputs = wave, grad_outputs = torch.ones_like(output), create_graph = True, retain_graph = True, only_inputs = True Loading Loading @@ -334,7 +335,8 @@ class SoundStream(nn.Module): return_discr_loss = False, return_discr_losses_separately = False, return_recons_only = False, input_sample_hz = None input_sample_hz = None, apply_grad_penalty = False ): if exists(input_sample_hz): x = resample(x, input_sample_hz, self.target_sample_hz) Loading Loading @@ -366,19 +368,26 @@ class SoundStream(nn.Module): real, fake = orig_x, recon_x.detach() stft_discr_loss = None stft_grad_penalty = None discr_losses = [] discr_grad_penalties = [] if self.single_channel: real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) real, fake = orig_x.clone(), recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real.requires_grad_(), fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 if apply_grad_penalty: stft_grad_penalty = gradient_penalty(real, stft_discr_loss) for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) real_logits, fake_logits = map(discr, (scaled_real, scaled_fake)) real_logits, fake_logits = map(discr, (scaled_real.requires_grad_(), scaled_fake)) one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss)) if not return_discr_losses_separately: all_discr_losses = torch.stack(discr_losses).mean() Loading @@ -386,6 +395,9 @@ class SoundStream(nn.Module): if exists(stft_discr_loss): all_discr_losses = all_discr_losses + stft_discr_loss if exists(stft_grad_penalty): all_discr_losses = all_discr_losses + stft_grad_penalty return all_discr_losses # return a list of discriminator losses with List[Tuple[str, Tensor]] Loading @@ -394,9 +406,14 @@ class SoundStream(nn.Module): discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)]) discr_losses_pkg.extend([(f'scale_grad_penalty:{scale}', discr_grad_penalty) for scale, discr_grad_penalty in zip(self.discr_multi_scales, discr_grad_penalties)]) if exists(stft_discr_loss): discr_losses_pkg.append(('stft', stft_discr_loss)) if exists(stft_grad_penalty): discr_losses_pkg.append(('stft_grad_penalty', stft_grad_penalty)) return discr_losses_pkg # recon loss Loading
audiolm_pytorch/trainer.py +2 −1 Original line number Diff line number Diff line Loading @@ -236,12 +236,13 @@ class SoundStreamTrainer(nn.Module): discr_losses = self.soundstream( wave, apply_grad_penalty = apply_grad_penalty, return_discr_loss = True, return_discr_losses_separately = True ) for name, discr_loss in discr_losses: self.accelerator.backward(discr_loss / self.grad_accum_every) self.accelerator.backward(discr_loss / self.grad_accum_every, retain_graph = True) accum_log(logs, {name: discr_loss.item() / self.grad_accum_every}) if exists(self.discr_max_grad_norm): Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.49', version = '0.0.50', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading