Commit 1b5fdefc authored by Phil Wang's avatar Phil Wang
Browse files

always save version wherever possible

parent 7596ae0c
Loading
Loading
Loading
Loading
+7 −9
Original line number Diff line number Diff line
@@ -501,7 +501,12 @@ class SoundStream(nn.Module):

    def save(self, path):
        path = Path(path)
        torch.save(self.state_dict(), str(path))
        pkg = dict(
            model = self.state_dict(),
            version = __version__
        )

        torch.save(pkg, str(path))

    def load(self, path):
        path = Path(path)
@@ -513,14 +518,7 @@ class SoundStream(nn.Module):
        if 'version' in pkg and version.parse(pkg['version']) < parsed_version:
            print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})')

        # some hacky logic to remove confusion around loading trainer vs main model

        maybe_trainer_pkg = len(pkg.keys()) < 15
        if maybe_trainer_pkg:
            self.load_from_trainer_saved_obj(str(path))
            return

        self.load_state_dict()
        self.load_state_dict(pkg['model'])

    def load_from_trainer_saved_obj(self, path):
        path = Path(path)
+1 −1
Original line number Diff line number Diff line
__version__ = '0.15.5'
__version__ = '0.15.7'