Commit 89a35002 authored by Phil Wang's avatar Phil Wang
Browse files

add helper save load methods to all the trainers, for saving optimizer states,...

add helper save load methods to all the trainers, for saving optimizer states, including all multiscale discriminator optimizers for soundstream training
parent deceb7a0
Loading
Loading
Loading
Loading
+72 −1
Original line number Diff line number Diff line
@@ -225,6 +225,32 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

    def save(self, path):
        pkg = dict(
            model = self.soundstream.state_dict(),
            ema_model = self.ema_soundstream.state_dict(),
            discr_optim = self.discr_optim.state_dict()
        )

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            pkg[key] = discr_optim.state_dict()

        torch.save(pkg, path)

    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        self.soundstream.load_state_dict(pkg['model'])
        self.ema_soundstream.load_state_dict(pkg['ema_model'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])

        for key, _ in self.multiscale_discriminator_iter():
            discr_optim = getattr(self, key)
            discr_optim.load_state_dict(pkg[key])

    def multiscale_discriminator_iter(self):
        for ind, discr in enumerate(self.soundstream.discriminators):
            yield f'multiscale_discr_optimizer_{ind}', discr
@@ -485,6 +511,21 @@ class SemanticTransformerTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

    def save(self, path):
        pkg = dict(
            model = self.transformer.state_dict(),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)

    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        self.transformer.load_state_dict(pkg['model'])
        self.optim.load_state_dict(pkg['optim'])

    def print(self, msg):
        self.accelerator.print(msg)

@@ -698,6 +739,21 @@ class CoarseTransformerTrainer(nn.Module):

        self.train_wrapper.to(self.device)

    def save(self, path):
        pkg = dict(
            model = self.transformer.state_dict(),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)

    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        self.transformer.load_state_dict(pkg['model'])
        self.optim.load_state_dict(pkg['optim'])

    def print(self, msg):
        self.accelerator.print(msg)

@@ -904,6 +960,21 @@ class FineTransformerTrainer(nn.Module):

        self.train_wrapper.to(self.device)

    def save(self, path):
        pkg = dict(
            model = self.transformer.state_dict(),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)

    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        self.transformer.load_state_dict(pkg['model'])
        self.optim.load_state_dict(pkg['optim'])

    def print(self, msg):
        self.accelerator.print(msg)

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.2.0',
  version = '0.2.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',