Commit 8d8b255c authored by Phil Wang's avatar Phil Wang
Browse files

release saving of soundstream configurations within checkpoint, and...

release saving of soundstream configurations within checkpoint, and reconstitution with a classmethod
parent 5889e518
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -40,6 +40,8 @@ $ pip install audiolm-pytorch

## Usage

### SoundStream

First, `SoundStream` needs to be trained on a large corpus of audio data

```python
@@ -79,6 +81,16 @@ soundstream = AudioLMSoundStream(...) # say you want the hyperparameters as in A
# rest is the same as above
```

As of version `0.17.0`, you can now invoke the class method on `SoundStream` to load from checkpoint files, without having to remember your configurations.

```python
from audiolm_pytorch import SoundStream

soundstream = SoundStream.init_and_load_from('./path/to/checkpoint.pt')
```

### Hierarchical Transformers

Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained


+49 −27
Original line number Diff line number Diff line
@@ -28,6 +28,8 @@ from audiolm_pytorch.version import __version__
from packaging import version
parsed_version = version.parse(__version__)

import pickle

# helper functions

def exists(val):
@@ -385,6 +387,16 @@ class SoundStream(nn.Module):
        attn_depth = 1
    ):
        super().__init__()

        # for autosaving the config

        _locals = locals()
        _locals.pop('self', None)
        _locals.pop('__class__', None)
        self._configs = pickle.dumps(_locals)

        # rest of the class

        self.target_sample_hz = target_sample_hz # for resampling on the fly

        self.single_channel = input_channels == 1
@@ -510,15 +522,29 @@ class SoundStream(nn.Module):
        path = Path(path)
        pkg = dict(
            model = self.state_dict(),
            config = self._configs,
            version = __version__
        )

        torch.save(pkg, str(path))

    @classmethod
    def init_and_load_from(cls, path, strict = True):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path), map_location = 'cpu')

        assert 'config' in pkg, 'model configs were not found in this saved checkpoint'

        config = pickle.loads(pkg['config'])
        soundstream = cls(**config)
        soundstream.load(path, strict = strict)
        return soundstream

    def load(self, path, strict = True):
        path = Path(path)
        assert path.exists()
        pkg = torch.load(str(path))
        pkg = torch.load(str(path), map_location = 'cpu')

        # check version

@@ -724,30 +750,26 @@ class SoundStream(nn.Module):

# some default soundstreams

class AudioLMSoundStream(SoundStream):
    def __init__(
        self,
def AudioLMSoundStream(
    strides = (2, 4, 5, 8),
    target_sample_hz = 16000,
    rq_num_quantizers = 12,
    **kwargs
):
        super().__init__(
    return SoundStream(
        strides = strides,
        target_sample_hz = target_sample_hz,
        rq_num_quantizers = rq_num_quantizers,
        **kwargs
    )

class MusicLMSoundStream(SoundStream):
    def __init__(
        self,
def MusicLMSoundStream(
    strides = (3, 4, 5, 8),
    target_sample_hz = 24000,
    rq_num_quantizers = 12,
    **kwargs
):
        super().__init__(
    return SoundStream(
        strides = strides,
        target_sample_hz = target_sample_hz,
        rq_num_quantizers = rq_num_quantizers,
+1 −0
Original line number Diff line number Diff line
@@ -257,6 +257,7 @@ class SoundStreamTrainer(nn.Module):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
            optim = self.optim.state_dict(),
            config = self.soundstream._configs,
            discr_optim = self.discr_optim.state_dict(),
            version = __version__
        )
+1 −1
Original line number Diff line number Diff line
__version__ = '0.16.2'
__version__ = '0.17.0'