Loading README.md +8 −0 Original line number Diff line number Diff line Loading @@ -220,6 +220,14 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples } ``` ```bibtex @inproceedings{Zhai2023SigmoidLF, title = {Sigmoid Loss for Language Image Pre-Training}, author = {Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer}, year = {2023} } ``` *The only truth is music.* - Jack Kerouac *Music is the universal language of mankind.* - Henry Wadsworth Longfellow musiclm_pytorch/musiclm_pytorch.py +84 −39 Original line number Diff line number Diff line import math from functools import wraps import torch Loading Loading @@ -248,6 +249,76 @@ class Transformer(nn.Module): return x, torch.stack(layers[:-1]) # contrastive losses class SoftmaxContrastiveLearning(nn.Module): def __init__( self, *, layers = 1, decoupled_contrastive_learning = False, init_temp = 10 ): super().__init__() self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) self.decoupled_contrastive_learning = decoupled_contrastive_learning @property def device(self): return next(self.parameters()).device def forward(self, sims): batch = sims.shape[-1] if sims.ndim == 2: sims = rearrange(sims, 'i j -> 1 i j') sims = sims * self.temperatures.exp() cosine_sims_exp = sims.exp() numerator = matrix_diag(cosine_sims_exp) if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = self.device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean') return contrastive_loss.sum() class SigmoidContrastiveLearning(nn.Module): """ https://arxiv.org/abs/2303.15343 """ def __init__( self, *, layers = 1, init_temp = 10, init_bias = -10 ): super().__init__() self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias) @property def device(self): return next(self.parameters()).device def forward(self, sims): if sims.ndim == 2: sims = rearrange(sims, 'i j -> 1 i j') n = sims.shape[-1] sims = sims * self.temperatures.exp() + self.bias labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims) return -F.logsigmoid(labels * sims).sum() / n # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): Loading Loading @@ -539,7 +610,8 @@ class MultiLayerContrastiveLoss(nn.Module): text_dim, dim_latent, layers, decoupled_contrastive_learning = False decoupled_contrastive_learning = False, sigmoid_contrastive_loss = False ): super().__init__() self.layers = layers Loading @@ -554,9 +626,8 @@ class MultiLayerContrastiveLoss(nn.Module): self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent)) self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent)) self.temperatures = nn.Parameter(torch.ones(layers, 1, 1)) self.decoupled_contrastive_learning = decoupled_contrastive_learning klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) self.contrast = klass(layers = layers) def forward(self, *, audio_layers, text_layers): device, batch = audio_layers.device, audio_layers.shape[1] Loading @@ -571,23 +642,9 @@ class MultiLayerContrastiveLoss(nn.Module): text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias text_latents = l2norm(text_latents) cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp() cosine_sims_exp = cosine_sims.exp() cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) numerator = matrix_diag(cosine_sims_exp) if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean') return contrastive_loss.sum() return self.contrast(cosine_sims) # main classes Loading @@ -600,7 +657,8 @@ class MuLaN(nn.Module): dim_latent = 128, # they use 128 decoupled_contrastive_learning = True, # think this was used, make it optional hierarchical_contrastive_loss = False, hierarchical_contrastive_loss_layers = None hierarchical_contrastive_loss_layers = None, sigmoid_contrastive_loss = False ): super().__init__() self.dim_latent = dim_latent Loading @@ -608,12 +666,12 @@ class MuLaN(nn.Module): self.audio = audio_transformer self.text = text_transformer self.temperature = nn.Parameter(torch.tensor(1.)) self.text_to_latents = nn.Linear(self.text.dim, dim_latent) self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent) self.decoupled_contrastive_learning = decoupled_contrastive_learning klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) self.contrast = klass() self.multi_layer_contrastive_learning = None Loading @@ -629,7 +687,8 @@ class MuLaN(nn.Module): text_dim = self.text.dim, dim_latent = dim_latent, layers = num_layers, decoupled_contrastive_learning = decoupled_contrastive_learning decoupled_contrastive_learning = decoupled_contrastive_learning, sigmoid_contrastive_loss = sigmoid_contrastive_loss ) def get_audio_latents( Loading Loading @@ -688,21 +747,7 @@ class MuLaN(nn.Module): if return_pairwise_similarities: return cosine_sim cosine_sim = cosine_sim * self.temperature.exp() cosine_sim_exp = cosine_sim.exp() numerator = cosine_sim_exp.diag() if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.) denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum') denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) cl_loss = contrastive_loss.mean() cl_loss = self.contrast(cosine_sim) if not exists(self.multi_layer_contrastive_learning): return cl_loss Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.1.2', version = '0.2.0', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading Loading
README.md +8 −0 Original line number Diff line number Diff line Loading @@ -220,6 +220,14 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples } ``` ```bibtex @inproceedings{Zhai2023SigmoidLF, title = {Sigmoid Loss for Language Image Pre-Training}, author = {Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer}, year = {2023} } ``` *The only truth is music.* - Jack Kerouac *Music is the universal language of mankind.* - Henry Wadsworth Longfellow
musiclm_pytorch/musiclm_pytorch.py +84 −39 Original line number Diff line number Diff line import math from functools import wraps import torch Loading Loading @@ -248,6 +249,76 @@ class Transformer(nn.Module): return x, torch.stack(layers[:-1]) # contrastive losses class SoftmaxContrastiveLearning(nn.Module): def __init__( self, *, layers = 1, decoupled_contrastive_learning = False, init_temp = 10 ): super().__init__() self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) self.decoupled_contrastive_learning = decoupled_contrastive_learning @property def device(self): return next(self.parameters()).device def forward(self, sims): batch = sims.shape[-1] if sims.ndim == 2: sims = rearrange(sims, 'i j -> 1 i j') sims = sims * self.temperatures.exp() cosine_sims_exp = sims.exp() numerator = matrix_diag(cosine_sims_exp) if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = self.device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean') return contrastive_loss.sum() class SigmoidContrastiveLearning(nn.Module): """ https://arxiv.org/abs/2303.15343 """ def __init__( self, *, layers = 1, init_temp = 10, init_bias = -10 ): super().__init__() self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp)) self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias) @property def device(self): return next(self.parameters()).device def forward(self, sims): if sims.ndim == 2: sims = rearrange(sims, 'i j -> 1 i j') n = sims.shape[-1] sims = sims * self.temperatures.exp() + self.bias labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims) return -F.logsigmoid(labels * sims).sum() / n # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): Loading Loading @@ -539,7 +610,8 @@ class MultiLayerContrastiveLoss(nn.Module): text_dim, dim_latent, layers, decoupled_contrastive_learning = False decoupled_contrastive_learning = False, sigmoid_contrastive_loss = False ): super().__init__() self.layers = layers Loading @@ -554,9 +626,8 @@ class MultiLayerContrastiveLoss(nn.Module): self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent)) self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent)) self.temperatures = nn.Parameter(torch.ones(layers, 1, 1)) self.decoupled_contrastive_learning = decoupled_contrastive_learning klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) self.contrast = klass(layers = layers) def forward(self, *, audio_layers, text_layers): device, batch = audio_layers.device, audio_layers.shape[1] Loading @@ -571,23 +642,9 @@ class MultiLayerContrastiveLoss(nn.Module): text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias text_latents = l2norm(text_latents) cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp() cosine_sims_exp = cosine_sims.exp() cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) numerator = matrix_diag(cosine_sims_exp) if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean') return contrastive_loss.sum() return self.contrast(cosine_sims) # main classes Loading @@ -600,7 +657,8 @@ class MuLaN(nn.Module): dim_latent = 128, # they use 128 decoupled_contrastive_learning = True, # think this was used, make it optional hierarchical_contrastive_loss = False, hierarchical_contrastive_loss_layers = None hierarchical_contrastive_loss_layers = None, sigmoid_contrastive_loss = False ): super().__init__() self.dim_latent = dim_latent Loading @@ -608,12 +666,12 @@ class MuLaN(nn.Module): self.audio = audio_transformer self.text = text_transformer self.temperature = nn.Parameter(torch.tensor(1.)) self.text_to_latents = nn.Linear(self.text.dim, dim_latent) self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent) self.decoupled_contrastive_learning = decoupled_contrastive_learning klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) self.contrast = klass() self.multi_layer_contrastive_learning = None Loading @@ -629,7 +687,8 @@ class MuLaN(nn.Module): text_dim = self.text.dim, dim_latent = dim_latent, layers = num_layers, decoupled_contrastive_learning = decoupled_contrastive_learning decoupled_contrastive_learning = decoupled_contrastive_learning, sigmoid_contrastive_loss = sigmoid_contrastive_loss ) def get_audio_latents( Loading Loading @@ -688,21 +747,7 @@ class MuLaN(nn.Module): if return_pairwise_similarities: return cosine_sim cosine_sim = cosine_sim * self.temperature.exp() cosine_sim_exp = cosine_sim.exp() numerator = cosine_sim_exp.diag() if self.decoupled_contrastive_learning: eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.) denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum') denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) cl_loss = contrastive_loss.mean() cl_loss = self.contrast(cosine_sim) if not exists(self.multi_layer_contrastive_learning): return cl_loss Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.1.2', version = '0.2.0', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading