Loading audiolm_pytorch/trainer.py +72 −1 Original line number Diff line number Diff line Loading @@ -225,6 +225,32 @@ class SoundStreamTrainer(nn.Module): self.results_folder.mkdir(parents = True, exist_ok = True) def save(self, path): pkg = dict( model = self.soundstream.state_dict(), ema_model = self.ema_soundstream.state_dict(), discr_optim = self.discr_optim.state_dict() ) for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) pkg[key] = discr_optim.state_dict() torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.soundstream.load_state_dict(pkg['model']) self.ema_soundstream.load_state_dict(pkg['ema_model']) self.discr_optim.load_state_dict(pkg['discr_optim']) for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) discr_optim.load_state_dict(pkg[key]) def multiscale_discriminator_iter(self): for ind, discr in enumerate(self.soundstream.discriminators): yield f'multiscale_discr_optimizer_{ind}', discr Loading Loading @@ -485,6 +511,21 @@ class SemanticTransformerTrainer(nn.Module): self.results_folder.mkdir(parents = True, exist_ok = True) def save(self, path): pkg = dict( model = self.transformer.state_dict(), optim = self.optim.state_dict() ) torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): self.accelerator.print(msg) Loading Loading @@ -698,6 +739,21 @@ class CoarseTransformerTrainer(nn.Module): self.train_wrapper.to(self.device) def save(self, path): pkg = dict( model = self.transformer.state_dict(), optim = self.optim.state_dict() ) torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): self.accelerator.print(msg) Loading Loading @@ -904,6 +960,21 @@ class FineTransformerTrainer(nn.Module): self.train_wrapper.to(self.device) def save(self, path): pkg = dict( model = self.transformer.state_dict(), optim = self.optim.state_dict() ) torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): self.accelerator.print(msg) Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.2.0', version = '0.2.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/trainer.py +72 −1 Original line number Diff line number Diff line Loading @@ -225,6 +225,32 @@ class SoundStreamTrainer(nn.Module): self.results_folder.mkdir(parents = True, exist_ok = True) def save(self, path): pkg = dict( model = self.soundstream.state_dict(), ema_model = self.ema_soundstream.state_dict(), discr_optim = self.discr_optim.state_dict() ) for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) pkg[key] = discr_optim.state_dict() torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.soundstream.load_state_dict(pkg['model']) self.ema_soundstream.load_state_dict(pkg['ema_model']) self.discr_optim.load_state_dict(pkg['discr_optim']) for key, _ in self.multiscale_discriminator_iter(): discr_optim = getattr(self, key) discr_optim.load_state_dict(pkg[key]) def multiscale_discriminator_iter(self): for ind, discr in enumerate(self.soundstream.discriminators): yield f'multiscale_discr_optimizer_{ind}', discr Loading Loading @@ -485,6 +511,21 @@ class SemanticTransformerTrainer(nn.Module): self.results_folder.mkdir(parents = True, exist_ok = True) def save(self, path): pkg = dict( model = self.transformer.state_dict(), optim = self.optim.state_dict() ) torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): self.accelerator.print(msg) Loading Loading @@ -698,6 +739,21 @@ class CoarseTransformerTrainer(nn.Module): self.train_wrapper.to(self.device) def save(self, path): pkg = dict( model = self.transformer.state_dict(), optim = self.optim.state_dict() ) torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): self.accelerator.print(msg) Loading Loading @@ -904,6 +960,21 @@ class FineTransformerTrainer(nn.Module): self.train_wrapper.to(self.device) def save(self, path): pkg = dict( model = self.transformer.state_dict(), optim = self.optim.state_dict() ) torch.save(pkg, path) def load(self, path): path = Path(path) assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): self.accelerator.print(msg) Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.2.0', version = '0.2.1', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading