Commit aaf0d2d5 authored by Phil Wang's avatar Phil Wang
Browse files

use map location to cpu when torch.load

parent 0abb04d4
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -262,7 +262,7 @@ class SoundStreamTrainer(nn.Module):
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        pkg = torch.load(str(path), map_location = 'cpu')

        self.unwrapped_soundstream.load_state_dict(pkg['model'])

@@ -585,7 +585,7 @@ class SemanticTransformerTrainer(nn.Module):
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        pkg = torch.load(str(path), map_location = 'cpu')

        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])
@@ -822,7 +822,7 @@ class CoarseTransformerTrainer(nn.Module):
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        pkg = torch.load(str(path), map_location = 'cpu')

        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])
@@ -1054,7 +1054,7 @@ class FineTransformerTrainer(nn.Module):
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        pkg = torch.load(str(path), map_location = 'cpu')

        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.14.0',
  version = '0.14.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',