Loading musiclm_pytorch/musiclm_pytorch.py +61 −6 Original line number Diff line number Diff line Loading @@ -4,6 +4,8 @@ from torch import nn, einsum from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking from audiolm_pytorch import AudioLM from x_clip.tokenizer import tokenizer from vector_quantize_pytorch import ResidualVQ Loading Loading @@ -376,6 +378,8 @@ class MuLaN(nn.Module): dim_latent = 128 # they use 128 ): super().__init__() self.dim_latent = dim_latent self.audio = audio_transformer self.text = text_transformer Loading Loading @@ -421,16 +425,67 @@ class MuLaN(nn.Module): class MuLaNEmbedQuantizer(nn.Module): def __init__( self, mulan: MuLaN mulan: MuLaN, rq_num_quantizers = 8, rq_ema_decay = 0.9, codebook_size = 1024, ): super().__init__() self.mulan = mulan self.rq = ResidualVQ( dim = mulan.dim_latent, num_quantizers = rq_num_quantizers, codebook_size = codebook_size, decay = rq_ema_decay, commitment_weight = 0, # only use EMA to update codebooks kmeans_init = True, threshold_ema_dead_code = 2, quantize_dropout = False # no quantize dropout ) def forward(self, x): raise NotImplementedError def forward( self, wavs = None, texts = None ): assert exists(wavs) ^ exist(texts) with torch.no_grad(): self.mulan.eval() # sound and language live in joint embedding space because of contrastive learning if exists(wavs): latents = self.mulan.get_audio_latents(wavs) elif exists(texts): latents = self.mulan.get_text_latents(texts) _, indices, _ = self.rq(latents) return indices @beartype class MusicLM(nn.Module): def __init__(self): def __init__( self, audio_lm: AudioLM, mulan_embed_quantizer: MuLaNEmbedQuantizer ): super().__init__() self.mulan_embed_quantizer = mulan_embed_quantizer self.audio_lm = audio_lm def forward(self, x): return x @torch.no_grad() def forward( self, raw_texts: List[str], **audio_lm_kwargs ): self.eval() texts = tokenizer.tokenize(raw_texts) cond_tokens = self.mulan_embed_quantizer(texts = texts) wavs = self.audio_lm.generate(cond_tokens = cond_tokens, **audio_lm_kwargs) return wavs Loading
musiclm_pytorch/musiclm_pytorch.py +61 −6 Original line number Diff line number Diff line Loading @@ -4,6 +4,8 @@ from torch import nn, einsum from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking from audiolm_pytorch import AudioLM from x_clip.tokenizer import tokenizer from vector_quantize_pytorch import ResidualVQ Loading Loading @@ -376,6 +378,8 @@ class MuLaN(nn.Module): dim_latent = 128 # they use 128 ): super().__init__() self.dim_latent = dim_latent self.audio = audio_transformer self.text = text_transformer Loading Loading @@ -421,16 +425,67 @@ class MuLaN(nn.Module): class MuLaNEmbedQuantizer(nn.Module): def __init__( self, mulan: MuLaN mulan: MuLaN, rq_num_quantizers = 8, rq_ema_decay = 0.9, codebook_size = 1024, ): super().__init__() self.mulan = mulan self.rq = ResidualVQ( dim = mulan.dim_latent, num_quantizers = rq_num_quantizers, codebook_size = codebook_size, decay = rq_ema_decay, commitment_weight = 0, # only use EMA to update codebooks kmeans_init = True, threshold_ema_dead_code = 2, quantize_dropout = False # no quantize dropout ) def forward(self, x): raise NotImplementedError def forward( self, wavs = None, texts = None ): assert exists(wavs) ^ exist(texts) with torch.no_grad(): self.mulan.eval() # sound and language live in joint embedding space because of contrastive learning if exists(wavs): latents = self.mulan.get_audio_latents(wavs) elif exists(texts): latents = self.mulan.get_text_latents(texts) _, indices, _ = self.rq(latents) return indices @beartype class MusicLM(nn.Module): def __init__(self): def __init__( self, audio_lm: AudioLM, mulan_embed_quantizer: MuLaNEmbedQuantizer ): super().__init__() self.mulan_embed_quantizer = mulan_embed_quantizer self.audio_lm = audio_lm def forward(self, x): return x @torch.no_grad() def forward( self, raw_texts: List[str], **audio_lm_kwargs ): self.eval() texts = tokenizer.tokenize(raw_texts) cond_tokens = self.mulan_embed_quantizer(texts = texts) wavs = self.audio_lm.generate(cond_tokens = cond_tokens, **audio_lm_kwargs) return wavs