Loading audiolm_pytorch/trainer.py +15 −8 Original line number Diff line number Diff line Loading @@ -227,7 +227,7 @@ class SoundStreamTrainer(nn.Module): def save(self, path): pkg = dict( model = self.soundstream.state_dict(), model = self.accelerator.get_state_dict(self.soundstream), ema_model = self.ema_soundstream.state_dict(), optim = self.optim.state_dict(), discr_optim = self.discr_optim.state_dict() Loading @@ -244,7 +244,9 @@ class SoundStreamTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.soundstream.load_state_dict(pkg['model']) soundstream = self.accelerator.unwrap_model(self.soundstream) soundstream.load_state_dict(pkg['model']) self.ema_soundstream.load_state_dict(pkg['ema_model']) self.optim.load_state_dict(pkg['optim']) self.discr_optim.load_state_dict(pkg['discr_optim']) Loading Loading @@ -515,7 +517,7 @@ class SemanticTransformerTrainer(nn.Module): def save(self, path): pkg = dict( model = self.transformer.state_dict(), model = self.accelerator.get_state_dict(self.transformer), optim = self.optim.state_dict() ) torch.save(pkg, path) Loading @@ -525,7 +527,8 @@ class SemanticTransformerTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) transformer = self.accelerator.unwrap_model(self.transformer) transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): Loading Loading @@ -743,7 +746,7 @@ class CoarseTransformerTrainer(nn.Module): def save(self, path): pkg = dict( model = self.transformer.state_dict(), model = self.accelerator.get_state_dict(self.transformer), optim = self.optim.state_dict() ) torch.save(pkg, path) Loading @@ -753,7 +756,9 @@ class CoarseTransformerTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) transformer = self.accelerator.unwrap_model(self.transformer) transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): Loading Loading @@ -964,7 +969,7 @@ class FineTransformerTrainer(nn.Module): def save(self, path): pkg = dict( model = self.transformer.state_dict(), model = self.accelerator.get_state_dict(self.transformer), optim = self.optim.state_dict() ) torch.save(pkg, path) Loading @@ -974,7 +979,9 @@ class FineTransformerTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) transformer = self.accelerator.unwrap_model(self.transformer) transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.2.2', version = '0.2.3', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/trainer.py +15 −8 Original line number Diff line number Diff line Loading @@ -227,7 +227,7 @@ class SoundStreamTrainer(nn.Module): def save(self, path): pkg = dict( model = self.soundstream.state_dict(), model = self.accelerator.get_state_dict(self.soundstream), ema_model = self.ema_soundstream.state_dict(), optim = self.optim.state_dict(), discr_optim = self.discr_optim.state_dict() Loading @@ -244,7 +244,9 @@ class SoundStreamTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.soundstream.load_state_dict(pkg['model']) soundstream = self.accelerator.unwrap_model(self.soundstream) soundstream.load_state_dict(pkg['model']) self.ema_soundstream.load_state_dict(pkg['ema_model']) self.optim.load_state_dict(pkg['optim']) self.discr_optim.load_state_dict(pkg['discr_optim']) Loading Loading @@ -515,7 +517,7 @@ class SemanticTransformerTrainer(nn.Module): def save(self, path): pkg = dict( model = self.transformer.state_dict(), model = self.accelerator.get_state_dict(self.transformer), optim = self.optim.state_dict() ) torch.save(pkg, path) Loading @@ -525,7 +527,8 @@ class SemanticTransformerTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) transformer = self.accelerator.unwrap_model(self.transformer) transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): Loading Loading @@ -743,7 +746,7 @@ class CoarseTransformerTrainer(nn.Module): def save(self, path): pkg = dict( model = self.transformer.state_dict(), model = self.accelerator.get_state_dict(self.transformer), optim = self.optim.state_dict() ) torch.save(pkg, path) Loading @@ -753,7 +756,9 @@ class CoarseTransformerTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) transformer = self.accelerator.unwrap_model(self.transformer) transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): Loading Loading @@ -964,7 +969,7 @@ class FineTransformerTrainer(nn.Module): def save(self, path): pkg = dict( model = self.transformer.state_dict(), model = self.accelerator.get_state_dict(self.transformer), optim = self.optim.state_dict() ) torch.save(pkg, path) Loading @@ -974,7 +979,9 @@ class FineTransformerTrainer(nn.Module): assert path.exists() pkg = torch.load(str(path)) self.transformer.load_state_dict(pkg['model']) transformer = self.accelerator.unwrap_model(self.transformer) transformer.load_state_dict(pkg['model']) self.optim.load_state_dict(pkg['optim']) def print(self, msg): Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.2.2', version = '0.2.3', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading