Commit 22fd63bb authored by Phil Wang's avatar Phil Wang
Browse files

map location to cpu when loading state dict

parent c9b10bbc
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -18,6 +18,8 @@ from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence

from lion_pytorch import Lion

from musiclm_pytorch import MuLaN

from einops import rearrange
@@ -161,6 +163,7 @@ class MuLaNTrainer(nn.Module):
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict(),
        use_lion = False,
        force_clear_prev_results = None  # set to True | False to skip the prompt
    ):
        super().__init__()
@@ -178,7 +181,8 @@ class MuLaNTrainer(nn.Module):

        # optimizers

        self.optim = Adam(mulan.parameters(), lr = lr, betas = betas)
        optim_klass = Lion if use_lion else Adam
        self.optim = optim_klass(mulan.parameters(), lr = lr, betas = betas)

        # max grad norm

@@ -260,7 +264,7 @@ class MuLaNTrainer(nn.Module):
    def load(self, path):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        pkg = torch.load(str(path), map_location = 'cpu')

        mulan = self.accelerator.unwrap_model(self.mulan)
        mulan.load_state_dict(pkg['model'])
+2 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.25',
  version = '0.0.26',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',
@@ -23,6 +23,7 @@ setup(
    'audiolm-pytorch>=0.10.4',
    'beartype',
    'einops>=0.6',
    'lion-pytorch',
    'vector-quantize-pytorch>=1.0.0',
    'x-clip',
    'torch>=1.12',