Commit 99d1f044 authored by Phil Wang's avatar Phil Wang
Browse files

complete the residual vector quantization of the joint embedding music-text...

complete the residual vector quantization of the joint embedding music-text space of mulan, with fetching of learned conditioning embeddings, setup so all three transformers in audiolm can have its own
parent e81b2d47
Loading
Loading
Loading
Loading
+22 −2
Original line number Diff line number Diff line
@@ -60,16 +60,36 @@ embeds = mulan.get_audio_latents(wavs) # during training
embeds = mulan.get_text_latents(texts)  # during inference
```

To obtain the conditioning embeddings for the three transformers that are a part of `AudioLM`, you must use the `MuLaNEmbedQuantizer` as so

```python
from musiclm_pytorch import MuLaNEmbedQuantizer

wavs = torch.randn(2, 1024)
embeds = mulan.get_audio_latents(wavs)

# setup the quantizer with the namespaced conditioning embeddings, unique per quantizer as well as namespace (per transformer)

quantizer = MuLaNEmbedQuantizer(
    mulan = mulan,
    conditioning_dims = (1024, 1024, 1024), # say all three transformers have model dimensions of 1024
    namespaces = ('semantic', 'coarse', 'fine')
)

# now say you want the conditioning embeddings for semantic transformer

conds = quantizer(wavs = wavs, namespace = 'semantic') # (2, 8, 1024) - 8 is number of quantizers
```

## Todo

- [x] mulan seems to be using decoupled contrastive learning, offer that as an option
- [x] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions

- [ ] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions
- [ ] modify audiolm to accept conditioning embeddings, optionally take care of different dimensions through a separate projection
- [ ] audiolm and mulan goes into musiclm and generate, filter with mulan
- [ ] add a version of mulan to <a href="https://github.com/mlfoundations/open_clip">open clip</a>
- [ ] set all the proper spectrogram hyperparameters
- [ ] email some contrastive learning experts and figure out why some papers are sharing the projection from embeddings to latent space
- [ ] improvise a bit and give the audio transformer a position generating module before each attention layer

## Appreciation
+2 −1
Original line number Diff line number Diff line
from musiclm_pytorch.musiclm_pytorch import MuLaN, MusicLM
from musiclm_pytorch.musiclm_pytorch import MuLaN, MuLaNEmbedQuantizer, MusicLM

from musiclm_pytorch.musiclm_pytorch import AudioSpectrogramTransformer, TextTransformer
+46 −5
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ from vector_quantize_pytorch import ResidualVQ

from einops import rearrange, repeat, reduce, pack, unpack

from beartype.typing import List, Optional
from beartype.typing import List, Optional, Tuple
from beartype import beartype

# functions
@@ -19,6 +19,9 @@ from beartype import beartype
def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def round_down_nearest_multiple(n, divisor):
    return n // divisor * divisor

@@ -449,15 +452,26 @@ class MuLaNEmbedQuantizer(nn.Module):
    def __init__(
        self,
        mulan: MuLaN,
        conditioning_dims: Tuple[int, ...],
        rq_num_quantizers = 8,
        rq_ema_decay = 0.9,
        codebook_size = 1024,
        namespaces: Tuple[str, ...] = ('semantic', 'coarse', 'fine'),

    ):
        super().__init__()
        self.mulan = mulan

        assert len(namespaces) > 0
        self.namespaces = namespaces
        self.conditioning_dims = conditioning_dims

        assert len(conditioning_dims) == len(namespaces), 'number of conditioning dimensions must be equal to number of namespaces'

        dim = mulan.dim_latent

        self.rq = ResidualVQ(
            dim = mulan.dim_latent,
            dim = dim,
            num_quantizers = rq_num_quantizers,
            codebook_size = codebook_size,
            decay = rq_ema_decay,
@@ -467,12 +481,33 @@ class MuLaNEmbedQuantizer(nn.Module):
            quantize_dropout = False  # no quantize dropout
        )

        self.dim = dim
        self.num_codebooks = rq_num_quantizers

        self.cond_embeddings = nn.ParameterDict({})

        for namespace, conditioning_dim in zip(namespaces, conditioning_dims):
            cond_embeddings = nn.Parameter(torch.randn(rq_num_quantizers, codebook_size, conditioning_dim))
            nn.init.normal_(cond_embeddings, std = 0.02)

            self.cond_embeddings[namespace] = cond_embeddings

        self.set_default_namespace(namespaces[0])

    def set_default_namespace(self, namespace):
        self._default_namespace = namespace

    def forward(
        self,
        wavs = None,
        texts = None
        texts = None,
        namespace = None
    ):
        assert exists(wavs) ^ exist(texts)
        assert exists(wavs) ^ exists(texts)

        namespace = default(namespace, self._default_namespace)
        assert namespace in self.namespaces, f'namespace {namespace} not found'
        cond_embeddings = self.cond_embeddings[namespace]

        with torch.no_grad():
            self.mulan.eval()
@@ -486,7 +521,13 @@ class MuLaNEmbedQuantizer(nn.Module):

        _, indices, _ = self.rq(latents)

        return indices
        batch, num_codebooks, dim = indices.shape[0], self.num_codebooks, cond_embeddings.shape[-1]

        cond_embeddings = repeat(cond_embeddings, 'q c d -> b q c d', b = batch)
        indices = repeat(indices, 'b q -> b q 1 d', q = num_codebooks, d = dim)

        cond_embeddings = cond_embeddings.gather(2, indices)
        return rearrange(cond_embeddings, 'b q 1 d -> b q d')

@beartype
class MusicLM(nn.Module):
+2 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.2',
  version = '0.0.3',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',
@@ -22,7 +22,7 @@ setup(
    'audiolm-pytorch',
    'beartype',
    'einops>=0.4',
    'vector-quantize-pytorch>=0.10.15',
    'vector-quantize-pytorch>=1.0.0',
    'x-clip',
    'torch>=1.6',
    'torchaudio'