Commit 5889e518 authored by Phil Wang's avatar Phil Wang
Browse files

make sure soundstream load function loads the ema version if available, also...

make sure soundstream load function loads the ema version if available, also allow for turning of strict loading, so if network has changed, it can be loaded and retrained for an epoch thus salvaged
parent 21a07ab9
Loading
Loading
Loading
Loading
+15 −2
Original line number Diff line number Diff line
@@ -39,6 +39,12 @@ def default(val, d):
def cast_tuple(t, l = 1):
    return ((t,) * l) if not isinstance(t, tuple) else t

def filter_by_keys(fn, d):
    return {k: v for k, v in d.items() if fn(k)}

def map_keys(fn, d):
    return {fn(k): v for k, v in d.items()}

# gan losses

def log(t, eps = 1e-20):
@@ -509,7 +515,7 @@ class SoundStream(nn.Module):

        torch.save(pkg, str(path))

    def load(self, path):
    def load(self, path, strict = True):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
@@ -519,7 +525,14 @@ class SoundStream(nn.Module):
        if 'version' in pkg and version.parse(pkg['version']) < parsed_version:
            print(f'soundstream model being loaded was trained on an older version of audiolm-pytorch ({pkg["version"]})')

        self.load_state_dict(pkg['model'])
        has_ema = 'ema_model' in pkg
        model_pkg = pkg['ema_model'] if has_ema else pkg['model']

        if has_ema:
            model_pkg = filter_by_keys(lambda k: k.startswith('ema_model.'), model_pkg)
            model_pkg = map_keys(lambda k: k[len('ema_model.'):], model_pkg)

        self.load_state_dict(model_pkg, strict = strict)

    def load_from_trainer_saved_obj(self, path):
        path = Path(path)
+1 −1
Original line number Diff line number Diff line
__version__ = '0.16.1'
__version__ = '0.16.2'