Commit ac00f9dc authored by Phil Wang's avatar Phil Wang
Browse files

some rough outlines for mulan wrapper and eventual musiclm

parent d745fcd8
Loading
Loading
Loading
Loading
+61 −6
Original line number Diff line number Diff line
@@ -4,6 +4,8 @@ from torch import nn, einsum

from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking

from audiolm_pytorch import AudioLM

from x_clip.tokenizer import tokenizer
from vector_quantize_pytorch import ResidualVQ

@@ -376,6 +378,8 @@ class MuLaN(nn.Module):
        dim_latent = 128 # they use 128
    ):
        super().__init__()
        self.dim_latent = dim_latent

        self.audio = audio_transformer
        self.text = text_transformer

@@ -421,16 +425,67 @@ class MuLaN(nn.Module):
class MuLaNEmbedQuantizer(nn.Module):
    def __init__(
        self,
        mulan: MuLaN
        mulan: MuLaN,
        rq_num_quantizers = 8,
        rq_ema_decay = 0.9,
        codebook_size = 1024,
    ):
        super().__init__()
        self.mulan = mulan

        self.rq = ResidualVQ(
            dim = mulan.dim_latent,
            num_quantizers = rq_num_quantizers,
            codebook_size = codebook_size,
            decay = rq_ema_decay,
            commitment_weight = 0,    # only use EMA to update codebooks
            kmeans_init = True,
            threshold_ema_dead_code = 2,
            quantize_dropout = False  # no quantize dropout
        )

    def forward(self, x):
        raise NotImplementedError
    def forward(
        self,
        wavs = None,
        texts = None
    ):
        assert exists(wavs) ^ exist(texts)

        with torch.no_grad():
            self.mulan.eval()

            # sound and language live in joint embedding space because of contrastive learning

            if exists(wavs):
                latents = self.mulan.get_audio_latents(wavs)
            elif exists(texts):
                latents = self.mulan.get_text_latents(texts)

        _, indices, _ = self.rq(latents)

        return indices

@beartype
class MusicLM(nn.Module):
    def __init__(self):
    def __init__(
        self,
        audio_lm: AudioLM,
        mulan_embed_quantizer: MuLaNEmbedQuantizer
    ):
        super().__init__()
        self.mulan_embed_quantizer = mulan_embed_quantizer
        self.audio_lm = audio_lm

    def forward(self, x):
        return x
    @torch.no_grad()
    def forward(
        self,
        raw_texts: List[str],
        **audio_lm_kwargs
    ):
        self.eval()

        texts = tokenizer.tokenize(raw_texts)
        cond_tokens = self.mulan_embed_quantizer(texts = texts)

        wavs = self.audio_lm.generate(cond_tokens = cond_tokens, **audio_lm_kwargs)
        return wavs