Commit 310deea3 authored by Leon Wu's avatar Leon Wu
Browse files

fix readme and demo jupyter notebook

parent 2e301432
Loading
Loading
Loading
Loading
+12 −5
Original line number Diff line number Diff line
@@ -42,9 +42,16 @@ $ pip install audiolm-pytorch

## Usage

### SoundStream
### SoundStream & Encodec

First, `SoundStream` needs to be trained on a large corpus of audio data
There are two options for the neural codec. If you want to use the pretrained 24kHz Encodec, just create an Encodec object as follows:
```python
from audiolm_pytorch import EncodecWrapper
encodec = EncodecWrapper()
# Now you can use the encodec variable in the same way you'd use the soundstream variables below.
```

Otherwise, to stay more true to the original paper, you can use `SoundStream`. First, `SoundStream` needs to be trained on a large corpus of audio data

```python
from audiolm_pytorch import SoundStream, SoundStreamTrainer
@@ -152,7 +159,7 @@ coarse_transformer = CoarseTransformer(

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = '/path/to/audio/files',
    batch_size = 1,
@@ -181,7 +188,7 @@ fine_transformer = FineTransformer(

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    codec = soundstream,
    folder = '/path/to/audio/files',
    batch_size = 1,
    data_max_length = 320 * 32,
@@ -198,7 +205,7 @@ from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
+15 −15
Original line number Diff line number Diff line
%% Cell type:code id: tags:

``` 
``` python
!nvidia-smi

# If this doesn't work, there's no GPU available or detected
```

%% Output

    Mon Jan 30 20:47:47 2023
    +-----------------------------------------------------------------------------+
    | NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
    |-------------------------------+----------------------+----------------------+
    | GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
    | Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
    |                               |                      |               MIG M. |
    |===============================+======================+======================|
    |   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
    | N/A   73C    P0    32W /  70W |  10692MiB / 15360MiB |      0%      Default |
    |                               |                      |                  N/A |
    +-------------------------------+----------------------+----------------------+
    
    +-----------------------------------------------------------------------------+
    | Processes:                                                                  |
    |  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
    |        ID   ID                                                   Usage      |
    |=============================================================================|
    |    0   N/A  N/A      5896      C                                   10689MiB |
    +-----------------------------------------------------------------------------+

%% Cell type:code id: tags:

