Commit f2c2ec24 authored by Leon Wu's avatar Leon Wu
Browse files

Fix branch

parent 3518d5a6
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
%% Cell type:code id: tags:

``` 
# pretrained substitute
# Encodec as a replacement for SoundStream, and MERT as a replacement for w2v-BERT.
# idea from https://github.com/zhvng/open-musiclm
```

%% Cell type:code id: tags:

``` 
!pip install transformers torch datasets

# download audiolm_pytorch manually so i can inject print statements
# !pip uninstall -y audiolm_pytorch
import urllib.request
import os
import zipfile
if not os.path.isfile("audiolm-pytorch.zip"):
  urllib.request.urlretrieve("https://github.com/lucidrains/audiolm-pytorch/archive/refs/heads/main.zip", "audiolm-pytorch.zip")
if not os.path.isdir("audiolm-pytorch"):
  with zipfile.ZipFile("audiolm-pytorch.zip", 'r') as zip_ref:
    zip_ref.extractall("audiolm-pytorch")
# !mv audiolm-pytorch/audiolm-pytorch-main/audiolm_pytorch .
# !mv audiolm-pytorch/audiolm-pytorch-personal_hacks/audiolm_pytorch .
!rm -rf audiolm-pytorch # not the one with underscore which is the actual library
```

%% Output

    Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
    Requirement already satisfied: transformers in /usr/local/lib/python3.8/dist-packages (4.26.1)
    Requirement already satisfied: torch in /usr/local/lib/python3.8/dist-packages (1.13.1+cu116)
    Requirement already satisfied: datasets in /usr/local/lib/python3.8/dist-packages (2.10.0)
    Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (1.22.4)
    Requirement already satisfied: huggingface-hub<1.0,>=0.11.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.12.1)
    Requirement already satisfied: requests in /usr/local/lib/python3.8/dist-packages (from transformers) (2.25.1)
    Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.8/dist-packages (from transformers) (4.64.1)
    Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.8/dist-packages (from transformers) (23.0)
    Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (6.0)
    Requirement already satisfied: filelock in /usr/local/lib/python3.8/dist-packages (from transformers) (3.9.0)
    Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.8/dist-packages (from transformers) (0.13.2)
    Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.8/dist-packages (from transformers) (2022.6.2)
    Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch) (4.5.0)
    Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.18.0)
    Requirement already satisfied: xxhash in /usr/local/lib/python3.8/dist-packages (from datasets) (3.2.0)
    Requirement already satisfied: aiohttp in /usr/local/lib/python3.8/dist-packages (from datasets) (3.8.4)
    Requirement already satisfied: pyarrow>=6.0.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (9.0.0)
    Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.8/dist-packages (from datasets) (2023.1.0)
    Requirement already satisfied: dill<0.3.7,>=0.3.0 in /usr/local/lib/python3.8/dist-packages (from datasets) (0.3.6)
    Requirement already satisfied: multiprocess in /usr/local/lib/python3.8/dist-packages (from datasets) (0.70.14)
    Requirement already satisfied: pandas in /usr/local/lib/python3.8/dist-packages (from datasets) (1.3.5)
    Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (22.2.0)
    Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.1)
    Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (4.0.2)
    Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.8.2)
    Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (1.3.3)
    Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (6.0.4)
    Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.8/dist-packages (from aiohttp->datasets) (3.0.1)
    Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2.10)
    Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (1.26.14)
    Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (2022.12.7)
    Requirement already satisfied: chardet<5,>=3.0.2 in /usr/local/lib/python3.8/dist-packages (from requests->transformers) (4.0.0)
    Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2022.7.1)
    Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.8/dist-packages (from pandas->datasets) (2.8.2)
    Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.8/dist-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)

%% Cell type:code id: tags:

