Commit 80ad8ccf authored by Phil Wang's avatar Phil Wang
Browse files

add sigmoid contrastive loss

parent b95e54b2
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -220,6 +220,14 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples
}
```

```bibtex
@inproceedings{Zhai2023SigmoidLF,
    title   = {Sigmoid Loss for Language Image Pre-Training},
    author  = {Xiaohua Zhai and Basil Mustafa and Alexander Kolesnikov and Lucas Beyer},
    year    = {2023}
}
```

*The only truth is music.* - Jack Kerouac

*Music is the universal language of mankind.* - Henry Wadsworth Longfellow
+84 −39
Original line number Diff line number Diff line
import math
from functools import wraps

import torch
@@ -248,6 +249,76 @@ class Transformer(nn.Module):

        return x, torch.stack(layers[:-1])

# contrastive losses

class SoftmaxContrastiveLearning(nn.Module):
    def __init__(
        self,
        *,
        layers = 1,
        decoupled_contrastive_learning = False,
        init_temp = 10
    ):
        super().__init__()
        self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
        self.decoupled_contrastive_learning = decoupled_contrastive_learning

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, sims):
        batch = sims.shape[-1]

        if sims.ndim == 2:
            sims = rearrange(sims, 'i j -> 1 i j')

        sims = sims * self.temperatures.exp()

        cosine_sims_exp = sims.exp()

        numerator = matrix_diag(cosine_sims_exp)

        if self.decoupled_contrastive_learning:
            eye = torch.eye(batch, device = self.device, dtype = torch.bool)
            cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)

        denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
        denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')

        contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))

        contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
        return contrastive_loss.sum()

class SigmoidContrastiveLearning(nn.Module):
    """ https://arxiv.org/abs/2303.15343 """

    def __init__(
        self,
        *,
        layers = 1,
        init_temp = 10,
        init_bias = -10
    ):
        super().__init__()
        self.temperatures = nn.Parameter(torch.ones(layers, 1, 1) * math.log(init_temp))
        self.bias = nn.Parameter(torch.ones(layers, 1, 1) * init_bias)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, sims):
        if sims.ndim == 2:
            sims = rearrange(sims, 'i j -> 1 i j')

        n = sims.shape[-1]
        sims = sims * self.temperatures.exp() + self.bias
        labels = 2 * rearrange(torch.eye(n), 'i j -> 1 i j') - torch.ones_like(sims)

        return -F.logsigmoid(labels * sims).sum() / n

# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

def pair(t):
@@ -539,7 +610,8 @@ class MultiLayerContrastiveLoss(nn.Module):
        text_dim,
        dim_latent,
        layers,
        decoupled_contrastive_learning = False
        decoupled_contrastive_learning = False,
        sigmoid_contrastive_loss = False
    ):
        super().__init__()
        self.layers = layers
@@ -554,9 +626,8 @@ class MultiLayerContrastiveLoss(nn.Module):
        self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent))
        self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))

        self.temperatures = nn.Parameter(torch.ones(layers, 1, 1))

        self.decoupled_contrastive_learning = decoupled_contrastive_learning
        klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
        self.contrast = klass(layers = layers)

    def forward(self, *, audio_layers, text_layers):
        device, batch = audio_layers.device, audio_layers.shape[1]
@@ -571,23 +642,9 @@ class MultiLayerContrastiveLoss(nn.Module):
        text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias
        text_latents = l2norm(text_latents)

        cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp()

        cosine_sims_exp = cosine_sims.exp()
        cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents)

        numerator = matrix_diag(cosine_sims_exp)

        if self.decoupled_contrastive_learning:
            eye = torch.eye(batch, device = device, dtype = torch.bool)
            cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)

        denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
        denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')

        contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))

        contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
        return contrastive_loss.sum()
        return self.contrast(cosine_sims)

# main classes

@@ -600,7 +657,8 @@ class MuLaN(nn.Module):
        dim_latent = 128,                       # they use 128
        decoupled_contrastive_learning = True,  # think this was used, make it optional
        hierarchical_contrastive_loss = False,
        hierarchical_contrastive_loss_layers = None
        hierarchical_contrastive_loss_layers = None,
        sigmoid_contrastive_loss = False
    ):
        super().__init__()
        self.dim_latent = dim_latent
@@ -608,12 +666,12 @@ class MuLaN(nn.Module):
        self.audio = audio_transformer
        self.text = text_transformer

        self.temperature = nn.Parameter(torch.tensor(1.))

        self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
        self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)

        self.decoupled_contrastive_learning = decoupled_contrastive_learning
        klass = SigmoidContrastiveLearning if sigmoid_contrastive_loss else partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning)
        self.contrast = klass()

        self.multi_layer_contrastive_learning = None

@@ -629,7 +687,8 @@ class MuLaN(nn.Module):
                text_dim = self.text.dim,
                dim_latent = dim_latent,
                layers = num_layers,
                decoupled_contrastive_learning = decoupled_contrastive_learning
                decoupled_contrastive_learning = decoupled_contrastive_learning,
                sigmoid_contrastive_loss = sigmoid_contrastive_loss
            )

    def get_audio_latents(
@@ -688,21 +747,7 @@ class MuLaN(nn.Module):
        if return_pairwise_similarities:
            return cosine_sim

        cosine_sim = cosine_sim * self.temperature.exp()

        cosine_sim_exp = cosine_sim.exp()

        numerator = cosine_sim_exp.diag()

        if self.decoupled_contrastive_learning:
            eye = torch.eye(batch, device = device, dtype = torch.bool)
            cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)

        denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum')
        denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum')

        contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
        cl_loss = contrastive_loss.mean()
        cl_loss = self.contrast(cosine_sim)

        if not exists(self.multi_layer_contrastive_learning):
            return cl_loss
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.2',
  version = '0.2.0',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',