``` 
``` python
!pip install audiolm-pytorch
```

%% Output

    Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
    Requirement already satisfied: audiolm-pytorch in /usr/local/lib/python3.8/dist-packages (0.7.5)
    Requirement already satisfied: ema-pytorch in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.1.4)
    Requirement already satisfied: sentencepiece in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.1.97)
    Requirement already satisfied: beartype in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.12.0)
    Requirement already satisfied: scikit-learn in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.0.2)
    Requirement already satisfied: torchaudio in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.13.1+cu116)
    Requirement already satisfied: joblib in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.2.0)
    Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.13.1+cu116)
    Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (4.26.0)
    Requirement already satisfied: Mega-pytorch in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.0.12)
    Requirement already satisfied: tqdm in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (4.64.1)
    Requirement already satisfied: accelerate in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.15.0)
    Requirement already satisfied: vector-quantize-pytorch>=0.10.15 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.10.15)
    Requirement already satisfied: einops>=0.6 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.6.0)
    Requirement already satisfied: local-attention>=1.5.7 in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (1.5.8)
    Requirement already satisfied: fairseq in /usr/local/lib/python3.8/dist-packages (from audiolm-pytorch) (0.12.2)
    Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.6->audiolm-pytorch) (4.4.0)
    Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (21.3)
    Requirement already satisfied: pyyaml in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (6.0)
    Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (1.21.6)
    Requirement already satisfied: psutil in /usr/local/lib/python3.8/dist-packages (from accelerate->audiolm-pytorch) (5.4.8)
    Requirement already satisfied: bitarray in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.6.2)
    Requirement already satisfied: regex in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2022.6.2)
    Requirement already satisfied: hydra-core<1.1,>=1.0.7 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (1.0.7)
    Requirement already satisfied: cffi in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (1.15.1)
    Requirement already satisfied: cython in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (0.29.33)
    Requirement already satisfied: omegaconf<2.1 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.0.6)
    Requirement already satisfied: sacrebleu>=1.4.12 in /usr/local/lib/python3.8/dist-packages (from fairseq->audiolm-pytorch) (2.3.1)
    Requirement already satisfied: scipy in /usr/local/lib/python3.8/dist-packages (from Mega-pytorch->audiolm-pytorch) (1.7.3)
    Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.8/dist-packages (from scikit-learn->audiolm-pytorch) (3.1.0)
    Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (2.25.1)
    Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (0.12.0)
    Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (0.13.2)
    Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers->audiolm-pytorch) (3.9.0)
    Requirement already satisfied: antlr4-python3-runtime==4.8 in /usr/local/lib/python3.8/dist-packages (from hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (4.8)
    Requirement already satisfied: importlib-resources in /usr/local/lib/python3.8/dist-packages (from hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (5.10.2)
    Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.8/dist-packages (from packaging>=20.0->accelerate->audiolm-pytorch) (3.0.9)
    Requirement already satisfied: portalocker in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (2.7.0)
    Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (0.8.10)
    Requirement already satisfied: colorama in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (0.4.6)
    Requirement already satisfied: lxml in /usr/local/lib/python3.8/dist-packages (from sacrebleu>=1.4.12->fairseq->audiolm-pytorch) (4.9.2)
    Requirement already satisfied: pycparser in /usr/local/lib/python3.8/dist-packages (from cffi->fairseq->audiolm-pytorch) (2.21)
    Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (2.10)
    Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (2022.12.7)
    Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (1.24.3)
    Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers->audiolm-pytorch) (4.0.0)
    Requirement already satisfied: zipp>=3.1.0 in /usr/local/lib/python3.8/dist-packages (from importlib-resources->hydra-core<1.1,>=1.0.7->fairseq->audiolm-pytorch) (3.11.0)

%% Cell type:markdown id: tags:

## Setup

Includes:

- How to generate a placeholder dataset if you haven't already, just the basics to run "training" e2e on a tiny dataset
- How to download a dataset from OpenSLR

%% Cell type:markdown id: tags:

### Imports & paths

%% Cell type:code id: tags:

``` 
``` python
# imports
import math
import wave
import struct
import os
import urllib.request
import tarfile
from audiolm_pytorch import SoundStream, SoundStreamTrainer, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, HubertWithKmeans, CoarseTransformer, CoarseTransformerWrapper, CoarseTransformerTrainer, FineTransformer, FineTransformerWrapper, FineTransformerTrainer, AudioLM
from torch import nn
import torch
import torchaudio


# define all dataset paths, checkpoints, etc
dataset_folder = "placeholder_dataset"
soundstream_ckpt = "results/soundstream.8.pt" # this can change depending on number of steps
hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer
```

%% Cell type:markdown id: tags:

### Data

%% Cell type:code id: tags:

``` 
``` python
# Placeholder data generation
def get_sinewave(freq=440.0, duration_ms=200, volume=1.0, sample_rate=44100.0):
  # code adapted from https://stackoverflow.com/a/33913403
  audio = []
  num_samples = duration_ms * (sample_rate / 1000.0)
  for x in range(int(num_samples)):
    audio.append(volume * math.sin(2 * math.pi * freq * (x / sample_rate)))
  return audio

def save_wav(file_name, audio, sample_rate=44100.0):
  # Open up a wav file
  wav_file=wave.open(file_name,"w")
  # wav params
  nchannels = 1
  sampwidth = 2
  # 44100 is the industry standard sample rate - CD quality.  If you need to
  # save on file size you can adjust it downwards. The stanard for low quality
  # is 8000 or 8kHz.
  nframes = len(audio)
  comptype = "NONE"
  compname = "not compressed"
  wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname))
  # WAV files here are using short, 16 bit, signed integers for the
  # sample size.  So we multiply the floating point data we have by 32767, the
  # maximum value for a short integer.  NOTE: It is theortically possible to
  # use the floating point -1.0 to 1.0 data directly in a WAV file but not
  # obvious how to do that using the wave module in python.
  for sample in audio:
      wav_file.writeframes(struct.pack('h', int( sample * 32767.0 )))
  wav_file.close()
  return

def make_placeholder_dataset():
  # Make a placeholder dataset with a few .wav files that you can "train" on, just to verify things work e2e
  if os.path.isdir(dataset_folder):
    return
  os.makedirs(dataset_folder)
  save_wav(f"{dataset_folder}/example.wav", get_sinewave())
  save_wav(f"{dataset_folder}/example2.wav", get_sinewave(duration_ms=500))
  os.makedirs(f"{dataset_folder}/subdirectory")
  save_wav(f"{dataset_folder}/subdirectory/example.wav", get_sinewave(freq=330.0))

make_placeholder_dataset()
```

%% Cell type:code id: tags:

``` 
``` python
# Get actual dataset. Uncomment this if you want to try training on real data

# full dataset: https://www.openslr.org/12
# We'll use https://us.openslr.org/resources/12/dev-clean.tar.gz development set, "clean" speech.
# We *should* train on, well, training, but this is just to demo running things end-to-end at all so I just picked a small clean set.

# url = "https://us.openslr.org/resources/12/dev-clean.tar.gz"
# filename = "dev-clean"
# filename_targz = filename + ".tar.gz"
# if not os.path.isfile(filename_targz):
#   urllib.request.urlretrieve(url, filename_targz)
# if not os.path.isdir(filename):
#   # open file
#   with tarfile.open(filename_targz) as t:
#     t.extractall(filename)
```

%% Cell type:markdown id: tags:

## Training

Now that we have a dataset, we can train AudioLM.

**Note**: do NOT type "y" to overwrite previous experiments/ checkpoints when running through the cells here unless you're ready to the entire results folder! Otherwise you will end up erasing things (e.g. you train SoundStream first, and if you choose "overwrite" then you lose the SoundStream checkpoint when you then train SemanticTransformer).

%% Cell type:markdown id: tags:

### SoundStream

%% Cell type:code id: tags:

``` 
``` python
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

trainer = SoundStreamTrainer(
    soundstream,
    folder = dataset_folder,
    batch_size = 4,
    grad_accum_every = 8,         # effective batch size of 32
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
).cuda()
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()
```

%% Output

    training with dataset of 2 samples and validating with randomly splitted 1 samples
    0: soundstream total loss: 167.262, soundstream recon loss: 1.123 | discr (scale 1) loss: 2.003 | discr (scale 0.5) loss: 1.999 | discr (scale 0.25) loss: 1.999
    0: saving to results
    0: saving model to results
    1: soundstream total loss: 182.282, soundstream recon loss: 1.389 | discr (scale 1) loss: 1.938 | discr (scale 0.5) loss: 1.928 | discr (scale 0.25) loss: 1.928
    2: soundstream total loss: 196.668, soundstream recon loss: 1.450 | discr (scale 1) loss: 1.845 | discr (scale 0.5) loss: 1.842 | discr (scale 0.25) loss: 1.843
    2: saving to results
    3: soundstream total loss: 216.329, soundstream recon loss: 1.451 | discr (scale 1) loss: 1.751 | discr (scale 0.5) loss: 1.750 | discr (scale 0.25) loss: 1.757
    4: soundstream total loss: 206.804, soundstream recon loss: 1.167 | discr (scale 1) loss: 1.671 | discr (scale 0.5) loss: 1.706 | discr (scale 0.25) loss: 1.724
    4: saving to results
    4: saving model to results
    5: soundstream total loss: 195.325, soundstream recon loss: 0.929 | discr (scale 1) loss: 1.348 | discr (scale 0.5) loss: 1.372 | discr (scale 0.25) loss: 1.482
    6: soundstream total loss: 245.195, soundstream recon loss: 1.054 | discr (scale 1) loss: 1.060 | discr (scale 0.5) loss: 1.244 | discr (scale 0.25) loss: 1.288
    6: saving to results
    7: soundstream total loss: 245.724, soundstream recon loss: 0.970 | discr (scale 1) loss: 1.092 | discr (scale 0.5) loss: 1.358 | discr (scale 0.25) loss: 1.079
    8: soundstream total loss: 202.707, soundstream recon loss: 0.786 | discr (scale 1) loss: 0.733 | discr (scale 0.5) loss: 0.687 | discr (scale 0.25) loss: 0.790
    8: saving to results
    8: saving model to results
    training complete

%% Cell type:markdown id: tags:

### SemanticTransformer

%% Cell type:code id: tags:

``` 
``` python
# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()


trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1
)

trainer.train()
```

