Loading audiolm_pytorch/trainer.py +24 −4 Original line number Diff line number Diff line Loading @@ -125,6 +125,7 @@ class SoundStreamTrainer(nn.Module): discr_max_grad_norm = None, save_results_every = 100, save_model_every = 1000, log_losses_every = 1, results_folder = './results', valid_frac = 0.05, random_split_seed = 42, Loading Loading @@ -220,6 +221,7 @@ class SoundStreamTrainer(nn.Module): self.save_model_every = save_model_every self.save_results_every = save_results_every self.log_losses_every = log_losses_every self.apply_grad_penalty_every = apply_grad_penalty_every Loading Loading @@ -298,6 +300,7 @@ class SoundStreamTrainer(nn.Module): steps = int(self.steps.item()) apply_grad_penalty = not (steps % self.apply_grad_penalty_every) log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every) self.soundstream.train() Loading @@ -311,14 +314,21 @@ class SoundStreamTrainer(nn.Module): wave, = next(self.dl_iter) wave = wave.to(device) loss, (recon_loss, multi_spectral_recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True) loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, dict( loss = loss.item() / self.grad_accum_every, recon_loss = recon_loss.item() / self.grad_accum_every, multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every )) if log_losses: accum_log(logs, dict( multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every, adversarial_loss = adversarial_loss.item() / self.grad_accum_every, feature_loss = feature_loss.item() / self.grad_accum_every, all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every, )) if exists(self.max_grad_norm): Loading Loading @@ -362,7 +372,16 @@ class SoundStreamTrainer(nn.Module): # build pretty printed losses losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}" self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps) if log_losses: self.accelerator.log({ "total_loss": logs['loss'], "recon_loss": logs['recon_loss'], "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'], "adversarial_loss": logs['adversarial_loss'], "feature_loss": logs['feature_loss'], "all_commitment_loss": logs['all_commitment_loss'], "stft_discr_loss": logs['stft'] }, step=steps) for key, loss in logs.items(): if not key.startswith('scale:'): Loading @@ -370,6 +389,7 @@ class SoundStreamTrainer(nn.Module): _, scale_factor = key.split(':') losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}" if log_losses: self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps) # log Loading Loading
audiolm_pytorch/trainer.py +24 −4 Original line number Diff line number Diff line Loading @@ -125,6 +125,7 @@ class SoundStreamTrainer(nn.Module): discr_max_grad_norm = None, save_results_every = 100, save_model_every = 1000, log_losses_every = 1, results_folder = './results', valid_frac = 0.05, random_split_seed = 42, Loading Loading @@ -220,6 +221,7 @@ class SoundStreamTrainer(nn.Module): self.save_model_every = save_model_every self.save_results_every = save_results_every self.log_losses_every = log_losses_every self.apply_grad_penalty_every = apply_grad_penalty_every Loading Loading @@ -298,6 +300,7 @@ class SoundStreamTrainer(nn.Module): steps = int(self.steps.item()) apply_grad_penalty = not (steps % self.apply_grad_penalty_every) log_losses = self.log_losses_every > 0 and not (steps % self.log_losses_every) self.soundstream.train() Loading @@ -311,14 +314,21 @@ class SoundStreamTrainer(nn.Module): wave, = next(self.dl_iter) wave = wave.to(device) loss, (recon_loss, multi_spectral_recon_loss, *_) = self.soundstream(wave, return_loss_breakdown = True) loss, (recon_loss, multi_spectral_recon_loss, adversarial_loss, feature_loss, all_commitment_loss) = self.soundstream(wave, return_loss_breakdown = True) self.accelerator.backward(loss / self.grad_accum_every) accum_log(logs, dict( loss = loss.item() / self.grad_accum_every, recon_loss = recon_loss.item() / self.grad_accum_every, multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every )) if log_losses: accum_log(logs, dict( multi_spectral_recon_loss = multi_spectral_recon_loss.item() / self.grad_accum_every, adversarial_loss = adversarial_loss.item() / self.grad_accum_every, feature_loss = feature_loss.item() / self.grad_accum_every, all_commitment_loss = all_commitment_loss.item() / self.grad_accum_every, )) if exists(self.max_grad_norm): Loading Loading @@ -362,7 +372,16 @@ class SoundStreamTrainer(nn.Module): # build pretty printed losses losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}" self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps) if log_losses: self.accelerator.log({ "total_loss": logs['loss'], "recon_loss": logs['recon_loss'], "multi_spectral_recon_loss": logs['multi_spectral_recon_loss'], "adversarial_loss": logs['adversarial_loss'], "feature_loss": logs['feature_loss'], "all_commitment_loss": logs['all_commitment_loss'], "stft_discr_loss": logs['stft'] }, step=steps) for key, loss in logs.items(): if not key.startswith('scale:'): Loading @@ -370,6 +389,7 @@ class SoundStreamTrainer(nn.Module): _, scale_factor = key.split(':') losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}" if log_losses: self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps) # log Loading