Commit 08ef21e7 authored by Leon Wu's avatar Leon Wu
Browse files

Start working on replacing hubert and soundstream (see...

parent 1bac3bdc
Loading
Loading
Loading
Loading
+158 −0
Original line number Diff line number Diff line
%% Cell type:code id: tags:

``` 
# run audiolm with replaced soundstream etc-- see changes starting with
# https://github.com/LWprogramming/audiolm-pytorch/commit/37be3c512cadeecab2184def3dd9bb12171b1bca

!rm -rf audiolm_pytorch/ setup.py audiolm-pytorch.zip # clean out any old stuff floating around

!pip install torch datasets boto3

# download audiolm_pytorch manually so i can inject print statements
# !pip uninstall -y audiolm_pytorch

raise AssertionError("don't forget to upload the customized version of audiolm_pytorch with print statements")

# !zip -r audiolm_pytorch.zip audiolm_pytorch/

import urllib.request
import os
import zipfile
if not os.path.isfile("audiolm-pytorch.zip"):
  urllib.request.urlretrieve("https://github.com/LWProgramming/audiolm-pytorch/archive/refs/heads/personal_hacks.zip", "audiolm-pytorch.zip")
if not os.path.isdir("audiolm-pytorch"):
  with zipfile.ZipFile("audiolm-pytorch.zip", 'r') as zip_ref:
    zip_ref.extractall("audiolm-pytorch")
!mv audiolm-pytorch/audiolm-pytorch-personal_hacks/audiolm_pytorch .

# install necessary files for patched audiolm-pytorch
!mv audiolm-pytorch/audiolm-pytorch-personal_hacks/setup.py .
!pip install . # install requirements from the patched audiolm-pytorch dir
!rm -rf audiolm-pytorch # not the one with underscore which is the actual library
```

%% Cell type:markdown id: tags:

# Semantic (MERT)

%% Cell type:code id: tags:

``` 
# original semantic transformer
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerWrapper, SemanticTransformerTrainer
import os
import urllib

# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer
if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    # checkpoint_path = './hubert/hubert_base_ls960.pt',
    checkpoint_path = None,
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin',
    use_mert = True
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
)

```

%% Cell type:markdown id: tags:

# Coarse (encodec)

%% Cell type:code id: tags:

``` 
from audiolm_pytorch import CoarseTransformer

assert 'wav2vec' in locals() # expect that we have a wav2vec from semantic part

coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)
```

%% Cell type:markdown id: tags:

# Fine (encodec)

%% Cell type:code id: tags:

``` 
from audiolm_pytorch import FineTransformer

fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)
```

%% Cell type:markdown id: tags:

# Generate

%% Cell type:code id: tags:

``` 
from audiolm_pytorch import AudioLM

audiolm = AudioLM(
    wav2vec = wav2vec,
    soundstream = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
).cuda() # TODO: possibly things are already cuda but I'm not sure

generated_wav = audiolm(batch_size = 1)
```

%% Output

    ---------------------------------------------------------------------------
    AssertionError                            Traceback (most recent call last)
    <ipython-input-2-a871fdc9ebee> in <module>
    ----> 1 assert False

    AssertionError:

%% Cell type:code id: tags:

``` 
x = True
y = True
not(x and y)
```

%% Output

    False

%% Cell type:code id: tags:

``` 
```