``` 
# # semantic- MERT
# # https://huggingface.co/m-a-p/MERT-v0
# # MERT-v0 is a completely unsupervised model trained on 1000 hour music audios.

# from transformers import Wav2Vec2Processor, HubertModel
# import torch
# from torch import nn
# from datasets import load_dataset

# # load demo audio and set processor
# dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# dataset = dataset.sort("id")
# sampling_rate = dataset.features["audio"].sampling_rate
# processor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")

# # loading our model weights
# model = HubertModel.from_pretrained("m-a-p/MERT-v0")

# # audio file is decoded on the fly
# inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
# with torch.no_grad():
#     outputs = model(**inputs, output_hidden_states=True)

# # take a look at the output shape, there are 13 layers of representation
# # each layer performs differently in different downstream tasks, you should choose empirically
# all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze()
# print(all_layer_hidden_states.shape) # [13 layer, 292 timestep, 768 feature_dim]

# # # for utterance level classification tasks, you can simply reduce the representation in time
# # time_reduced_hidden_states = all_layer_hidden_states.mean(-2)
# # print(time_reduced_hidden_states.shape) # [13, 768]

# # # you can even use a learnable weighted average representation
# # aggregator = nn.Conv1d(in_channels=13, out_channels=1, kernel_size=1)
# # weighted_avg_hidden_states = aggregator(time_reduced_hidden_states.unsqueeze(0)).squeeze()
# # print(weighted_avg_hidden_states.shape) # [768]
```

%% Cell type:code id: tags:

``` 
# original semantic transformer
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer
import os
import urllib

# hubert checkpoints can be downloaded at
# https://github.com/facebookresearch/fairseq/tree/main/examples/hubert

hubert_ckpt = 'hubert/hubert_base_ls960.pt'
hubert_quantizer = f'hubert/hubert_base_ls960_L9_km500.bin' # listed in row "HuBERT Base (~95M params)", column Quantizer
if not os.path.isdir("hubert"):
  os.makedirs("hubert")
if not os.path.isfile(hubert_ckpt):
  hubert_ckpt_download = f"https://dl.fbaipublicfiles.com/{hubert_ckpt}"
  urllib.request.urlretrieve(hubert_ckpt_download, f"./{hubert_ckpt}")
if not os.path.isfile(hubert_quantizer):
  hubert_quantizer_download = f"https://dl.fbaipublicfiles.com/{hubert_quantizer}"
  urllib.request.urlretrieve(hubert_quantizer_download, f"./{hubert_quantizer}")

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()
```

%% Output

    /usr/local/lib/python3.8/dist-packages/sklearn/base.py:329: UserWarning: Trying to unpickle estimator MiniBatchKMeans from version 0.24.0 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
    https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations
      warnings.warn(

%% Cell type:code id: tags:

``` 
# import wave
# import struct

# # dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
# # dataset[0]["audio"]["array"]
# sampling_rate = dataset.features["audio"].sampling_rate

# def save_wav(file_name, audio, sample_rate=sampling_rate):
#   # Open up a wav file
#   wav_file=wave.open(file_name,"w")
#   # wav params
#   nchannels = 1
#   sampwidth = 2
#   # 44100 is the industry standard sample rate - CD quality.  If you need to
#   # save on file size you can adjust it downwards. The stanard for low quality
#   # is 8000 or 8kHz.
#   nframes = len(audio)
#   comptype = "NONE"
#   compname = "not compressed"
#   wav_file.setparams((nchannels, sampwidth, sample_rate, nframes, comptype, compname))
#   # WAV files here are using short, 16 bit, signed integers for the
#   # sample size.  So we multiply the floating point data we have by 32767, the
#   # maximum value for a short integer.  NOTE: It is theortically possible to
#   # use the floating point -1.0 to 1.0 data directly in a WAV file but not
#   # obvious how to do that using the wave module in python.
#   for sample in audio:
#     wav_file.writeframes(struct.pack('h', int( sample * 32767.0 )))
#   wav_file.close()
#   return
# save_wav("test.wav", dataset[1]["audio"]["array"])
```

%% Cell type:code id: tags:

``` 
from audiolm_pytorch import SemanticTransformerWrapper
import numpy as np

# in case not already loaded
from datasets import load_dataset
# load demo audio and set processor
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")

batch_size = 2
# sample data is 77040 samples at 16kHz sampling rate
# just reshape it here so batch size for prime_wave is effectively 1
samples = np.array([dataset[1]["audio"]["array"], dataset[1]["audio"]["array"]])
prime_wave = torch.tensor(samples).reshape(2, 77040).cuda()
# raise AssertionError(prime_wave.shape)
max_length = 2048
semantic = SemanticTransformerWrapper(
            wav2vec = wav2vec,
            transformer = semantic_transformer,
            audio_conditioner = None,
            unique_consecutive = True
        ).cuda()
semantic_tokens = semantic.generate(
            text_embeds = None, # no text, it's not musicLM
            batch_size = batch_size,
            prime_wave = prime_wave,
            max_length = max_length
        )
# semantic.device # should be cuda
```

%% Output

    WARNING:datasets.builder:Found cached dataset librispeech_asr_demo (/root/.cache/huggingface/datasets/hf-internal-testing___librispeech_asr_demo/clean/2.1.0/d3bc4c2bc2078fcde3ad0f0f635862e4c0fef78ba94c4a34c4c250a097af240b)

    embed.keys(): dict_keys(['x', 'padding_mask', 'features'])
    embed['x'] shape: torch.Size([2, 240, 768]), embed['features'].shape: torch.Size([2, 240, 768])
    wav_input shape: torch.Size([2, 77040]), embed shape: torch.Size([480, 768]), packed_shape: [torch.Size([2, 240])]
    codebook_indices before unpacking: torch.Size([480])
    codebook_indices after unpacking: torch.Size([2, 240])
    ids.shape: torch.Size([2, 240]) and prime_wave True

    generating semantic:  17%|█▋        | 324/1905 [00:10<00:52, 30.37it/s]

    ---------------------------------------------------------------------------
    KeyboardInterrupt                         Traceback (most recent call last)
    <ipython-input-6-11be373fc09e> in <module>
         20             unique_consecutive = True
         21         ).cuda()
    ---> 22 semantic_tokens = semantic.generate(
         23             text_embeds = None, # no text, it's not musicLM
         24             batch_size = batch_size,
    /content/audiolm_pytorch/audiolm_pytorch.py in inner(model, *args, **kwargs)
         55         was_training = model.training
         56         model.eval()
    ---> 57         out = fn(model, *args, **kwargs)
         58         model.train(was_training)
         59         return out
    /usr/local/lib/python3.8/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
         25         def decorate_context(*args, **kwargs):
         26             with self.clone():
    ---> 27                 return func(*args, **kwargs)
         28         return cast(F, decorate_context)
         29
    <@beartype(audiolm_pytorch.audiolm_pytorch.SemanticTransformerWrapper.generate) at 0x7fcf8e4a9280> in generate(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_9474080, __beartype_getrandbits, *args, **kwargs)
    /content/audiolm_pytorch/audiolm_pytorch.py in generate(self, max_length, text, text_embeds, prime_wave, prime_ids, batch_size, cond_scale, filter_thres, temperature, include_eos_in_output, **kwargs)
       1043             sample_semantic_ids = torch.cat((sample_semantic_ids, sampled), dim = -1)
       1044
    -> 1045             if all_rows_have_eos_id(sample_semantic_ids, self.eos_id):
       1046                 break
       1047
    KeyboardInterrupt:

%% Cell type:code id: tags:

``` 
# ?semantic.wav2vec
# torch.tensor(dataset[1]["audio"]["array"]).cuda().device
# semantic_tokens.shape
# semantic_tokens[:,-1] # unfortunately doesn't seem to be the EOS we're looking for
# semantic_tokens[:, 0]
```

%% Cell type:code id: tags:

``` 
# dataset[1]["audio"]["array"].shape
# # len(dataset[1]["audio"]["array"]) # 77040
# # dataset.features["audio"].sampling_rate # 16000
# # so 4.815 seconds of audio
```

%% Cell type:code id: tags:

``` 
torch.empty((batch_size, 0), dtype = torch.long)
```

%% Cell type:code id: tags:

``` 
```