Commit 4c3c3534 authored by Phil Wang's avatar Phil Wang
Browse files

need to access unwrapped soundstream in trainer

parent 36dc9e0d
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -242,13 +242,16 @@ class SoundStreamTrainer(nn.Module):

        torch.save(pkg, path)

    @property
    def unwrapped_soundstream(self):
        return self.accelerator.unwrap_model(self.soundstream)

    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))

        soundstream = self.accelerator.unwrap_model(self.soundstream)
        soundstream.load_state_dict(pkg['model'])
        self.unwrapped_soundstream.load_state_dict(pkg['model'])

        self.ema_soundstream.load_state_dict(pkg['ema_model'])
        self.optim.load_state_dict(pkg['optim'])
@@ -259,7 +262,7 @@ class SoundStreamTrainer(nn.Module):
            discr_optim.load_state_dict(pkg[key])

    def multiscale_discriminator_iter(self):
        for ind, discr in enumerate(self.soundstream.discriminators):
        for ind, discr in enumerate(self.unwrapped_soundstream.discriminators):
            yield f'multiscale_discr_optimizer_{ind}', discr

    def print(self, msg):
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.3.4',
  version = '0.3.5',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',