Commit 95ebad6c authored by Ilya Borovik's avatar Ilya Borovik
Browse files

control ema model with `use_ema`

+ fix bug with saving decoded ema model samples
parent 15f83db0
Loading
Loading
Loading
Loading
+22 −11
Original line number Diff line number Diff line
@@ -129,6 +129,7 @@ class SoundStreamTrainer(nn.Module):
        results_folder = './results',
        valid_frac = 0.05,
        random_split_seed = 42,
        use_ema = True,
        ema_beta = 0.995,
        ema_update_after_step = 500,
        ema_update_every = 10,
@@ -144,6 +145,9 @@ class SoundStreamTrainer(nn.Module):
        self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs)

        self.soundstream = soundstream

        self.use_ema = use_ema
        if self.use_ema:
            self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every)

        self.register_buffer('steps', torch.Tensor([0]))
@@ -238,11 +242,13 @@ class SoundStreamTrainer(nn.Module):
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
            ema_model = self.ema_soundstream.state_dict(),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )

        if self.use_ema:
            pkg['ema_model'] = self.ema_soundstream.state_dict()

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            pkg[key] = discr_optim.state_dict()
@@ -260,7 +266,10 @@ class SoundStreamTrainer(nn.Module):

        self.unwrapped_soundstream.load_state_dict(pkg['model'])

        if self.use_ema:
            assert 'ema_model' in pkg
            self.ema_soundstream.load_state_dict(pkg['ema_model'])

        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

@@ -400,7 +409,7 @@ class SoundStreamTrainer(nn.Module):

        self.accelerator.wait_for_everyone()

        if self.is_main:
        if self.is_main and self.use_ema:
            self.ema_soundstream.update()

        # sample results every so often
@@ -408,19 +417,21 @@ class SoundStreamTrainer(nn.Module):
        self.accelerator.wait_for_everyone()

        if self.is_main and not (steps % self.save_results_every):
            for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))):
                model.eval()
            models = [(self.unwrapped_soundstream, str(steps))]
            if self.use_ema:
                models.append((self.ema_soundstream.ema_model if self.use_ema else self.unwrapped_soundstream, f'{steps}.ema'))

            wave, = next(self.valid_dl_iter)
            wave = wave.to(device)

            for model, label in models:
                model.eval()

                with torch.no_grad():
                    recons = model(wave, return_recons_only = True)

                milestone = steps // self.save_results_every

                for ind, recon in enumerate(recons.unbind(dim = 0)):
                    filename = str(self.results_folder / f'sample_{steps}.flac')
                    filename = str(self.results_folder / f'sample_{label}.flac')
                    torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz)

            self.print(f'{steps}: saving to {str(self.results_folder)}')