Loading audiolm_pytorch/trainer.py +22 −11 Original line number Diff line number Diff line Loading @@ -129,6 +129,7 @@ class SoundStreamTrainer(nn.Module): results_folder = './results', valid_frac = 0.05, random_split_seed = 42, use_ema = True, ema_beta = 0.995, ema_update_after_step = 500, ema_update_every = 10, Loading @@ -144,6 +145,9 @@ class SoundStreamTrainer(nn.Module): self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs) self.soundstream = soundstream self.use_ema = use_ema if self.use_ema: self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every) self.register_buffer('steps', torch.Tensor([0])) Loading Loading @@ -238,11 +242,13 @@ class SoundStreamTrainer(nn.Module): def save(self, path): pkg = dict( model = self.accelerator.get_state_dict(self.soundstream), ema_model = self.ema_soundstream.state_dict(), optim = self.optim.state_dict(), discr_optim = self.discr_optim.state_dict() ) if self.use_ema: pkg['ema_model'] = self.ema_soundstream.state_dict() for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) pkg[key] = discr_optim.state_dict() Loading @@ -260,7 +266,10 @@ class SoundStreamTrainer(nn.Module): self.unwrapped_soundstream.load_state_dict(pkg['model']) if self.use_ema: assert 'ema_model' in pkg self.ema_soundstream.load_state_dict(pkg['ema_model']) self.optim.load_state_dict(pkg['optim']) self.discr_optim.load_state_dict(pkg['discr_optim']) Loading Loading @@ -400,7 +409,7 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main: if self.is_main and self.use_ema: self.ema_soundstream.update() # sample results every so often Loading @@ -408,19 +417,21 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main and not (steps % self.save_results_every): for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))): model.eval() models = [(self.unwrapped_soundstream, str(steps))] if self.use_ema: models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema')) wave, = next(self.valid_dl_iter) wave = wave.to(device) for model, label in models: model.eval() with torch.no_grad(): recons = model(wave, return_recons_only = True) milestone = steps // self.save_results_every for ind, recon in enumerate(recons.unbind(dim = 0)): filename = str(self.results_folder / f'sample_{steps}.flac') filename = str(self.results_folder / f'sample_{label}.flac') torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz) self.print(f'{steps}: saving to {str(self.results_folder)}') Loading Loading
audiolm_pytorch/trainer.py +22 −11 Original line number Diff line number Diff line Loading @@ -129,6 +129,7 @@ class SoundStreamTrainer(nn.Module): results_folder = './results', valid_frac = 0.05, random_split_seed = 42, use_ema = True, ema_beta = 0.995, ema_update_after_step = 500, ema_update_every = 10, Loading @@ -144,6 +145,9 @@ class SoundStreamTrainer(nn.Module): self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs) self.soundstream = soundstream self.use_ema = use_ema if self.use_ema: self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every) self.register_buffer('steps', torch.Tensor([0])) Loading Loading @@ -238,11 +242,13 @@ class SoundStreamTrainer(nn.Module): def save(self, path): pkg = dict( model = self.accelerator.get_state_dict(self.soundstream), ema_model = self.ema_soundstream.state_dict(), optim = self.optim.state_dict(), discr_optim = self.discr_optim.state_dict() ) if self.use_ema: pkg['ema_model'] = self.ema_soundstream.state_dict() for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) pkg[key] = discr_optim.state_dict() Loading @@ -260,7 +266,10 @@ class SoundStreamTrainer(nn.Module): self.unwrapped_soundstream.load_state_dict(pkg['model']) if self.use_ema: assert 'ema_model' in pkg self.ema_soundstream.load_state_dict(pkg['ema_model']) self.optim.load_state_dict(pkg['optim']) self.discr_optim.load_state_dict(pkg['discr_optim']) Loading Loading @@ -400,7 +409,7 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main: if self.is_main and self.use_ema: self.ema_soundstream.update() # sample results every so often Loading @@ -408,19 +417,21 @@ class SoundStreamTrainer(nn.Module): self.accelerator.wait_for_everyone() if self.is_main and not (steps % self.save_results_every): for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))): model.eval() models = [(self.unwrapped_soundstream, str(steps))] if self.use_ema: models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema')) wave, = next(self.valid_dl_iter) wave = wave.to(device) for model, label in models: model.eval() with torch.no_grad(): recons = model(wave, return_recons_only = True) milestone = steps // self.save_results_every for ind, recon in enumerate(recons.unbind(dim = 0)): filename = str(self.results_folder / f'sample_{steps}.flac') filename = str(self.results_folder / f'sample_{label}.flac') torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz) self.print(f'{steps}: saving to {str(self.results_folder)}') Loading