Loading audiolm_pytorch/soundstream.py +7 −9 Original line number Diff line number Diff line Loading @@ -501,7 +501,12 @@ class SoundStream(nn.Module): def save(self, path): path = Path(path) torch.save(self.state_dict(), str(path)) pkg = dict( model = self.state_dict(), version = __version__ ) torch.save(pkg, str(path)) def load(self, path): path = Path(path) Loading @@ -513,14 +518,7 @@ class SoundStream(nn.Module): if 'version' in pkg and version.parse(pkg['version']) < parsed_version: print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})') # some hacky logic to remove confusion around loading trainer vs main model maybe_trainer_pkg = len(pkg.keys()) < 15 if maybe_trainer_pkg: self.load_from_trainer_saved_obj(str(path)) return self.load_state_dict() self.load_state_dict(pkg['model']) def load_from_trainer_saved_obj(self, path): path = Path(path) Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.15.5' __version__ = '0.15.7' Loading
audiolm_pytorch/soundstream.py +7 −9 Original line number Diff line number Diff line Loading @@ -501,7 +501,12 @@ class SoundStream(nn.Module): def save(self, path): path = Path(path) torch.save(self.state_dict(), str(path)) pkg = dict( model = self.state_dict(), version = __version__ ) torch.save(pkg, str(path)) def load(self, path): path = Path(path) Loading @@ -513,14 +518,7 @@ class SoundStream(nn.Module): if 'version' in pkg and version.parse(pkg['version']) < parsed_version: print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})') # some hacky logic to remove confusion around loading trainer vs main model maybe_trainer_pkg = len(pkg.keys()) < 15 if maybe_trainer_pkg: self.load_from_trainer_saved_obj(str(path)) return self.load_state_dict() self.load_state_dict(pkg['model']) def load_from_trainer_saved_obj(self, path): path = Path(path) Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.15.5' __version__ = '0.15.7'