Loading README.md +1 −1 Original line number Diff line number Diff line Loading @@ -133,7 +133,7 @@ musiclm = MusicLM( mulan_embed_quantizer = mulan_embed_quantizer ) music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.Tensor music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4) # sample 4 and pick the top match with mulan ``` ## Todo Loading musiclm_pytorch/musiclm_pytorch.py +48 −6 Original line number Diff line number Diff line Loading @@ -23,6 +23,9 @@ from beartype import beartype def exists(val): return val is not None def first(it): return it[0] def default(val, d): return val if exists(val) else d Loading Loading @@ -243,6 +246,8 @@ class AudioSpectrogramTransformer(nn.Module): attn_dropout = 0., ff_mult = 4, ff_dropout = 0., accept_spec = False, accept_spec_time_first = True, spec_n_fft = 128, spec_power = 2, spec_win_length = 24, Loading @@ -268,6 +273,9 @@ class AudioSpectrogramTransformer(nn.Module): nn.LayerNorm(dim) ) self.accept_spec = accept_spec self.accept_spec_time_first = accept_spec_time_first self.spec = Spectrogram( n_fft = spec_n_fft, power = spec_power, Loading Loading @@ -321,7 +329,12 @@ class AudioSpectrogramTransformer(nn.Module): force_no_patch_dropout = False ): batch, device = x.shape[0], x.device assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2) if self.accept_spec and self.accept_spec_time_first: x = rearrange(x, 'b t f -> b f t') if not self.accept_spec: x = self.spec(x) if self.training: Loading Loading @@ -525,18 +538,26 @@ class MuLaN(nn.Module): wavs, texts = None, raw_texts: Optional[List[str]] = None, return_similarities = False return_latents = False, return_similarities = False, return_pairwise_similarities = False ): batch, device = wavs.shape[0], wavs.device audio_latents = self.get_audio_latents(wavs) text_latents = self.get_text_latents(texts, raw_texts = raw_texts) if return_latents: return audio_latents, text_latents if return_similarities: return einsum('i d, i d -> i', audio_latents, text_latents) cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents) assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal' if return_similarities: if return_pairwise_similarities: return cosine_sim cosine_sim = cosine_sim * self.temperature.exp() Loading Loading @@ -661,13 +682,34 @@ class MusicLM(nn.Module): @torch.no_grad() def forward( self, raw_texts: List[str], text: str, num_samples = 1, **audio_lm_kwargs ): self.eval() texts = tokenizer.tokenize(raw_texts).to(self.device) texts = tokenizer.tokenize([text]).to(self.device) text_embeds = self.mulan_embed_quantizer(texts = texts) return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs) # unable to deal with variable lengthed audio for now samples = [] for _ in range(num_samples): music = self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs) samples.append(music) # if one sample, just return it if num_samples == 1: return first(samples) mulan = self.mulan_embed_quantizer.mulan # get the one with the highest similarity score, of all the samples sims = torch.cat([mulan(texts = texts, wavs = music, return_similarities = True) for music in samples], dim = 0) top_matching_index = sims.topk(1, dim = 0).indices.item() return samples[top_matching_index] setup.py +2 −2 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.0.26', version = '0.0.28', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading @@ -20,7 +20,7 @@ setup( ], install_requires=[ 'accelerate', 'audiolm-pytorch>=0.10.4', 'audiolm-pytorch>=0.17.0', 'beartype', 'einops>=0.6', 'lion-pytorch', Loading Loading
README.md +1 −1 Original line number Diff line number Diff line Loading @@ -133,7 +133,7 @@ musiclm = MusicLM( mulan_embed_quantizer = mulan_embed_quantizer ) music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.Tensor music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples = 4) # sample 4 and pick the top match with mulan ``` ## Todo Loading
musiclm_pytorch/musiclm_pytorch.py +48 −6 Original line number Diff line number Diff line Loading @@ -23,6 +23,9 @@ from beartype import beartype def exists(val): return val is not None def first(it): return it[0] def default(val, d): return val if exists(val) else d Loading Loading @@ -243,6 +246,8 @@ class AudioSpectrogramTransformer(nn.Module): attn_dropout = 0., ff_mult = 4, ff_dropout = 0., accept_spec = False, accept_spec_time_first = True, spec_n_fft = 128, spec_power = 2, spec_win_length = 24, Loading @@ -268,6 +273,9 @@ class AudioSpectrogramTransformer(nn.Module): nn.LayerNorm(dim) ) self.accept_spec = accept_spec self.accept_spec_time_first = accept_spec_time_first self.spec = Spectrogram( n_fft = spec_n_fft, power = spec_power, Loading Loading @@ -321,7 +329,12 @@ class AudioSpectrogramTransformer(nn.Module): force_no_patch_dropout = False ): batch, device = x.shape[0], x.device assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2) if self.accept_spec and self.accept_spec_time_first: x = rearrange(x, 'b t f -> b f t') if not self.accept_spec: x = self.spec(x) if self.training: Loading Loading @@ -525,18 +538,26 @@ class MuLaN(nn.Module): wavs, texts = None, raw_texts: Optional[List[str]] = None, return_similarities = False return_latents = False, return_similarities = False, return_pairwise_similarities = False ): batch, device = wavs.shape[0], wavs.device audio_latents = self.get_audio_latents(wavs) text_latents = self.get_text_latents(texts, raw_texts = raw_texts) if return_latents: return audio_latents, text_latents if return_similarities: return einsum('i d, i d -> i', audio_latents, text_latents) cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents) assert cosine_sim.shape[0] == cosine_sim.shape[1], 'batch sizes for audio and text are not equal' if return_similarities: if return_pairwise_similarities: return cosine_sim cosine_sim = cosine_sim * self.temperature.exp() Loading Loading @@ -661,13 +682,34 @@ class MusicLM(nn.Module): @torch.no_grad() def forward( self, raw_texts: List[str], text: str, num_samples = 1, **audio_lm_kwargs ): self.eval() texts = tokenizer.tokenize(raw_texts).to(self.device) texts = tokenizer.tokenize([text]).to(self.device) text_embeds = self.mulan_embed_quantizer(texts = texts) return self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs) # unable to deal with variable lengthed audio for now samples = [] for _ in range(num_samples): music = self.audio_lm(text_embeds = text_embeds, **audio_lm_kwargs) samples.append(music) # if one sample, just return it if num_samples == 1: return first(samples) mulan = self.mulan_embed_quantizer.mulan # get the one with the highest similarity score, of all the samples sims = torch.cat([mulan(texts = texts, wavs = music, return_similarities = True) for music in samples], dim = 0) top_matching_index = sims.topk(1, dim = 0).indices.item() return samples[top_matching_index]
setup.py +2 −2 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.0.26', version = '0.0.28', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading @@ -20,7 +20,7 @@ setup( ], install_requires=[ 'accelerate', 'audiolm-pytorch>=0.10.4', 'audiolm-pytorch>=0.17.0', 'beartype', 'einops>=0.6', 'lion-pytorch', Loading