Commit ca13fbae authored by Phil Wang's avatar Phil Wang
Browse files

get a working example of text conditioned semantic transformer training

parent 228a2298
Loading
Loading
Loading
Loading
+63 −0
Original line number Diff line number Diff line
@@ -171,6 +171,69 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d

```

## Text Conditioned Audio Synthesis

ex. Semantic Transformer

```python
import torch
from audiolm_pytorch import HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, get_embeds, FairseqVQWav2Vec

wav2vec = HubertWithKmeans(
    checkpoint_path = './hubert/hubert_base_ls960.pt',
    kmeans_path = './hubert/hubert_base_ls960_L9_km500.bin'
)

semantic_transformer = SemanticTransformer(
    num_semantic_tokens = 500,
    dim = 1024,
    depth = 6,
    has_condition = True  # this will have to be set to True
).cuda()

# mock text video dataset (as an example)

# you will have to extend your own from `Dataset`, and return an audio tensor as well as a string (the audio description) in any order (the framework will autodetect and route it into the transformer)

from torch.utils.data import Dataset

class MockTextAudioDataset(Dataset):
    def __init__(self, length = 100, audio_length = 320 * 32):
        super().__init__()
        self.audio_length = audio_length
        self.len = length

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        mock_audio = torch.randn(self.audio_length)
        mock_caption = 'audio caption'
        return mock_caption, mock_audio

dataset = MockTextAudioDataset()

# instantiate semantic transformer trainer and train

trainer = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    dataset = dataset,
    batch_size = 4,
    grad_accum_every = 8,
    data_max_length = 320 * 32,
    num_train_steps = 100000
)

trainer.train()

# after much training above

sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_size = 1, max_length = 2) # (1, < 128) - may terminate early if it detects [eos]

```


## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research
+3 −0
Original line number Diff line number Diff line
@@ -844,6 +844,7 @@ class SemanticTransformerWrapper(nn.Module):

    @eval_decorator
    @torch.no_grad()
    @beartype
    def generate(
        self,
        *,
@@ -1008,6 +1009,7 @@ class CoarseTransformerWrapper(nn.Module):

    @eval_decorator
    @torch.no_grad()
    @beartype
    def generate(
        self,
        *,
@@ -1201,6 +1203,7 @@ class FineTransformerWrapper(nn.Module):

    @eval_decorator
    @torch.no_grad()
    @beartype
    def generate(
        self,
        *,
+14 −2
Original line number Diff line number Diff line
from pathlib import Path
from functools import partial, wraps

from typing import Tuple
from beartype.door import is_bearable

import torchaudio
from torchaudio.functional import resample

@@ -89,9 +92,18 @@ def collate_one_or_multiple_tensors(fn):
        is_one_data = not isinstance(data[0], tuple)

        if is_one_data:
            return fn(data)
            data = (data,)

        outputs = []
        for datum in zip(*data):
            if is_bearable(datum, Tuple[str, ...]):
                output = list(datum)
            else:
                output = fn(datum)

            outputs.append(output)

        return tuple(map(fn, zip(*data)))
        return tuple(outputs)

    return inner

+75 −39
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ from random import choice
from pathlib import Path
from shutil import rmtree

from typing import Union, List, Optional
from typing import Union, List, Optional, Tuple
from typing_extensions import Annotated

from beartype import beartype
@@ -385,8 +385,9 @@ class SemanticTransformerTrainer(nn.Module):
        *,
        num_train_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        folder,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
@@ -425,6 +426,10 @@ class SemanticTransformerTrainer(nn.Module):

        # create dataset

        self.ds = dataset
        if not exists(self.ds):
            assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
@@ -432,6 +437,8 @@ class SemanticTransformerTrainer(nn.Module):
                seq_len_multiple_of = wav2vec.seq_len_multiple_of
            )

        self.ds_fields = None

        # split for validation

        if valid_frac > 0:
@@ -500,6 +507,13 @@ class SemanticTransformerTrainer(nn.Module):
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    def data_tuple_to_kwargs(self, data):
        if not exists(self.ds_fields):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'

        return dict(zip(self.ds_fields, data))

    def train_step(self):
        device = self.device

@@ -514,9 +528,9 @@ class SemanticTransformerTrainer(nn.Module):
        # update vae (generator)

        for _ in range(self.grad_accum_every):
            wave = next(self.dl_iter).to(device)
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))

            loss = self.train_wrapper(raw_wave = wave, return_loss = True)
            loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.accelerator.backward(loss / self.grad_accum_every)

@@ -535,11 +549,11 @@ class SemanticTransformerTrainer(nn.Module):
        # sample results every so often

        if self.is_main and not (steps % self.save_results_every):
            wave = next(self.valid_dl_iter).to(device)
            data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))

            with torch.no_grad():
                self.train_wrapper.eval()
                valid_loss = self.train_wrapper(raw_wave = wave, return_loss = True)
                valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')

@@ -575,8 +589,10 @@ class CoarseTransformerTrainer(nn.Module):
        *,
        num_train_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
        data_max_length = None,
        folder,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
@@ -617,6 +633,11 @@ class CoarseTransformerTrainer(nn.Module):

        # create dataset

        self.ds = dataset

        if not exists(self.ds):
            assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
@@ -627,6 +648,8 @@ class CoarseTransformerTrainer(nn.Module):
                seq_len_multiple_of = soundstream.seq_len_multiple_of
            )

        self.ds_fields = ds_fields

        # split for validation

        if valid_frac > 0:
@@ -711,11 +734,10 @@ class CoarseTransformerTrainer(nn.Module):
        # update vae (generator)

        for _ in range(self.grad_accum_every):
            wave_wav2vec, wave_soundstream = tuple(map(lambda t: t.to(device), next(self.dl_iter)))
            data_kwargs = dict(zip(self.ds_fields, next(self.dl_iter)))

            loss = self.train_wrapper(
                raw_wave = wave_wav2vec,
                raw_wave_for_soundstream = wave_soundstream,
                **data_kwargs,
                return_loss = True
            )

@@ -736,14 +758,13 @@ class CoarseTransformerTrainer(nn.Module):
        # sample results every so often

        if self.is_main and not (steps % self.save_results_every):
            wave_wav2vec, wave_soundstream = tuple(map(lambda t: t.to(device), next(self.valid_dl_iter)))
            data_kwargs = dict(zip(self.ds_fields, next(self.valid_dl_iter)))

            with torch.no_grad():
                self.train_wrapper.eval()

                valid_loss = self.train_wrapper(
                    raw_wave = wave_wav2vec,
                    raw_wave_for_soundstream = wave_soundstream,
                    **data_kwargs,
                    return_loss = True
                )

@@ -780,8 +801,9 @@ class FineTransformerTrainer(nn.Module):
        *,
        num_train_steps,
        batch_size,
        dataset: Optional[Dataset] = None,
        data_max_length = None,
        folder,
        folder = None,
        lr = 3e-4,
        grad_accum_every = 1,
        wd = 0.,
@@ -820,6 +842,11 @@ class FineTransformerTrainer(nn.Module):

        # create dataset

        self.ds = dataset

        if not exists(self.ds):
            assert exists(folder), 'folder must be passed in, if not passing in a custom dataset for text conditioned audio synthesis training'

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
@@ -827,6 +854,8 @@ class FineTransformerTrainer(nn.Module):
                seq_len_multiple_of = soundstream.seq_len_multiple_of
            )

        self.ds_fields = None

        # split for validation

        if valid_frac > 0:
@@ -897,6 +926,13 @@ class FineTransformerTrainer(nn.Module):
    def is_local_main(self):
        return self.accelerator.is_local_main_process

    def data_tuple_to_kwargs(self, data):
        if not exists(self.ds_fields):
            self.ds_fields = determine_types(data, DATASET_FIELD_TYPE_CONFIG)
            assert not has_duplicates(self.ds_fields), 'dataset fields must not have duplicate field names'

        return dict(zip(self.ds_fields, data))

    def train_step(self):
        device = self.device

@@ -911,8 +947,8 @@ class FineTransformerTrainer(nn.Module):
        # update vae (generator)

        for _ in range(self.grad_accum_every):
            wave = next(self.dl_iter).to(device)
            loss = self.train_wrapper(raw_wave = wave, return_loss = True)
            data_kwargs = self.data_tuple_to_kwargs(next(self.dl_iter))
            loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.accelerator.backward(loss / self.grad_accum_every)

@@ -931,11 +967,11 @@ class FineTransformerTrainer(nn.Module):
        # sample results every so often

        if self.is_main and not (steps % self.save_results_every):
            wave = next(self.valid_dl_iter).to(device)
            data_kwargs = self.data_tuple_to_kwargs(next(self.valid_dl_iter))

            with torch.no_grad():
                self.train_wrapper.eval()
                valid_loss = self.train_wrapper(raw_wave = wave, return_loss = True)
                valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.16',
  version = '0.1.17',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',