Commit 6a0c7e73 authored by Phil Wang's avatar Phil Wang
Browse files

get some working scripts into readme

parent 416aaf7a
Loading
Loading
Loading
Loading
+101 −3
Original line number Diff line number Diff line
@@ -14,15 +14,110 @@ $ pip install audiolm-pytorch

## Usage

First, `SoundStream` needs to be trained on a large corpus of audio data

```python
from audiolm_pytorch import SoundStream, SoundStreamTrainer

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = '/path/to/librispeech',
    batch_size = 4,
    data_max_length = 320 * 32,
    num_train_steps = 10000
).cuda()

trainer.train()
```

Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained


ex. `SemanticTransformer`

```python
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    wav2vec = wav2vec,
    dim = 1024,
    depth = 6
).cuda()

wave = torch.randn(1, 320 * 512).cuda()

loss = semantic_transformer(
    raw_wave = wave,
    return_loss = True
)

loss.backward()
```

ex. `CoarseTransformer`

```python
import torch
from audiolm_pytorch import HubertWithKmeans, SoundStream, CoarseTransformer, CoarseTransformerWrapper

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

coarse_transformer = CoarseTransformer(
    wav2vec = wav2vec,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

coarse_wrapper = CoarseTransformerWrapper(
    wav2vec = wav2vec,
    soundstream = soundstream,
    transformer = coarse_transformer
).cuda()

wave = torch.randn(1, 32 * 320).cuda()

loss = coarse_wrapper(
    raw_wave = wave,
    return_loss = True
)

loss.backward()
```

ex. `FineTransformer`

```python
import torch
from audiolm_pytorch.audiolm_pytorch import SoundStream, AudioLM, FineTransformer, FineTransformerWrapper
from audiolm_pytorch import SoundStream, FineTransformer, FineTransformerWrapper

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load('/path/to/trained/soundstream.pt')

transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
@@ -36,16 +131,18 @@ train_wrapper = FineTransformerWrapper(
    transformer = transformer
).cuda()

raw_waveform = torch.randn(1, 320 * 512).cuda()
wave = torch.randn(1, 320 * 512).cuda()

loss = train_wrapper(
    raw_wave = raw_waveform,
    raw_wave = wave,
    return_loss = True
)

loss.backward()
```

- [ ] show how to generate from prompt tensor or file

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research
@@ -79,6 +176,7 @@ loss.backward()
- [ ] abstract out conditioning + classifier free guidance into external module or potentially a package
- [ ] add option to use flash attention
- [ ] function for pretty printing all discriminator losses to log
- [ ] simplify training even more within AudioLM class

## Citations

+14 −2
Original line number Diff line number Diff line
@@ -267,8 +267,8 @@ class SemanticTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_semantic_tokens,
        dim,
        num_semantic_tokens = None,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
@@ -279,6 +279,12 @@ class SemanticTransformer(nn.Module):
        **kwargs
    ):
        super().__init__()
        assert exists(wav2vec) or exists(num_semantic_tokens)

        if exists(wav2vec):
            num_semantic_tokens = default(num_semantic_tokens, wav2vec.codebook_size)
            assert num_semantic_tokens == wav2vec.codebook_size

        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob
@@ -362,10 +368,10 @@ class CoarseTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_semantic_tokens,
        codebook_size,
        num_coarse_quantizers,
        dim,
        num_semantic_tokens = None,
        t5_name = DEFAULT_T5_NAME,
        has_condition = False,
        cond_drop_prob = 0.5,
@@ -374,6 +380,12 @@ class CoarseTransformer(nn.Module):
        **kwargs
    ):
        super().__init__()
        assert exists(wav2vec) or exists(num_semantic_tokens)

        if exists(wav2vec):
            num_semantic_tokens = default(num_semantic_tokens, wav2vec.codebook_size)
            assert num_semantic_tokens == wav2vec.codebook_size

        self.has_condition = has_condition
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob
+6 −0
Original line number Diff line number Diff line
import functools
from pathlib import Path
from functools import partial

import torch
@@ -309,6 +310,11 @@ class SoundStream(nn.Module):
        self.adversarial_loss_weight = adversarial_loss_weight
        self.feature_loss_weight = feature_loss_weight

    def load(self, path):
        path = Path(path)
        assert path.exists()
        self.load_state_dict(torch.load(str(path)))

    def non_discr_parameters(self):
        return [*self.encoder.parameters(), *self.decoder.parameters()]

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.26',
  version = '0.0.27',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',