Commit c53245de authored by Phil Wang's avatar Phil Wang
Browse files

do something hacky to resolve issue from past saved models

parent c37cd2d9
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -495,6 +495,10 @@ class SoundStream(nn.Module):
        x = rearrange(x, 'b n c -> b c n')
        return self.decoder(x)

    def save(self, path):
        path = Path(path)
        torch.save(self.state_dict(), str(path))

    def load(self, path):
        path = Path(path)
        assert path.exists()
+11 −0
Original line number Diff line number Diff line
@@ -264,6 +264,17 @@ class SoundStreamTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        # if loading from old version, make a hacky guess

        if len(pkg.keys()) > 20:
            self.unwrapped_soundstream.load_state_dict(pkg)

            if self.use_ema:
                self.ema_soundstream.ema_model.load_state_dict(pkg)
            return

        # otherwise load things normally

        self.unwrapped_soundstream.load_state_dict(pkg['model'])

        if self.use_ema:
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.14.1',
  version = '0.14.2',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',