Commit 5f53d410 authored by Phil Wang's avatar Phil Wang
Browse files

add the hierarchical contrastive loss from the ViCHA strategy in https://arxiv.org/abs/2208.13628

parent acee34a1
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -211,6 +211,15 @@ music = musiclm('the crystalline sounds of the piano in a ballroom', num_samples
}
```

```bibtex
@inproceedings{Shukor2022EfficientVP,
    title   = {Efficient Vision-Language Pretraining with Visual Concepts and Hierarchical Alignment},
    author  = {Mustafa Shukor and Guillaume Couairon and Matthieu Cord},
    booktitle = {British Machine Vision Conference},
    year    = {2022}
}
```

*The only truth is music.* - Jack Kerouac

*Music is the universal language of mankind.* - Henry Wadsworth Longfellow
+161 −21
Original line number Diff line number Diff line
@@ -58,6 +58,16 @@ def log(t, eps = 1e-20):
def l2norm(t):
    return F.normalize(t, p = 2, dim = -1)

def matrix_diag(t):
    device = t.device
    i, j = t.shape[-2:]
    num_diag_el = min(i, j)
    i_range = torch.arange(i, device = device)
    j_range = torch.arange(j, device = device)
    diag_mask = rearrange(i_range, 'i -> i 1') == rearrange(j_range, 'j -> 1 j')
    diag_el = t.masked_select(diag_mask)
    return rearrange(diag_el, '(b d) -> b d', d = num_diag_el)

# 2d sinusoidal positional embedding
# simple vit paper shows it is good enough compared to learned

@@ -81,13 +91,15 @@ def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32):
# biasless layernorm

class LayerNorm(nn.Module):
    def __init__(self, dim):
    def __init__(self, dim, scale = True):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer('beta', torch.zeros(dim))
        self.learned_gamma = nn.Parameter(torch.ones(dim)) if scale else None

        self.register_buffer('gamma', torch.ones(dim), persistent = False)
        self.register_buffer('beta', torch.zeros(dim), persistent = False)

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
        return F.layer_norm(x, x.shape[-1:], default(self.learned_gamma, self.gamma), self.beta)

# feedforward

@@ -221,15 +233,21 @@ class Transformer(nn.Module):
        self,
        x,
        rel_pos_bias = None,
        mask = None
        mask = None,
        return_all_layers = False
    ):
        layers = []

        for attn, ff in self.layers:
            x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x
            x = ff(x) + x
            layers.append(x)

        if not return_all_layers:
            return x

        return x, torch.stack(layers[:-1])

# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

def pair(t):
@@ -262,6 +280,7 @@ class AudioSpectrogramTransformer(nn.Module):
    ):
        super().__init__()
        self.dim = dim
        self.depth = depth

        self.patch_size = pair(patch_size)
        patch_input_dim = self.patch_size[0] * self.patch_size[1]
@@ -326,7 +345,8 @@ class AudioSpectrogramTransformer(nn.Module):
    def forward(
        self,
        x,
        force_no_patch_dropout = False
        force_no_patch_dropout = False,
        return_all_layers = False
    ):
        batch, device = x.shape[0], x.device
        assert (self.accept_spec and x.ndim == 3) or (not self.accept_spec and x.ndim == 2)
@@ -397,13 +417,18 @@ class AudioSpectrogramTransformer(nn.Module):

        # attention, what else

        x = self.transformer(x, rel_pos_bias = rel_pos_bias)
        x, all_layers = self.transformer(x, rel_pos_bias = rel_pos_bias, return_all_layers = True)

        # final global average and norm (most recent papers show this is superior to CLS token)

        x = reduce(x, 'b n d -> b d', 'mean')

        return self.norm(x)
        out = self.norm(x)

        if not return_all_layers:
            return out

        return out, all_layers

# text transformer

@@ -428,6 +453,7 @@ class TextTransformer(nn.Module):
        self.token_emb = nn.Embedding(num_tokens, dim)
        self.pos_emb = nn.Embedding(max_seq_len, dim)

        self.depth = depth
        self.max_seq_len = max_seq_len

        self.cls_token = nn.Parameter(torch.randn(dim))
@@ -453,7 +479,8 @@ class TextTransformer(nn.Module):
        self,
        x = None,
        raw_texts: Optional[List[str]] = None,
        mask = None
        mask = None,
        return_all_layers = False
    ):
        assert exists(x) ^ exists(raw_texts)

@@ -484,13 +511,87 @@ class TextTransformer(nn.Module):

        # attention

        x = self.transformer(x, mask = mask)
        x, all_layers = self.transformer(x, mask = mask, return_all_layers = True)

        # unpack the cls tokens

        cls_tokens, _ = unpack(x, ps, 'b * d')

        return self.norm(cls_tokens)
        out = self.norm(cls_tokens)

        if not return_all_layers:
            return out

        return out, all_layers

# hierarchical cl loss

def pick_layers_evenly_interspersed(layers, tensor):
    total_layers, device = tensor.shape[0], tensor.device
    assert total_layers >= layers

    step = total_layers / layers
    indices = (torch.arange(0, layers) * step).floor().long()
    return tensor[indices]

class MultiLayerContrastiveLoss(nn.Module):
    def __init__(
        self,
        *,
        audio_dim,
        text_dim,
        dim_latent,
        layers,
        decoupled_contrastive_learning = False
    ):
        super().__init__()
        self.layers = layers

        self.audio_norm = LayerNorm(audio_dim, scale = False)
        self.audio_gamma = nn.Parameter(torch.ones(layers, 1, audio_dim))
        self.audio_latent_weight = nn.Parameter(torch.randn(layers, audio_dim, dim_latent))
        self.audio_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))

        self.text_norm = LayerNorm(text_dim, scale = False)
        self.text_gamma = nn.Parameter(torch.ones(layers, 1, text_dim))
        self.text_latent_weight = nn.Parameter(torch.randn(layers, text_dim, dim_latent))
        self.text_latent_bias = nn.Parameter(torch.randn(layers, 1, dim_latent))

        self.temperatures = nn.Parameter(torch.ones(layers, 1, 1))

        self.decoupled_contrastive_learning = decoupled_contrastive_learning

    def forward(self, *, audio_layers, text_layers):
        device, batch = audio_layers.device, audio_layers.shape[1]

        audio_layers = pick_layers_evenly_interspersed(self.layers, audio_layers)
        text_layers = pick_layers_evenly_interspersed(self.layers, text_layers)

        audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean')
        audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma
        audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias
        audio_latents = l2norm(audio_latents)

        text_cls_tokens = text_layers[:, :, 0]
        text_embeds = self.text_norm(text_cls_tokens) * self.text_gamma
        text_latents = einsum('l b d, l d e -> l b e', text_embeds, self.text_latent_weight) + self.text_latent_bias
        text_latents = l2norm(text_latents)

        cosine_sims = einsum('l i d, l j d -> l i j', audio_latents, text_latents) * self.temperatures.exp()

        cosine_sims_exp = cosine_sims.exp()

        numerator = matrix_diag(cosine_sims_exp)

        if self.decoupled_contrastive_learning:
            eye = torch.eye(batch, device = device, dtype = torch.bool)
            cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.)

        denominator = reduce(cosine_sims_exp, 'l i j -> l i', 'sum')
        contrastive_loss = -log(numerator) + log(denominator)

        contrastive_loss = reduce(contrastive_loss, 'l i -> l', 'mean')
        return contrastive_loss.sum()

# main classes

@@ -502,6 +603,7 @@ class MuLaN(nn.Module):
        text_transformer: TextTransformer,
        dim_latent = 128,                       # they use 128
        decoupled_contrastive_learning = True,  # think this was used, make it optional
        hierarchical_contrastive_loss = False
    ):
        super().__init__()
        self.dim_latent = dim_latent
@@ -516,22 +618,48 @@ class MuLaN(nn.Module):

        self.decoupled_contrastive_learning = decoupled_contrastive_learning

        self.multi_layer_contrastive_learning = None

        if hierarchical_contrastive_loss:
            num_layers = min(audio_transformer.depth, text_transformer.depth) - 1
            assert num_layers > 0

            self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss(
                audio_dim = self.audio.dim,
                text_dim = self.text.dim,
                dim_latent = dim_latent,
                layers = num_layers,
                decoupled_contrastive_learning = decoupled_contrastive_learning
            )

    def get_audio_latents(
        self,
        wavs
        wavs,
        return_all_layers = False
    ):
        audio_embeds = self.audio(wavs)
        audio_embeds, audio_layers = self.audio(wavs, return_all_layers = True)
        audio_latents = self.audio_to_latents(audio_embeds)
        return l2norm(audio_latents)
        out = l2norm(audio_latents)

        if not return_all_layers:
            return out

        return out, audio_layers

    def get_text_latents(
        self,
        texts = None,
        raw_texts: Optional[List[str]] = None
        raw_texts: Optional[List[str]] = None,
        return_all_layers = False
    ):
        text_embeds = self.text(texts, raw_texts = raw_texts)
        text_embeds, text_layers = self.text(texts, raw_texts = raw_texts, return_all_layers = True)
        text_latents = self.text_to_latents(text_embeds)
        return l2norm(text_latents)
        out = l2norm(text_latents)

        if not return_all_layers:
            return out

        return out, text_layers

    def forward(
        self,
@@ -544,8 +672,8 @@ class MuLaN(nn.Module):
    ):
        batch, device = wavs.shape[0], wavs.device

        audio_latents = self.get_audio_latents(wavs)
        text_latents = self.get_text_latents(texts, raw_texts = raw_texts)
        audio_latents, audio_layers = self.get_audio_latents(wavs, return_all_layers = True)
        text_latents, text_layers = self.get_text_latents(texts, raw_texts = raw_texts, return_all_layers = True)

        if return_latents:
            return audio_latents, text_latents
@@ -573,7 +701,19 @@ class MuLaN(nn.Module):
        denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum')

        contrastive_loss = -log(numerator) + log(denominator)
        return contrastive_loss.mean()
        cl_loss = contrastive_loss.mean()

        if not exists(self.multi_layer_contrastive_learning):
            return cl_loss

        # whether to do cl loss across all layers, from ViCHA paper https://arxiv.org/abs/2208.13628

        hierarchical_cl_loss = self.multi_layer_contrastive_learning(
            audio_layers = audio_layers,
            text_layers = text_layers
        )

        return cl_loss + hierarchical_cl_loss

# music lm

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.28',
  version = '0.1.0',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',