Loading README.md +1 −0 Original line number Diff line number Diff line Loading @@ -76,6 +76,7 @@ loss.backward() - [ ] test with speech synthesis for starters - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] function for pretty printing all discriminator losses to log ## Citations Loading audiolm_pytorch/audiolm_pytorch.py +30 −12 Original line number Diff line number Diff line Loading @@ -350,6 +350,9 @@ class SoundStream(nn.Module): self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def non_discr_parameters(self): return [*self.encoder.parameters(), *self.decoder.parameters()] @property def seq_len_multiple_of(self): return functools.reduce(lambda x, y: x * y, self.strides) Loading @@ -359,8 +362,8 @@ class SoundStream(nn.Module): x, return_encoded = False, return_discr_loss = False, return_recons_only = False, return_stft_discr_loss = False return_discr_losses_separately = False, return_recons_only = False ): if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -381,21 +384,19 @@ class SoundStream(nn.Module): if return_recons_only: return recon_x # stft discr loss if return_stft_discr_loss: assert self.single_channel real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 return stft_discr_loss # multi-scale discriminator loss if return_discr_loss: real, fake = orig_x, recon_x.detach() stft_discr_loss = None discr_losses = [] if self.single_channel: real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) Loading @@ -403,7 +404,24 @@ class SoundStream(nn.Module): one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) return torch.stack(discr_losses).mean() if not return_discr_losses_separately: all_discr_losses = torch.stack(discr_losses).mean() if exists(stft_discr_loss): all_discr_losses = all_discr_losses + stft_discr_loss return all_discr_losses # return a list of discriminator losses with List[Tuple[str, Tensor]] discr_losses_pkg = [] discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)]) if exists(stft_discr_loss): discr_losses_pkg.append(('stft', stft_discr_loss)) return discr_losses_pkg # recon loss Loading audiolm_pytorch/data.py +5 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ class SoundDataset(Dataset): self, folder, exts = ['flac', 'wav'], max_length = None, seq_len_multiple_of = None ): super().__init__() Loading @@ -28,6 +29,7 @@ class SoundDataset(Dataset): assert len(files) > 0, 'no sound files found' self.files = files self.max_length = max_length self.seq_len_multiple_of = seq_len_multiple_of def __len__(self): Loading @@ -39,6 +41,9 @@ class SoundDataset(Dataset): data = rearrange(data, '1 ... -> ...') if exists(self.max_length): data = data[:self.max_length] if exists(self.seq_len_multiple_of): mult = self.seq_len_multiple_of data_len = len(data) Loading audiolm_pytorch/trainer.py +34 −19 Original line number Diff line number Diff line Loading @@ -64,6 +64,7 @@ class SoundStreamTrainer(nn.Module): *, num_train_steps, batch_size, data_max_length = None, folder, lr = 3e-4, grad_accum_every = 1, Loading Loading @@ -93,14 +94,17 @@ class SoundStreamTrainer(nn.Module): self.batch_size = batch_size self.grad_accum_every = grad_accum_every all_parameters = set(soundstream.parameters()) discr_parameters = set(soundstream.stft_discriminator.parameters()) soundstream_parameters = all_parameters - discr_parameters # optimizers self.soundstream_parameters = soundstream_parameters self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd) self.optim = get_optimizer(soundstream_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) for ind, discr in enumerate(soundstream.discriminators): one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd) setattr(self, f'multiscale_discr_optimizer_{ind}', one_multiscale_discr_optimizer) self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd) # max grad norm self.max_grad_norm = max_grad_norm self.discr_max_grad_norm = discr_max_grad_norm Loading @@ -109,6 +113,7 @@ class SoundStreamTrainer(nn.Module): self.ds = SoundDataset( folder, max_length = data_max_length, seq_len_multiple_of = soundstream.seq_len_multiple_of ) Loading Loading @@ -211,26 +216,36 @@ class SoundStreamTrainer(nn.Module): # update discriminator if exists(self.soundstream.stft_discriminator): for _ in range(self.grad_accum_every): wave = next(self.dl_iter) wave = wave.to(device) loss = self.soundstream(wave, return_discr_loss = True) self.accelerator.backward(loss / self.grad_accum_every) discr_losses = self.soundstream( wave, return_discr_loss = True, return_discr_losses_separately = True ) accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) for name, discr_loss in discr_losses: self.accelerator.backward(discr_loss / self.grad_accum_every) accum_log(logs, {name: discr_loss.item() / self.grad_accum_every}) if exists(self.discr_max_grad_norm): self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm) # gradient step for all discriminators self.discr_optim.step() self.discr_optim.zero_grad() for ind in range(len(self.soundstream.discriminators)): discr_optimizer = getattr(self, f'multiscale_discr_optimizer_{ind}') discr_optimizer.step() discr_optimizer.zero_grad() # log self.print(f"{steps}: soundstream loss: {logs['loss']} - discr loss: {logs['discr_loss']}") self.print(f"{steps}: soundstream loss: {logs['loss']}") # update exponential moving averaged generator Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.17', version = '0.0.18', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
README.md +1 −0 Original line number Diff line number Diff line Loading @@ -76,6 +76,7 @@ loss.backward() - [ ] test with speech synthesis for starters - [ ] abstract out conditioning + classifier free guidance into external module or potentially a package - [ ] add option to use flash attention - [ ] function for pretty printing all discriminator losses to log ## Citations Loading
audiolm_pytorch/audiolm_pytorch.py +30 −12 Original line number Diff line number Diff line Loading @@ -350,6 +350,9 @@ class SoundStream(nn.Module): self.adversarial_loss_weight = adversarial_loss_weight self.feature_loss_weight = feature_loss_weight def non_discr_parameters(self): return [*self.encoder.parameters(), *self.decoder.parameters()] @property def seq_len_multiple_of(self): return functools.reduce(lambda x, y: x * y, self.strides) Loading @@ -359,8 +362,8 @@ class SoundStream(nn.Module): x, return_encoded = False, return_discr_loss = False, return_recons_only = False, return_stft_discr_loss = False return_discr_losses_separately = False, return_recons_only = False ): if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -381,21 +384,19 @@ class SoundStream(nn.Module): if return_recons_only: return recon_x # stft discr loss if return_stft_discr_loss: assert self.single_channel real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 return stft_discr_loss # multi-scale discriminator loss if return_discr_loss: real, fake = orig_x, recon_x.detach() stft_discr_loss = None discr_losses = [] if self.single_channel: real, fake = orig_x, recon_x.detach() stft_real_logits, stft_fake_logits = map(self.stft_discriminator, (real, fake)) stft_discr_loss = (hinge_discr_loss(stft_fake_logits.real, stft_real_logits.real) + hinge_discr_loss(stft_fake_logits.imag, stft_real_logits.imag)) / 2 for discr, scale in zip(self.discriminators, self.discr_multi_scales): scaled_real, scaled_fake = map(lambda t: F.interpolate(t, scale_factor = scale), (real, fake)) Loading @@ -403,7 +404,24 @@ class SoundStream(nn.Module): one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) return torch.stack(discr_losses).mean() if not return_discr_losses_separately: all_discr_losses = torch.stack(discr_losses).mean() if exists(stft_discr_loss): all_discr_losses = all_discr_losses + stft_discr_loss return all_discr_losses # return a list of discriminator losses with List[Tuple[str, Tensor]] discr_losses_pkg = [] discr_losses_pkg.extend([(f'scale:{scale}', multi_scale_loss) for scale, multi_scale_loss in zip(self.discr_multi_scales, discr_losses)]) if exists(stft_discr_loss): discr_losses_pkg.append(('stft', stft_discr_loss)) return discr_losses_pkg # recon loss Loading
audiolm_pytorch/data.py +5 −0 Original line number Diff line number Diff line Loading @@ -18,6 +18,7 @@ class SoundDataset(Dataset): self, folder, exts = ['flac', 'wav'], max_length = None, seq_len_multiple_of = None ): super().__init__() Loading @@ -28,6 +29,7 @@ class SoundDataset(Dataset): assert len(files) > 0, 'no sound files found' self.files = files self.max_length = max_length self.seq_len_multiple_of = seq_len_multiple_of def __len__(self): Loading @@ -39,6 +41,9 @@ class SoundDataset(Dataset): data = rearrange(data, '1 ... -> ...') if exists(self.max_length): data = data[:self.max_length] if exists(self.seq_len_multiple_of): mult = self.seq_len_multiple_of data_len = len(data) Loading
audiolm_pytorch/trainer.py +34 −19 Original line number Diff line number Diff line Loading @@ -64,6 +64,7 @@ class SoundStreamTrainer(nn.Module): *, num_train_steps, batch_size, data_max_length = None, folder, lr = 3e-4, grad_accum_every = 1, Loading Loading @@ -93,14 +94,17 @@ class SoundStreamTrainer(nn.Module): self.batch_size = batch_size self.grad_accum_every = grad_accum_every all_parameters = set(soundstream.parameters()) discr_parameters = set(soundstream.stft_discriminator.parameters()) soundstream_parameters = all_parameters - discr_parameters # optimizers self.soundstream_parameters = soundstream_parameters self.optim = get_optimizer(soundstream.non_discr_parameters(), lr = lr, wd = wd) self.optim = get_optimizer(soundstream_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) for ind, discr in enumerate(soundstream.discriminators): one_multiscale_discr_optimizer = get_optimizer(discr.parameters(), lr = lr, wd = wd) setattr(self, f'multiscale_discr_optimizer_{ind}', one_multiscale_discr_optimizer) self.discr_optim = get_optimizer(soundstream.stft_discriminator.parameters(), lr = lr, wd = wd) # max grad norm self.max_grad_norm = max_grad_norm self.discr_max_grad_norm = discr_max_grad_norm Loading @@ -109,6 +113,7 @@ class SoundStreamTrainer(nn.Module): self.ds = SoundDataset( folder, max_length = data_max_length, seq_len_multiple_of = soundstream.seq_len_multiple_of ) Loading Loading @@ -211,26 +216,36 @@ class SoundStreamTrainer(nn.Module): # update discriminator if exists(self.soundstream.stft_discriminator): for _ in range(self.grad_accum_every): wave = next(self.dl_iter) wave = wave.to(device) loss = self.soundstream(wave, return_discr_loss = True) self.accelerator.backward(loss / self.grad_accum_every) discr_losses = self.soundstream( wave, return_discr_loss = True, return_discr_losses_separately = True ) accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) for name, discr_loss in discr_losses: self.accelerator.backward(discr_loss / self.grad_accum_every) accum_log(logs, {name: discr_loss.item() / self.grad_accum_every}) if exists(self.discr_max_grad_norm): self.accelerator.clip_grad_norm_(self.soundstream.stft_discriminator.parameters(), self.discr_max_grad_norm) # gradient step for all discriminators self.discr_optim.step() self.discr_optim.zero_grad() for ind in range(len(self.soundstream.discriminators)): discr_optimizer = getattr(self, f'multiscale_discr_optimizer_{ind}') discr_optimizer.step() discr_optimizer.zero_grad() # log self.print(f"{steps}: soundstream loss: {logs['loss']} - discr loss: {logs['discr_loss']}") self.print(f"{steps}: soundstream loss: {logs['loss']}") # update exponential moving averaged generator Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.0.17', version = '0.0.18', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading