Commit cd94fae5 authored by Phil Wang's avatar Phil Wang
Browse files

start saving version numbers along with model, so researchers can trace back...

start saving version numbers along with model, so researchers can trace back to working package version if model was trained in the past
parent bee80b5e
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -23,6 +23,10 @@ from local_attention.transformer import FeedForward
from mega_pytorch import MultiHeadedEMA
from audiolm_pytorch.utils import curtail_to_multiple

from audiolm_pytorch.version import __version__
from packaging import version
parsed_version = version.parse(__version__)

# helper functions

def exists(val):
@@ -504,6 +508,11 @@ class SoundStream(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path))

        # check version

        if 'version' in pkg and version.parse(pkg['version']) < parsed_version:
            print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})')

        # some hacky logic to remove confusion around loading trainer vs main model

        maybe_trainer_pkg = len(pkg.keys()) < 15
+31 −4
Original line number Diff line number Diff line
@@ -38,6 +38,9 @@ from audiolm_pytorch.audiolm_pytorch import (
from audiolm_pytorch.data import SoundDataset, get_dataloader
from audiolm_pytorch.utils import AudioConditionerBase

from audiolm_pytorch.version import __version__
from packaging import version

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

@@ -254,7 +257,8 @@ class SoundStreamTrainer(nn.Module):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
            optim = self.optim.state_dict(),
            discr_optim = self.discr_optim.state_dict()
            discr_optim = self.discr_optim.state_dict(),
            version = __version__
        )

        if self.use_ema:
@@ -284,6 +288,11 @@ class SoundStreamTrainer(nn.Module):
                self.ema_soundstream.ema_model.load_state_dict(pkg)
            return

        # check version

        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')

        # otherwise load things normally

        self.unwrapped_soundstream.load_state_dict(pkg['model'])
@@ -606,7 +615,8 @@ class SemanticTransformerTrainer(nn.Module):
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict()
            optim = self.optim.state_dict(),
            version = __version__
        )
        torch.save(pkg, path)

@@ -615,6 +625,11 @@ class SemanticTransformerTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        # check version

        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')

        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])
        self.optim.load_state_dict(pkg['optim'])
@@ -849,7 +864,8 @@ class CoarseTransformerTrainer(nn.Module):
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict()
            optim = self.optim.state_dict(),
            version = __version__
        )
        torch.save(pkg, path)

@@ -858,6 +874,11 @@ class CoarseTransformerTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        # check version

        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')

        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])

@@ -1087,7 +1108,8 @@ class FineTransformerTrainer(nn.Module):
    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
            optim = self.optim.state_dict()
            optim = self.optim.state_dict(),
            version = __version__
        )
        torch.save(pkg, path)

@@ -1096,6 +1118,11 @@ class FineTransformerTrainer(nn.Module):
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        # check version

        if 'version' in pkg and version.parse(pkg['version']) < version.parse(__version__):
            print(f'model was trained on older version {pkg["version"]} of audiolm-pytorch')

        transformer = self.accelerator.unwrap_model(self.transformer)
        transformer.load_state_dict(pkg['model'])

+1 −0
Original line number Diff line number Diff line
__version__ = '0.15.4'
+2 −1
Original line number Diff line number Diff line
from setuptools import setup, find_packages
exec(open('audiolm_pytorch/version.py').read())

setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.15.3',
  version = __version__,
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',