Commit ff4c75f8 authored by Phil Wang's avatar Phil Wang
Browse files

add decoupled contrastive learning for mulan as an option

parent 022860c9
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -51,12 +51,13 @@ loss.backward()

## Todo

- [x] mulan seems to be using decoupled contrastive learning, offer that as an option

- [ ] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions
- [ ] modify audiolm to accept conditioning embeddings, optionally take care of different dimensions through a separate projection
- [ ] audiolm and mulan goes into musiclm and generate, filter with mulan
- [ ] add a version of mulan to <a href="https://github.com/mlfoundations/open_clip">open clip</a>
- [ ] set all the proper spectrogram hyperparameters
- [ ] mulan seems to be using decoupled contrastive learning, offer that as an option
- [ ] email some contrastive learning experts and figure out why some papers are sharing the projection from embeddings to latent space

## Appreciation
+19 −12
Original line number Diff line number Diff line
@@ -375,7 +375,8 @@ class MuLaN(nn.Module):
        self,
        audio_transformer: AudioSpectrogramTransformer,
        text_transformer: TextTransformer,
        dim_latent = 128 # they use 128
        dim_latent = 128,                       # they use 128
        decoupled_contrastive_learning = True,  # think this was used, make it optional
    ):
        super().__init__()
        self.dim_latent = dim_latent
@@ -388,6 +389,8 @@ class MuLaN(nn.Module):
        self.text_to_latents = nn.Linear(self.text.dim, dim_latent)
        self.audio_to_latents = nn.Linear(self.audio.dim, dim_latent)

        self.decoupled_contrastive_learning = decoupled_contrastive_learning

    def get_audio_latents(
        self,
        wavs
@@ -398,7 +401,8 @@ class MuLaN(nn.Module):

    def get_text_latents(
        self,
        texts,
        texts = None,
        raw_texts: Optional[List[str]] = None
    ):
        text_embeds = self.text(texts)
        text_latents = self.text_to_latents(text_embeds)
@@ -413,13 +417,8 @@ class MuLaN(nn.Module):
    ):
        batch, device = wavs.shape[0], wavs.device

        audio_embeds = self.audio(wavs)
        text_embeds = self.text(texts, raw_texts = raw_texts)

        audio_latents = self.audio_to_latents(audio_embeds)
        text_latents = self.text_to_latents(text_embeds)

        audio_latents, text_latents = map(l2norm, (audio_latents, text_latents))
        audio_latents = self.get_audio_latents(wavs)
        text_latents = self.get_text_latents(texts, raw_texts = raw_texts)

        cosine_sim = einsum('i d, j d -> i j', audio_latents, text_latents)

@@ -430,10 +429,18 @@ class MuLaN(nn.Module):

        cosine_sim = cosine_sim * self.temperature.exp()

        labels = torch.arange(batch, device = device)
        cosine_sim_exp = cosine_sim.exp()

        numerator = cosine_sim_exp.diag()

        if self.decoupled_contrastive_learning:
            eye = torch.eye(batch, device = device)
            cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)

        denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')

        contrastive_loss = F.cross_entropy(cosine_sim, labels)
        return contrastive_loss
        contrastive_loss = -log(numerator / denominator)
        return contrastive_loss.mean()

# music lm

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.1',
  version = '0.0.2',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',