Loading audiolm_pytorch/soundstream.py +14 −5 Original line number Diff line number Diff line Loading @@ -502,14 +502,23 @@ class SoundStream(nn.Module): def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) pkg = torch.load(str(path)) def load_from_trainer_saved_obj(self, path, ema = False): key = 'ema_model' if ema else 'model' # some hacky logic to remove confusion around loading trainer vs main model maybe_trainer_pkg = len(pkg.keys()) < 15 if maybe_trainer_pkg: self.load_from_trainer_saved_obj(str(path)) return self.load_state_dict() def load_from_trainer_saved_obj(self, path): path = Path(path) assert path.exists() trainer_obj = torch.load(str(path)) self.load_state_dict(trainer_obj[key]) obj = torch.load(str(path)) self.load_state_dict(obj['model']) exit() def non_discr_parameters(self): return [ Loading audiolm_pytorch/trainer.py +5 −0 Original line number Diff line number Diff line Loading @@ -245,6 +245,11 @@ class SoundStreamTrainer(nn.Module): hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr} self.accelerator.init_trackers("soundstream", config=hps) def set_model_as_ema_model_(self): """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """ assert self.use_ema self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict()) def save(self, path): pkg = dict( model = self.accelerator.get_state_dict(self.soundstream), Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.15.0', version = '0.15.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/soundstream.py +14 −5 Original line number Diff line number Diff line Loading @@ -502,14 +502,23 @@ class SoundStream(nn.Module): def load(self, path): path = Path(path) assert path.exists() self.load_state_dict(torch.load(str(path))) pkg = torch.load(str(path)) def load_from_trainer_saved_obj(self, path, ema = False): key = 'ema_model' if ema else 'model' # some hacky logic to remove confusion around loading trainer vs main model maybe_trainer_pkg = len(pkg.keys()) < 15 if maybe_trainer_pkg: self.load_from_trainer_saved_obj(str(path)) return self.load_state_dict() def load_from_trainer_saved_obj(self, path): path = Path(path) assert path.exists() trainer_obj = torch.load(str(path)) self.load_state_dict(trainer_obj[key]) obj = torch.load(str(path)) self.load_state_dict(obj['model']) exit() def non_discr_parameters(self): return [ Loading
audiolm_pytorch/trainer.py +5 −0 Original line number Diff line number Diff line Loading @@ -245,6 +245,11 @@ class SoundStreamTrainer(nn.Module): hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr} self.accelerator.init_trackers("soundstream", config=hps) def set_model_as_ema_model_(self): """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """ assert self.use_ema self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict()) def save(self, path): pkg = dict( model = self.accelerator.get_state_dict(self.soundstream), Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.15.0', version = '0.15.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading