Commit 9c6785b8 authored by Phil Wang's avatar Phil Wang
Browse files

get some code for fine transformer (last stage) into place

parent 90904ee5
Loading
Loading
Loading
Loading
+40 −0
Original line number Diff line number Diff line
@@ -4,6 +4,46 @@

Implementation of <a href="https://google-research.github.io/seanet/audiolm/examples/">AudioLM</a>, a Language Modeling Approach to Audio Generation out of Google Research, in Pytorch

## Install

```bash
$ pip install audiolm-pytorch
```

## Usage

```python
import torch
from audiolm_pytorch.audiolm_pytorch import SoundStream, AudioLM, FineTransformer, FineTransformerWrapper

soundstream = SoundStream(
    codebook_size = 1024,
    rq_num_quantizers = 8,
)

transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

train_wrapper = FineTransformerWrapper(
    soundstream = soundstream,
    transformer = transformer
).cuda()

raw_waveform = torch.randn(1, 320 * 512).cuda()

loss = train_wrapper(
    raw_wave = raw_waveform,
    return_loss = True
)

loss.backward()
```

## Citations

```bibtex
+3 −0
Original line number Diff line number Diff line
from audiolm_pytorch.audiolm_pytorch import AudioLM
from audiolm_pytorch.audiolm_pytorch import SoundStream

from audiolm_pytorch.audiolm_pytorch import SemanticTransformer, CoarseTransformer, FineTransformer
from audiolm_pytorch.audiolm_pytorch import FineTransformerWrapper
+114 −12
Original line number Diff line number Diff line
import math
from functools import partial
from typing import Optional

import torch
from torch import nn, einsum
from torch.autograd import grad as torch_grad
import torch.nn.functional as F

from einops import rearrange
from einops import rearrange, repeat

from vector_quantize_pytorch import ResidualVQ

@@ -305,6 +306,7 @@ class SoundStream(nn.Module):
    def forward(
        self,
        x,
        return_encoded = False,
        return_discr_loss = False,
        return_stft_discr_loss = False
    ):
@@ -319,6 +321,9 @@ class SoundStream(nn.Module):
        x, indices, commit_loss = self.rq(x)
        x = rearrange(x, 'b n c -> b c n')

        if return_encoded:
            return x, indices, commit_loss

        recon_x = self.decoder(x)

        # stft discr loss
@@ -536,6 +541,7 @@ class SemanticTransformer(nn.Module):
        **kwargs
    ):
        super().__init__()
        self.transformer = Transformer(dim = dim, **kwargs)

    def forward(
        self,
@@ -548,11 +554,13 @@ class CoarseTransformer(nn.Module):
        self,
        *,
        num_semantic_tokens,
        codebook_size,
        num_coarse_tokens,
        dim,
        **kwargs
    ):
        super().__init__()
        self.transformer = Transformer(dim = dim, **kwargs)

    def forward(
        self,
@@ -565,19 +573,112 @@ class FineTransformer(nn.Module):
    def __init__(
        self,
        *,
        num_coarse_tokens,
        num_fine_tokens,
        num_coarse_quantizers,
        num_fine_quantizers,
        codebook_size,
        dim,
        **kwargs
    ):
        super().__init__()

        self.start_token = nn.Parameter(torch.randn(dim))

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size, dim)

        self.transformer = Transformer(dim = dim, **kwargs)

        self.coarse_logits = nn.Linear(dim, codebook_size)
        self.fine_logits = nn.Linear(dim, codebook_size)

    def forward(
        self,
        coarse_token_ids,
        fine_token_ids
    ):
        raise NotImplemented
        coarse_token_ids, fine_token_ids = map(lambda t: rearrange(t, 'b ... -> b (...)'), (coarse_token_ids, fine_token_ids))

        b, n = coarse_token_ids.shape

        coarse_tokens = self.coarse_embedding(coarse_token_ids)
        fine_tokens = self.fine_embedding(fine_token_ids)

        start_tokens = repeat(self.start_token, 'd -> b 1 d', b = b)

        tokens = torch.cat((start_tokens, coarse_tokens, fine_tokens), dim = 1)

        tokens = self.transformer(tokens)

        pred_coarse_tokens, pred_fine_tokens = tokens[:, :n], tokens[:, n:]

        coarse_logits = self.coarse_logits(pred_coarse_tokens)
        fine_logits = self.fine_logits(pred_fine_tokens)

        return coarse_logits, fine_logits

# training wrappers

class FineTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        soundstream: Optional[SoundStream],
        transformer: FineTransformer,
        num_coarse_quantize = 3
    ):
        super().__init__()
        self.soundstream = soundstream
        self.transformer = transformer

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize

    def forward(
        self,
        *,
        raw_wave = None,
        coarse_token_ids = None,
        fine_token_ids = None,
        return_loss = False
    ):
        assert exists(raw_wave) ^ (exists(coarse_token_ids) and exists(fine_token_ids)), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

        if exists(raw_wave):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                coarse_token_ids, fine_token_ids = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

        if return_loss:
            coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
            fine_token_ids = fine_token_ids[:, :-1]

        coarse_logits, fine_logits = self.transformer(
            coarse_token_ids = coarse_token_ids,
            fine_token_ids = fine_token_ids
        )

        if not return_loss:
            return coarse_logits, fine_logits

        coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

        coarse_loss = F.cross_entropy(
            coarse_logits,
            coarse_labels
        )

        fine_loss = F.cross_entropy(
            fine_logits,
            fine_labels
        )

        return (coarse_loss + fine_loss) * 0.5

# audio LM

@@ -585,15 +686,16 @@ class AudioLM(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        **kwargs
        soundstream: SoundStream,
        semantic_transformer: SemanticTransformer,
        coarse_transformer: CoarseTransformer,
        fine_transformer: FineTransformer,
    ):
        super().__init__()
        self.attend_semantic = Transformer(dim = dim, depth = depth, **kwargs)
        self.attend_coarse = Transformer(dim = dim, depth = depth, **kwargs)
        self.attend_fine = Transformer(dim = dim, depth = depth, **kwargs)
        self.soundstream = soundstream
        self.semantic = semantic_transformer
        self.coarse = coarse_transformer
        self.fine = fine_transformer

    def forward(self, x):
        x = self.attend_semantic(x)
        return x
        raise NotImplemented