Commit bcf77f91 authored by Phil Wang's avatar Phil Wang
Browse files

fix issue with accelerator wrapped models when saving and loading

parent e4d31ce4
Loading
Loading
Loading
Loading
+15 −8
Original line number Diff line number Diff line
@@ -227,7 +227,7 @@ class SoundStreamTrainer(nn.Module):

    def save(self, path):
        pkg = dict(
            model = self.soundstream.state_dict(),
            model = self.accelerator.get_state_dict(self.soundstream),
            ema_model = self.ema_soundstream.state_dict(),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
@@ -244,7 +244,9 @@ class SoundStreamTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path))

        self.soundstream.load_state_dict(pkg['model'])
        soundstream = self.accelerator.unwrap_model(self.soundstream)
        soundstream.load_state_dict(pkg['model'])

        self.ema_soundstream.load_state_dict(pkg['ema_model'])
        self.optim.load_state_dict(pkg['optim'])
        self.discr_optim.load_state_dict(pkg['discr_optim'])
@@ -515,7 +517,7 @@ class SemanticTransformerTrainer(nn.Module):

    def save(self, path):
        pkg = dict(
            model = self.transformer.state_dict(),
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)
@@ -525,7 +527,8 @@ class SemanticTransformerTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path))

        self.transformer.load_state_dict(pkg['model'])
        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])
        self.optim.load_state_dict(pkg['optim'])

    def print(self, msg):
@@ -743,7 +746,7 @@ class CoarseTransformerTrainer(nn.Module):

    def save(self, path):
        pkg = dict(
            model = self.transformer.state_dict(),
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)
@@ -753,7 +756,9 @@ class CoarseTransformerTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path))

        self.transformer.load_state_dict(pkg['model'])
        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])

        self.optim.load_state_dict(pkg['optim'])

    def print(self, msg):
@@ -964,7 +969,7 @@ class FineTransformerTrainer(nn.Module):

    def save(self, path):
        pkg = dict(
            model = self.transformer.state_dict(),
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict()
        )
        torch.save(pkg, path)
@@ -974,7 +979,9 @@ class FineTransformerTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path))

        self.transformer.load_state_dict(pkg['model'])
        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])

        self.optim.load_state_dict(pkg['optim'])

    def print(self, msg):
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.2.2',
  version = '0.2.3',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',