Commit 9a0554f4 authored by Phil Wang's avatar Phil Wang
Browse files

complete mulan, lay out a gameplan for remaining week

parent 7624306f
Loading
Loading
Loading
Loading
+49 −0
Original line number Diff line number Diff line
@@ -6,6 +6,55 @@ Implementation of <a href="https://google-research.github.io/seanet/musiclm/exam

They are basically using text-conditioned <a href="https://github.com/lucidrains/audiolm-pytorch">AudioLM</a>, but surprisingly with the embeddings from a text-audio contrastive learned model named <a href="https://arxiv.org/abs/2208.12415">MuLan</a>. MuLan is what will be built out in this repository, with AudioLM modified from the other repository to support the music generation needs here.

## Usage

```install
$ pip install musiclm-pytorch
```

## Usage

`MuLaN` first needs to be trained

```python
import torch
from musiclm_pytorch import MuLaN, AudioSpectrogramTransformer, TextTransformer

audio_transformer = AudioSpectrogramTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64,
    spec_n_fft = 128,
    spec_win_length = 24,
    spec_aug_stretch_factor = 0.8
)

text_transformer = TextTransformer(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

mulan = MuLaN(
    audio_transformer = audio_transformer,
    text_transformer = text_transformer
)

texts = torch.randint(0, 20000, (2, 256))
wavs = torch.randn(2, 1024)

loss = mulan(wavs, texts)
loss.backward()
```

## Todo

- [ ] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions
- [ ] modify audiolm to accept conditioning embeddings, optionally take care of different dimensions through a separate projection
- [ ] audiolm and mulan goes into musiclm and generate, filter with mulan

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research
+1 −0
Original line number Diff line number Diff line
from musiclm_pytorch.musiclm_pytorch import MuLaN, MusicLM
from musiclm_pytorch.musiclm_pytorch import AudioSpectrogramTransformer, TextTransformer
+128 −5
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@ from vector_quantize_pytorch import ResidualVQ

from einops import rearrange, repeat, reduce, pack, unpack

from beartype.typing import List, Optional
from beartype import beartype

# functions
@@ -211,6 +212,8 @@ class AudioSpectrogramTransformer(nn.Module):

    ):
        super().__init__()
        self.dim = dim

        self.patch_size = pair(patch_size)
        self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1)

@@ -284,8 +287,81 @@ class AudioSpectrogramTransformer(nn.Module):

# text transformer

class TextTransformer:
    pass
@beartype
class TextTransformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        num_tokens = tokenizer.vocab_size,
        max_seq_len = 256,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_dropout = 0.,
        ff_mult = 4,
        pad_id = 0
    ):
        super().__init__()
        self.dim = dim

        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.cls_token = nn.Parameter(torch.randn(dim))

        self.transformer = Transformer(
            dim = dim,
            depth = depth,
            dim_head = dim_head,
            heads = heads,
            attn_dropout = attn_dropout,
            ff_dropout = ff_dropout,
            ff_mult = ff_mult
        )

        self.pad_id = pad_id
        self.norm = LayerNorm(dim)

    def forward(
        self,
        x = None,
        raw_texts: Optional[List[str]] = None,
        mask = None
    ):
        assert exists(x) ^ exists(raw_texts)

        if exists(raw_texts):
            x = tokenizer.tokenize(raw_texts)

        if not exists(mask):
            mask = x != self.pad_id

        b, n, device = *x.shape, x.device

        # token embedding + positional embedding

        x = self.token_emb(x)
        x = x + self.pos_emb(torch.arange(n, device = device))

        # cls tokens, as in bert

        cls_tokens = repeat(self.cls_token, 'd -> b d', b = b)
        x, ps = pack([cls_tokens, x], 'b * d')

        # account for attending to cls token with self attention mask

        mask = F.pad(mask, (1, 0), value = True)

        # attention

        x = self.transformer(x, mask = mask)

        # unpack the cls tokens

        cls_tokens, _ = unpack(x, ps, 'b * d')

        return self.norm(cls_tokens)

# main classes

@@ -294,15 +370,62 @@ class MuLaN(nn.Module):
    def __init__(
        self,
        audio_transformer: AudioSpectrogramTransformer,
        text_transformer: TextTransformer
        text_transformer: TextTransformer,
        dim_latent = 128 # they use 128
    ):
        super().__init__()
        self.audio = audio_transformer
        self.text = text_transformer

    def forward(self, x):
        return x
        self.temperature = nn.Parameter(torch.tensor(1.))

        self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
        self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)

    def forward(
        self,
        wavs,
        texts = None,
        raw_texts: Optional[List[str]] = None,
        return_similarities = False
    ):
        batch, device = wavs.shape[0], wavs.device

        audio_embeds = self.audio(wavs)
        text_embeds = self.text(texts, raw_texts = raw_texts)

        audio_latents = self.audio_to_latents(audio_embeds)
        text_latents = self.text_to_latents(text_embeds)

        audio_latents, text_latents = map(l2norm, (audio_latents, text_latents))

        cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)

        assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal'

        if return_similarities:
            return cosine_sim

        cosine_sim = cosine_sim * self.temperature.exp()

        labels = torch.arange(batch, device = device)

        contrastive_loss = F.cross_entropy(cosine_sim, labels)
        return contrastive_loss

# music lm

@beartype
class MuLaNEmbedQuantizer(nn.Module):
    def __init__(
        self,
        mulan: MuLaN
    ):
        super().__init__()

    def forward(self, x):
        raise NotImplementedError

class MusicLM(nn.Module):
    def __init__(self):
        super().__init__()