Commit 535fa4a1 authored by Phil Wang's avatar Phil Wang
Browse files

fix issue with trainers not saving optimizer states

parent 0d96c1af
Loading
Loading
Loading
Loading
+4 −12
Original line number Diff line number Diff line
@@ -402,13 +402,8 @@ class SoundStreamTrainer(nn.Module):
        # save model every so often

        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.soundstream.state_dict()
            model_path = str(self.results_folder / f'soundstream.{steps}.pt')
            torch.save(state_dict, model_path)

            ema_state_dict = self.ema_soundstream.state_dict()
            model_path = str(self.results_folder / f'soundstream.{steps}.ema.pt')
            torch.save(ema_state_dict, model_path)
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

@@ -634,9 +629,8 @@ class SemanticTransformerTrainer(nn.Module):
        # save model every so often

        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.transformer.state_dict()
            model_path = str(self.results_folder / f'semantic.transformer.{steps}.pt')
            torch.save(state_dict, model_path)
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

@@ -873,9 +867,8 @@ class CoarseTransformerTrainer(nn.Module):
        # save model every so often

        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.transformer.state_dict()
            model_path = str(self.results_folder / f'fine.transformer.{steps}.pt')
            torch.save(state_dict, model_path)
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

@@ -1105,9 +1098,8 @@ class FineTransformerTrainer(nn.Module):
        # save model every so often

        if self.is_main and not (steps % self.save_model_every):
            state_dict = self.transformer.state_dict()
            model_path = str(self.results_folder / f'fine.transformer.{steps}.pt')
            torch.save(state_dict, model_path)
            self.save(model_path)

            self.print(f'{steps}: saving model to {str(self.results_folder)}')

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.12.2',
  version = '0.12.3',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',