Commit 255b578f authored by Phil Wang's avatar Phil Wang
Browse files

more hacky logic to autoresolve confusion around loading from trainer vs soundstream model directly

parent f475b095
Loading
Loading
Loading
Loading
+14 −5
Original line number Diff line number Diff line
@@ -502,14 +502,23 @@ class SoundStream(nn.Module):
    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))
        pkg = torch.load(str(path))

    def load_from_trainer_saved_obj(self, path, ema = False):
        key = 'ema_model' if ema else 'model'
        # some hacky logic to remove confusion around loading trainer vs main model

        maybe_trainer_pkg = len(pkg.keys()) < 15
        if maybe_trainer_pkg:
            self.load_from_trainer_saved_obj(str(path))
            return

        self.load_state_dict()

    def load_from_trainer_saved_obj(self, path):
        path = Path(path)
        assert path.exists()
        trainer_obj = torch.load(str(path))
        self.load_state_dict(trainer_obj[key])
        obj = torch.load(str(path))
        self.load_state_dict(obj['model'])
        exit()

    def non_discr_parameters(self):
        return [
+5 −0
Original line number Diff line number Diff line
@@ -245,6 +245,11 @@ class SoundStreamTrainer(nn.Module):
        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("soundstream", config=hps)        

    def set_model_as_ema_model_(self):
        """ this will force the main 'online' model to have same parameters as the exponentially moving averaged model """
        assert self.use_ema
        self.ema_soundstream.ema_model.load_state_dict(self.soundstream.state_dict())

    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.15.0',
  version = '0.15.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',