Loading audiolm_pytorch/soundstream.py +2 −1 Original line number Diff line number Diff line Loading @@ -588,6 +588,7 @@ class SoundStream(nn.Module): one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) if apply_grad_penalty: discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss)) if not return_discr_losses_separately: Loading audiolm_pytorch/trainer.py +47 −16 Original line number Diff line number Diff line Loading @@ -125,9 +125,11 @@ class SoundStreamTrainer(nn.Module): discr_max_grad_norm = None, save_results_every = 100, save_model_every = 1000, log_losses_every = 1, results_folder = './results', valid_frac = 0.05, random_split_seed = 42, use_ema = True, ema_beta = 0.995, ema_update_after_step = 500, ema_update_every = 10, Loading @@ -143,6 +145,9 @@ class SoundStreamTrainer(nn.Module): self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs) self.soundstream = soundstream self.use_ema = use_ema if self.use_ema: self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every) self.register_buffer('steps', torch.Tensor([0])) Loading Loading @@ -220,6 +225,7 @@ class SoundStreamTrainer(nn.Module): self.save_model_every = save_model_every self.save_results_every = save_results_every self.log_losses_every = log_losses_every self.apply_grad_penalty_every = apply_grad_penalty_every Loading @@ -236,11 +242,13 @@ class SoundStreamTrainer(nn.Module): def save(self, path): pkg = dict( model = self.accelerator.get_state_dict(self.soundstream), ema_model = self.ema_soundstream.state_dict(), optim = self.optim.state_dict(), discr_optim = self.discr_optim.state_dict() ) if self.use_ema: pkg['ema_model'] = self.ema_soundstream.state_dict() for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) pkg[key] = discr_optim.state_dict() Loading @@ -258,7 +266,10 @@ class SoundStreamTrainer(nn.Module): self.unwrapped_soundstream.load_state_dict(pkg['model']) if self.use_ema: assert 'ema_model' in pkg self.ema_soundstream.load_state_dict(pkg['ema_model']) self.optim.load_state_dict(pkg['optim']) self.discr_optim.load_state_dict(pkg['discr_optim']) Loading Loading @@ -297,7 +308,8 @@ class SoundStreamTrainer(nn.Module): device = self.device steps = int(self.steps.item()) apply_grad_penalty = not (steps % self.apply_grad_penalty_every) apply_grad_penalty = self.apply_grad_penalty_every > 0 and not (steps % self.apply_grad_penalty_every) log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every) self.soundstream.train() Loading @@ -311,14 +323,21 @@ class SoundStreamTrainer(nn.Module): wave, = next(self.dl_iter) wave = wave.to(device) loss, (recon_loss, multi_spectral_recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True) loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, dict( loss = loss.item() / self.grad_accum_every, recon_loss = recon_loss.item() / self.grad_accum_every, multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every )) if log_losses: accum_log(logs, dict( multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every, adversarial_loss = adversarial_loss.item() / self.grad_accum_every, feature_loss = feature_loss.item() / self.grad_accum_every, all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every, )) if exists(self.max_grad_norm): Loading Loading @@ -362,7 +381,16 @@ class SoundStreamTrainer(nn.Module): # build pretty printed losses losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}" self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps) if log_losses: self.accelerator.log({ "total_loss": logs['loss'], "recon_loss": logs['recon_loss'], "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'], "adversarial_loss": logs['adversarial_loss'], "feature_loss": logs['feature_loss'], "all_commitment_loss": logs['all_commitment_loss'], "stft_discr_loss": logs['stft'] }, step=steps) for key, loss in logs.items(): if not key.startswith('scale:'): Loading @@ -370,6 +398,7 @@ class SoundStreamTrainer(nn.Module): _, scale_factor = key.split(':') losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}" if log_losses: self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps) # log Loading @@ -380,7 +409,7 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main: if self.is_main and self.use_ema: self.ema_soundstream.update() # sample results every so often Loading @@ -388,19 +417,21 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main and not (steps % self.save_results_every): for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))): model.eval() models = [(self.unwrapped_soundstream, str(steps))] if self.use_ema: models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema')) wave, = next(self.valid_dl_iter) wave = wave.to(device) for model, label in models: model.eval() with torch.no_grad(): recons = model(wave, return_recons_only = True) milestone = steps // self.save_results_every for ind, recon in enumerate(recons.unbind(dim = 0)): filename = str(self.results_folder / f'sample_{steps}.flac') filename = str(self.results_folder / f'sample_{label}.flac') torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz) self.print(f'{steps}: saving to {str(self.results_folder)}') Loading Loading
audiolm_pytorch/soundstream.py +2 −1 Original line number Diff line number Diff line Loading @@ -588,6 +588,7 @@ class SoundStream(nn.Module): one_discr_loss = hinge_discr_loss(fake_logits, real_logits) discr_losses.append(one_discr_loss) if apply_grad_penalty: discr_grad_penalties.append(gradient_penalty(scaled_real, one_discr_loss)) if not return_discr_losses_separately: Loading
audiolm_pytorch/trainer.py +47 −16 Original line number Diff line number Diff line Loading @@ -125,9 +125,11 @@ class SoundStreamTrainer(nn.Module): discr_max_grad_norm = None, save_results_every = 100, save_model_every = 1000, log_losses_every = 1, results_folder = './results', valid_frac = 0.05, random_split_seed = 42, use_ema = True, ema_beta = 0.995, ema_update_after_step = 500, ema_update_every = 10, Loading @@ -143,6 +145,9 @@ class SoundStreamTrainer(nn.Module): self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs) self.soundstream = soundstream self.use_ema = use_ema if self.use_ema: self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every) self.register_buffer('steps', torch.Tensor([0])) Loading Loading @@ -220,6 +225,7 @@ class SoundStreamTrainer(nn.Module): self.save_model_every = save_model_every self.save_results_every = save_results_every self.log_losses_every = log_losses_every self.apply_grad_penalty_every = apply_grad_penalty_every Loading @@ -236,11 +242,13 @@ class SoundStreamTrainer(nn.Module): def save(self, path): pkg = dict( model = self.accelerator.get_state_dict(self.soundstream), ema_model = self.ema_soundstream.state_dict(), optim = self.optim.state_dict(), discr_optim = self.discr_optim.state_dict() ) if self.use_ema: pkg['ema_model'] = self.ema_soundstream.state_dict() for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) pkg[key] = discr_optim.state_dict() Loading @@ -258,7 +266,10 @@ class SoundStreamTrainer(nn.Module): self.unwrapped_soundstream.load_state_dict(pkg['model']) if self.use_ema: assert 'ema_model' in pkg self.ema_soundstream.load_state_dict(pkg['ema_model']) self.optim.load_state_dict(pkg['optim']) self.discr_optim.load_state_dict(pkg['discr_optim']) Loading Loading @@ -297,7 +308,8 @@ class SoundStreamTrainer(nn.Module): device = self.device steps = int(self.steps.item()) apply_grad_penalty = not (steps % self.apply_grad_penalty_every) apply_grad_penalty = self.apply_grad_penalty_every > 0 and not (steps % self.apply_grad_penalty_every) log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every) self.soundstream.train() Loading @@ -311,14 +323,21 @@ class SoundStreamTrainer(nn.Module): wave, = next(self.dl_iter) wave = wave.to(device) loss, (recon_loss, multi_spectral_recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True) loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, dict( loss = loss.item() / self.grad_accum_every, recon_loss = recon_loss.item() / self.grad_accum_every, multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every )) if log_losses: accum_log(logs, dict( multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every, adversarial_loss = adversarial_loss.item() / self.grad_accum_every, feature_loss = feature_loss.item() / self.grad_accum_every, all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every, )) if exists(self.max_grad_norm): Loading Loading @@ -362,7 +381,16 @@ class SoundStreamTrainer(nn.Module): # build pretty printed losses losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}" self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps) if log_losses: self.accelerator.log({ "total_loss": logs['loss'], "recon_loss": logs['recon_loss'], "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'], "adversarial_loss": logs['adversarial_loss'], "feature_loss": logs['feature_loss'], "all_commitment_loss": logs['all_commitment_loss'], "stft_discr_loss": logs['stft'] }, step=steps) for key, loss in logs.items(): if not key.startswith('scale:'): Loading @@ -370,6 +398,7 @@ class SoundStreamTrainer(nn.Module): _, scale_factor = key.split(':') losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}" if log_losses: self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps) # log Loading @@ -380,7 +409,7 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main: if self.is_main and self.use_ema: self.ema_soundstream.update() # sample results every so often Loading @@ -388,19 +417,21 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main and not (steps % self.save_results_every): for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))): model.eval() models = [(self.unwrapped_soundstream, str(steps))] if self.use_ema: models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema')) wave, = next(self.valid_dl_iter) wave = wave.to(device) for model, label in models: model.eval() with torch.no_grad(): recons = model(wave, return_recons_only = True) milestone = steps // self.save_results_every for ind, recon in enumerate(recons.unbind(dim = 0)): filename = str(self.results_folder / f'sample_{steps}.flac') filename = str(self.results_folder / f'sample_{label}.flac') torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz) self.print(f'{steps}: saving to {str(self.results_folder)}') Loading