Loading README.md +9 −0 Original line number Diff line number Diff line Loading @@ -211,6 +211,15 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples } ``` ```bibtex @inproceedings{Shukor2022EfficientVP, title = {Efficient Vision-Language Pretraining with Visual Concepts and Hierarchical Alignment}, author = {Mustafa Shukor and Guillaume Couairon and Matthieu Cord}, booktitle = {British Machine Vision Conference}, year = {2022} } ``` *The only truth is music.* - Jack Kerouac *Music is the universal language of mankind.* - Henry Wadsworth Longfellow musiclm_pytorch/musiclm_pytorch.py +161 −21 Original line number Diff line number Diff line Loading @@ -58,6 +58,16 @@ def log(t, eps = 1e-20): def l2norm(t): return F.normalize(t, p = 2, dim = -1) def matrix_diag(t): device = t.device i, j = t.shape[-2:] num_diag_el = min(i, j) i_range = torch.arange(i, device = device) j_range = torch.arange(j, device = device) diag_mask = rearrange(i_range, 'i -> i 1') == rearrange(j_range, 'j -> 1 j') diag_el = t.masked_select(diag_mask) return rearrange(diag_el, '(b d) -> b d', d = num_diag_el) # 2d sinusoidal positional embedding # simple vit paper shows it is good enough compared to learned Loading @@ -81,13 +91,15 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): # biasless layernorm class LayerNorm(nn.Module): def __init__(self, dim): def __init__(self, dim, scale = True): super().__init__() self.gamma = nn.Parameter(torch.ones(dim)) self.register_buffer('beta', torch.zeros(dim)) self.learned_gamma = nn.Parameter(torch.ones(dim)) if scale else None self.register_buffer('gamma', torch.ones(dim), persistent = False) self.register_buffer('beta', torch.zeros(dim), persistent = False) def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) return F.layer_norm(x, x.shape[-1:], default(self.learned_gamma, self.gamma), self.beta) # feedforward Loading Loading @@ -221,15 +233,21 @@ class Transformer(nn.Module): self, x, rel_pos_bias = None, mask = None mask = None, return_all_layers = False ): layers = [] for attn, ff in self.layers: x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x x = ff(x) + x layers.append(x) if not return_all_layers: return x return x, torch.stack(layers[:-1]) # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): Loading Loading @@ -262,6 +280,7 @@ class AudioSpectrogramTransformer(nn.Module): ): super().__init__() self.dim = dim self.depth = depth self.patch_size = pair(patch_size) patch_input_dim = self.patch_size[0] * self.patch_size[1] Loading Loading @@ -326,7 +345,8 @@ class AudioSpectrogramTransformer(nn.Module): def forward( self, x, force_no_patch_dropout = False force_no_patch_dropout = False, return_all_layers = False ): batch, device = x.shape[0], x.device assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2) Loading Loading @@ -397,13 +417,18 @@ class AudioSpectrogramTransformer(nn.Module): # attention, what else x = self.transformer(x, rel_pos_bias = rel_pos_bias) x, all_layers = self.transformer(x, rel_pos_bias = rel_pos_bias, return_all_layers = True) # final global average and norm (most recent papers show this is superior to CLS token) x = reduce(x, 'b n d -> b d', 'mean') return self.norm(x) out = self.norm(x) if not return_all_layers: return out return out, all_layers # text transformer Loading @@ -428,6 +453,7 @@ class TextTransformer(nn.Module): self.token_emb = nn.Embedding(num_tokens, dim) self.pos_emb = nn.Embedding(max_seq_len, dim) self.depth = depth self.max_seq_len = max_seq_len self.cls_token = nn.Parameter(torch.randn(dim)) Loading @@ -453,7 +479,8 @@ class TextTransformer(nn.Module): self, x = None, raw_texts: Optional[List[str]] = None, mask = None mask = None, return_all_layers = False ): assert exists(x) ^ exists(raw_texts) Loading Loading @@ -484,13 +511,87 @@ class TextTransformer(nn.Module): # attention x = self.transformer(x, mask = mask) x, all_layers = self.transformer(x, mask = mask, return_all_layers = True) # unpack the cls tokens cls_tokens, _ = unpack(x, ps, 'b * d') return self.norm(cls_tokens) out = self.norm(cls_tokens) if not return_all_layers: return out return out, all_layers # hierarchical cl loss def pick_layers_evenly_interspersed(layers, tensor): total_layers, device = tensor.shape[0], tensor.device assert total_layers >= layers step = total_layers / layers indices = (torch.arange(0, layers) * step).floor().long() return tensor[indices] class MultiLayerContrastiveLoss(nn.Module): def __init__( self, *, audio_dim, text_dim, dim_latent, layers, decoupled_contrastive_learning = False ): super().__init__() self.layers = layers self.audio_norm = LayerNorm(audio_dim, scale = False) self.audio_gamma = nn.Parameter(torch.ones(layers, 1, audio_dim)) self.audio_latent_weight = nn.Parameter(torch.randn(layers, audio_dim, dim_latent)) self.audio_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent)) self.text_norm = LayerNorm(text_dim, scale = False) self.text_gamma = nn.Parameter(torch.ones(layers, 1, text_dim)) self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent)) self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent)) self.temperatures = nn.Parameter(torch.ones(layers, 1, 1)) self.decoupled_contrastive_learning = decoupled_contrastive_learning def forward(self, *, audio_layers, text_layers): device, batch = audio_layers.device, audio_layers.shape[1] audio_layers = pick_layers_evenly_interspersed(self.layers, audio_layers) text_layers = pick_layers_evenly_interspersed(self.layers, text_layers) audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean') audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias audio_latents = l2norm(audio_latents) text_cls_tokens = text_layers[:, :, 0] text_embeds = self.text_norm(text_cls_tokens) * self.text_gamma text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias text_latents = l2norm(text_latents) cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp() cosine_sims_exp = cosine_sims.exp() numerator = matrix_diag(cosine_sims_exp) if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') contrastive_loss = -log(numerator) + log(denominator) contrastive_loss = reduce(contrastive_loss, 'l i -> l', 'mean') return contrastive_loss.sum() # main classes Loading @@ -502,6 +603,7 @@ class MuLaN(nn.Module): text_transformer: TextTransformer, dim_latent = 128, # they use 128 decoupled_contrastive_learning = True, # think this was used, make it optional hierarchical_contrastive_loss = False ): super().__init__() self.dim_latent = dim_latent Loading @@ -516,22 +618,48 @@ class MuLaN(nn.Module): self.decoupled_contrastive_learning = decoupled_contrastive_learning self.multi_layer_contrastive_learning = None if hierarchical_contrastive_loss: num_layers = min(audio_transformer.depth, text_transformer.depth) - 1 assert num_layers > 0 self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss( audio_dim = self.audio.dim, text_dim = self.text.dim, dim_latent = dim_latent, layers = num_layers, decoupled_contrastive_learning = decoupled_contrastive_learning ) def get_audio_latents( self, wavs wavs, return_all_layers = False ): audio_embeds = self.audio(wavs) audio_embeds, audio_layers = self.audio(wavs, return_all_layers = True) audio_latents = self.audio_to_latents(audio_embeds) return l2norm(audio_latents) out = l2norm(audio_latents) if not return_all_layers: return out return out, audio_layers def get_text_latents( self, texts = None, raw_texts: Optional[List[str]] = None raw_texts: Optional[List[str]] = None, return_all_layers = False ): text_embeds = self.text(texts, raw_texts = raw_texts) text_embeds, text_layers = self.text(texts, raw_texts = raw_texts, return_all_layers = True) text_latents = self.text_to_latents(text_embeds) return l2norm(text_latents) out = l2norm(text_latents) if not return_all_layers: return out return out, text_layers def forward( self, Loading @@ -544,8 +672,8 @@ class MuLaN(nn.Module): ): batch, device = wavs.shape[0], wavs.device audio_latents = self.get_audio_latents(wavs) text_latents = self.get_text_latents(texts, raw_texts = raw_texts) audio_latents, audio_layers = self.get_audio_latents(wavs, return_all_layers = True) text_latents, text_layers = self.get_text_latents(texts, raw_texts = raw_texts, return_all_layers = True) if return_latents: return audio_latents, text_latents Loading Loading @@ -573,7 +701,19 @@ class MuLaN(nn.Module): denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum') contrastive_loss = -log(numerator) + log(denominator) return contrastive_loss.mean() cl_loss = contrastive_loss.mean() if not exists(self.multi_layer_contrastive_learning): return cl_loss # whether to do cl loss across all layers, from ViCHA paper https://arxiv.org/abs/2208.13628 hierarchical_cl_loss = self.multi_layer_contrastive_learning( audio_layers = audio_layers, text_layers = text_layers ) return cl_loss + hierarchical_cl_loss # music lm Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.0.28', version = '0.1.0', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading Loading
README.md +9 −0 Original line number Diff line number Diff line Loading @@ -211,6 +211,15 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples } ``` ```bibtex @inproceedings{Shukor2022EfficientVP, title = {Efficient Vision-Language Pretraining with Visual Concepts and Hierarchical Alignment}, author = {Mustafa Shukor and Guillaume Couairon and Matthieu Cord}, booktitle = {British Machine Vision Conference}, year = {2022} } ``` *The only truth is music.* - Jack Kerouac *Music is the universal language of mankind.* - Henry Wadsworth Longfellow
musiclm_pytorch/musiclm_pytorch.py +161 −21 Original line number Diff line number Diff line Loading @@ -58,6 +58,16 @@ def log(t, eps = 1e-20): def l2norm(t): return F.normalize(t, p = 2, dim = -1) def matrix_diag(t): device = t.device i, j = t.shape[-2:] num_diag_el = min(i, j) i_range = torch.arange(i, device = device) j_range = torch.arange(j, device = device) diag_mask = rearrange(i_range, 'i -> i 1') == rearrange(j_range, 'j -> 1 j') diag_el = t.masked_select(diag_mask) return rearrange(diag_el, '(b d) -> b d', d = num_diag_el) # 2d sinusoidal positional embedding # simple vit paper shows it is good enough compared to learned Loading @@ -81,13 +91,15 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): # biasless layernorm class LayerNorm(nn.Module): def __init__(self, dim): def __init__(self, dim, scale = True): super().__init__() self.gamma = nn.Parameter(torch.ones(dim)) self.register_buffer('beta', torch.zeros(dim)) self.learned_gamma = nn.Parameter(torch.ones(dim)) if scale else None self.register_buffer('gamma', torch.ones(dim), persistent = False) self.register_buffer('beta', torch.zeros(dim), persistent = False) def forward(self, x): return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) return F.layer_norm(x, x.shape[-1:], default(self.learned_gamma, self.gamma), self.beta) # feedforward Loading Loading @@ -221,15 +233,21 @@ class Transformer(nn.Module): self, x, rel_pos_bias = None, mask = None mask = None, return_all_layers = False ): layers = [] for attn, ff in self.layers: x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x x = ff(x) + x layers.append(x) if not return_all_layers: return x return x, torch.stack(layers[:-1]) # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): Loading Loading @@ -262,6 +280,7 @@ class AudioSpectrogramTransformer(nn.Module): ): super().__init__() self.dim = dim self.depth = depth self.patch_size = pair(patch_size) patch_input_dim = self.patch_size[0] * self.patch_size[1] Loading Loading @@ -326,7 +345,8 @@ class AudioSpectrogramTransformer(nn.Module): def forward( self, x, force_no_patch_dropout = False force_no_patch_dropout = False, return_all_layers = False ): batch, device = x.shape[0], x.device assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2) Loading Loading @@ -397,13 +417,18 @@ class AudioSpectrogramTransformer(nn.Module): # attention, what else x = self.transformer(x, rel_pos_bias = rel_pos_bias) x, all_layers = self.transformer(x, rel_pos_bias = rel_pos_bias, return_all_layers = True) # final global average and norm (most recent papers show this is superior to CLS token) x = reduce(x, 'b n d -> b d', 'mean') return self.norm(x) out = self.norm(x) if not return_all_layers: return out return out, all_layers # text transformer Loading @@ -428,6 +453,7 @@ class TextTransformer(nn.Module): self.token_emb = nn.Embedding(num_tokens, dim) self.pos_emb = nn.Embedding(max_seq_len, dim) self.depth = depth self.max_seq_len = max_seq_len self.cls_token = nn.Parameter(torch.randn(dim)) Loading @@ -453,7 +479,8 @@ class TextTransformer(nn.Module): self, x = None, raw_texts: Optional[List[str]] = None, mask = None mask = None, return_all_layers = False ): assert exists(x) ^ exists(raw_texts) Loading Loading @@ -484,13 +511,87 @@ class TextTransformer(nn.Module): # attention x = self.transformer(x, mask = mask) x, all_layers = self.transformer(x, mask = mask, return_all_layers = True) # unpack the cls tokens cls_tokens, _ = unpack(x, ps, 'b * d') return self.norm(cls_tokens) out = self.norm(cls_tokens) if not return_all_layers: return out return out, all_layers # hierarchical cl loss def pick_layers_evenly_interspersed(layers, tensor): total_layers, device = tensor.shape[0], tensor.device assert total_layers >= layers step = total_layers / layers indices = (torch.arange(0, layers) * step).floor().long() return tensor[indices] class MultiLayerContrastiveLoss(nn.Module): def __init__( self, *, audio_dim, text_dim, dim_latent, layers, decoupled_contrastive_learning = False ): super().__init__() self.layers = layers self.audio_norm = LayerNorm(audio_dim, scale = False) self.audio_gamma = nn.Parameter(torch.ones(layers, 1, audio_dim)) self.audio_latent_weight = nn.Parameter(torch.randn(layers, audio_dim, dim_latent)) self.audio_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent)) self.text_norm = LayerNorm(text_dim, scale = False) self.text_gamma = nn.Parameter(torch.ones(layers, 1, text_dim)) self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent)) self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent)) self.temperatures = nn.Parameter(torch.ones(layers, 1, 1)) self.decoupled_contrastive_learning = decoupled_contrastive_learning def forward(self, *, audio_layers, text_layers): device, batch = audio_layers.device, audio_layers.shape[1] audio_layers = pick_layers_evenly_interspersed(self.layers, audio_layers) text_layers = pick_layers_evenly_interspersed(self.layers, text_layers) audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean') audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias audio_latents = l2norm(audio_latents) text_cls_tokens = text_layers[:, :, 0] text_embeds = self.text_norm(text_cls_tokens) * self.text_gamma text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias text_latents = l2norm(text_latents) cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp() cosine_sims_exp = cosine_sims.exp() numerator = matrix_diag(cosine_sims_exp) if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') contrastive_loss = -log(numerator) + log(denominator) contrastive_loss = reduce(contrastive_loss, 'l i -> l', 'mean') return contrastive_loss.sum() # main classes Loading @@ -502,6 +603,7 @@ class MuLaN(nn.Module): text_transformer: TextTransformer, dim_latent = 128, # they use 128 decoupled_contrastive_learning = True, # think this was used, make it optional hierarchical_contrastive_loss = False ): super().__init__() self.dim_latent = dim_latent Loading @@ -516,22 +618,48 @@ class MuLaN(nn.Module): self.decoupled_contrastive_learning = decoupled_contrastive_learning self.multi_layer_contrastive_learning = None if hierarchical_contrastive_loss: num_layers = min(audio_transformer.depth, text_transformer.depth) - 1 assert num_layers > 0 self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss( audio_dim = self.audio.dim, text_dim = self.text.dim, dim_latent = dim_latent, layers = num_layers, decoupled_contrastive_learning = decoupled_contrastive_learning ) def get_audio_latents( self, wavs wavs, return_all_layers = False ): audio_embeds = self.audio(wavs) audio_embeds, audio_layers = self.audio(wavs, return_all_layers = True) audio_latents = self.audio_to_latents(audio_embeds) return l2norm(audio_latents) out = l2norm(audio_latents) if not return_all_layers: return out return out, audio_layers def get_text_latents( self, texts = None, raw_texts: Optional[List[str]] = None raw_texts: Optional[List[str]] = None, return_all_layers = False ): text_embeds = self.text(texts, raw_texts = raw_texts) text_embeds, text_layers = self.text(texts, raw_texts = raw_texts, return_all_layers = True) text_latents = self.text_to_latents(text_embeds) return l2norm(text_latents) out = l2norm(text_latents) if not return_all_layers: return out return out, text_layers def forward( self, Loading @@ -544,8 +672,8 @@ class MuLaN(nn.Module): ): batch, device = wavs.shape[0], wavs.device audio_latents = self.get_audio_latents(wavs) text_latents = self.get_text_latents(texts, raw_texts = raw_texts) audio_latents, audio_layers = self.get_audio_latents(wavs, return_all_layers = True) text_latents, text_layers = self.get_text_latents(texts, raw_texts = raw_texts, return_all_layers = True) if return_latents: return audio_latents, text_latents Loading Loading @@ -573,7 +701,19 @@ class MuLaN(nn.Module): denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum') contrastive_loss = -log(numerator) + log(denominator) return contrastive_loss.mean() cl_loss = contrastive_loss.mean() if not exists(self.multi_layer_contrastive_learning): return cl_loss # whether to do cl loss across all layers, from ViCHA paper https://arxiv.org/abs/2208.13628 hierarchical_cl_loss = self.multi_layer_contrastive_learning( audio_layers = audio_layers, text_layers = text_layers ) return cl_loss + hierarchical_cl_loss # music lm Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.0.28', version = '0.1.0', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading