Commit b95e54b2 authored by Phil Wang's avatar Phil Wang
Browse files

contrastive loss needs both directions

parent cad7b494
Loading
Loading
Loading
Loading
+8 −5
Original line number Diff line number Diff line
@@ -581,10 +581,12 @@ class MultiLayerContrastiveLoss(nn.Module):
            eye = torch.eye(batch, device = device, dtype = torch.bool)
            cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)

        denominator = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
        contrastive_loss = -log(numerator) + log(denominator)
        denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
        denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum')

        contrastive_loss = reduce(contrastive_loss, 'l i -> l', 'mean')
        contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))

        contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean')
        return contrastive_loss.sum()

# main classes
@@ -696,9 +698,10 @@ class MuLaN(nn.Module):
            eye = torch.eye(batch, device = device, dtype = torch.bool)
            cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.)

        denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')
        denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum')
        denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum')

        contrastive_loss = -log(numerator) + log(denominator)
        contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j))
        cl_loss = contrastive_loss.mean()

        if not exists(self.multi_layer_contrastive_learning):
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.1',
  version = '0.1.2',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',