Commit cad7b494 authored by Phil Wang's avatar Phil Wang
Browse files

precompute audio and text layers to be used for hierarchical contrastive loss

parent 5f53d410
Loading
Loading
Loading
Loading
+11 −10
Original line number Diff line number Diff line
@@ -526,13 +526,10 @@ class TextTransformer(nn.Module):

# hierarchical cl loss

def pick_layers_evenly_interspersed(layers, tensor):
    total_layers, device = tensor.shape[0], tensor.device
def interspersed_indices(layers, total_layers):
    assert total_layers >= layers

    step = total_layers / layers
    indices = (torch.arange(0, layers) * step).floor().long()
    return tensor[indices]
    return (torch.arange(0, layers) * step).floor().long()

class MultiLayerContrastiveLoss(nn.Module):
    def __init__(
@@ -564,9 +561,6 @@ class MultiLayerContrastiveLoss(nn.Module):
    def forward(self, *, audio_layers, text_layers):
        device, batch = audio_layers.device, audio_layers.shape[1]

        audio_layers = pick_layers_evenly_interspersed(self.layers, audio_layers)
        text_layers = pick_layers_evenly_interspersed(self.layers, text_layers)

        audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean')
        audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma
        audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias
@@ -603,7 +597,8 @@ class MuLaN(nn.Module):
        text_transformer: TextTransformer,
        dim_latent = 128,                       # they use 128
        decoupled_contrastive_learning = True,  # think this was used, make it optional
        hierarchical_contrastive_loss = False
        hierarchical_contrastive_loss = False,
        hierarchical_contrastive_loss_layers = None
    ):
        super().__init__()
        self.dim_latent = dim_latent
@@ -621,9 +616,12 @@ class MuLaN(nn.Module):
        self.multi_layer_contrastive_learning = None

        if hierarchical_contrastive_loss:
            num_layers = min(audio_transformer.depth, text_transformer.depth) - 1
            num_layers = default(hierarchical_contrastive_loss_layers, min(audio_transformer.depth, text_transformer.depth) - 1)
            assert num_layers > 0

            self.register_buffer('text_layers_indices', interspersed_indices(num_layers, text_transformer.depth))
            self.register_buffer('audio_layers_indices', interspersed_indices(num_layers, audio_transformer.depth))

            self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss(
                audio_dim = self.audio.dim,
                text_dim = self.text.dim,
@@ -706,6 +704,9 @@ class MuLaN(nn.Module):
        if not exists(self.multi_layer_contrastive_learning):
            return cl_loss

        audio_layers = audio_layers[self.audio_layers_indices]
        text_layers = text_layers[self.text_layers_indices]

        # whether to do cl loss across all layers, from ViCHA paper https://arxiv.org/abs/2208.13628

        hierarchical_cl_loss = self.multi_layer_contrastive_learning(
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.0',
  version = '0.1.1',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',