%% Output

    /usr/local/lib/python3.8/dist-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MiniBatchKMeans from version 0.24.0 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
    https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations
      warnings.warn(

    training with dataset of 2 samples and validating with randomly splitted 1 samples
    do you want to clear previous experiment checkpoints and results? (y/n) n
    0: loss: 6.648584365844727
    0: valid loss 5.763116359710693
    0: saving model to results
    training complete

%% Cell type:markdown id: tags:

### CoarseTransformer

%% Cell type:code id: tags:

``` 
``` python
wav2vec = HubertWithKmeans(
    checkpoint_path = f'./{hubert_ckpt}',
    kmeans_path = f'./{hubert_quantizer}'
)

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

trainer = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    soundstream = soundstream,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    save_results_every = 2,
    save_model_every = 4,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()
```

%% Output

    /usr/local/lib/python3.8/dist-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MiniBatchKMeans from version 0.24.0 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
    https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations
      warnings.warn(

    training with dataset of 2 samples and validating with randomly splitted 1 samples
    do you want to clear previous experiment checkpoints and results? (y/n) n
    0: loss: 63.983970642089844
    0: valid loss 63.398582458496094
    0: saving model to results
    1: loss: 65.85967254638672
    2: loss: 62.4722900390625
    2: valid loss 50.01605987548828
    3: loss: 11.735434532165527
    4: loss: 3.976104497909546
    4: valid loss 46.094608306884766
    4: saving model to results
    5: loss: 58.27140426635742
    6: loss: 41.68347930908203
    6: valid loss 45.54595184326172
    7: loss: 2.2387890815734863
    8: loss: 0.4718627631664276
    8: valid loss 39.10848617553711
    8: saving model to results
    training complete

%% Cell type:markdown id: tags:

### FineTransformer

%% Cell type:code id: tags:

``` 
``` python
soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

soundstream.load(f"./{soundstream_ckpt}")

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

trainer = FineTransformerTrainer(
    transformer = fine_transformer,
    soundstream = soundstream,
    codec = soundstream,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 9
)
# NOTE: I changed num_train_steps to 9 (aka 8 + 1) from 10000 to make things go faster for demo purposes
# adjusting save_*_every variables for the same reason

trainer.train()
```

%% Output

    training with dataset of 2 samples and validating with randomly splitted 1 samples
    do you want to clear previous experiment checkpoints and results? (y/n) n
    0: loss: 70.90608215332031
    0: valid loss 65.99951171875
    0: saving model to results
    1: loss: 43.6014289855957
    2: loss: 8.300681114196777
    3: loss: 61.23375701904297
    4: loss: 63.34052276611328
    5: loss: 2.010118246078491
    6: loss: 56.52588653564453
    7: loss: 0.5423888564109802
    8: loss: 0.005095238331705332
    training complete

%% Cell type:markdown id: tags:

## Inference

%% Cell type:code id: tags:

``` 
``` python
# Everything together
audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

generated_wav = audiolm(batch_size = 1)
```

%% Output

    generating semantic:   0%|          | 10/2048 [00:00<00:25, 78.55it/s]
    generating coarse: 100%|██████████| 512/512 [00:14<00:00, 34.83it/s]
    generating fine: 100%|██████████| 512/512 [02:56<00:00,  2.91it/s]

%% Cell type:code id: tags:

``` 
``` python
output_path = "out.wav"
sample_rate = 44100
torchaudio.save(output_path, generated_wav.cpu(), sample_rate)
```

%% Cell type:code id: tags:

``` 
``` python
```