Loading musiclm_pytorch/musiclm_pytorch.py +11 −10 Original line number Diff line number Diff line Loading @@ -526,13 +526,10 @@ class TextTransformer(nn.Module): # hierarchical cl loss def pick_layers_evenly_interspersed(layers, tensor): total_layers, device = tensor.shape[0], tensor.device def interspersed_indices(layers, total_layers): assert total_layers >= layers step = total_layers / layers indices = (torch.arange(0, layers) * step).floor().long() return tensor[indices] return (torch.arange(0, layers) * step).floor().long() class MultiLayerContrastiveLoss(nn.Module): def __init__( Loading Loading @@ -564,9 +561,6 @@ class MultiLayerContrastiveLoss(nn.Module): def forward(self, *, audio_layers, text_layers): device, batch = audio_layers.device, audio_layers.shape[1] audio_layers = pick_layers_evenly_interspersed(self.layers, audio_layers) text_layers = pick_layers_evenly_interspersed(self.layers, text_layers) audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean') audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias Loading Loading @@ -603,7 +597,8 @@ class MuLaN(nn.Module): text_transformer: TextTransformer, dim_latent = 128, # they use 128 decoupled_contrastive_learning = True, # think this was used, make it optional hierarchical_contrastive_loss = False hierarchical_contrastive_loss = False, hierarchical_contrastive_loss_layers = None ): super().__init__() self.dim_latent = dim_latent Loading @@ -621,9 +616,12 @@ class MuLaN(nn.Module): self.multi_layer_contrastive_learning = None if hierarchical_contrastive_loss: num_layers = min(audio_transformer.depth, text_transformer.depth) - 1 num_layers = default(hierarchical_contrastive_loss_layers, min(audio_transformer.depth, text_transformer.depth) - 1) assert num_layers > 0 self.register_buffer('text_layers_indices', interspersed_indices(num_layers, text_transformer.depth)) self.register_buffer('audio_layers_indices', interspersed_indices(num_layers, audio_transformer.depth)) self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss( audio_dim = self.audio.dim, text_dim = self.text.dim, Loading Loading @@ -706,6 +704,9 @@ class MuLaN(nn.Module): if not exists(self.multi_layer_contrastive_learning): return cl_loss audio_layers = audio_layers[self.audio_layers_indices] text_layers = text_layers[self.text_layers_indices] # whether to do cl loss across all layers, from ViCHA paper https://arxiv.org/abs/2208.13628 hierarchical_cl_loss = self.multi_layer_contrastive_learning( Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.1.0', version = '0.1.1', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading Loading
musiclm_pytorch/musiclm_pytorch.py +11 −10 Original line number Diff line number Diff line Loading @@ -526,13 +526,10 @@ class TextTransformer(nn.Module): # hierarchical cl loss def pick_layers_evenly_interspersed(layers, tensor): total_layers, device = tensor.shape[0], tensor.device def interspersed_indices(layers, total_layers): assert total_layers >= layers step = total_layers / layers indices = (torch.arange(0, layers) * step).floor().long() return tensor[indices] return (torch.arange(0, layers) * step).floor().long() class MultiLayerContrastiveLoss(nn.Module): def __init__( Loading Loading @@ -564,9 +561,6 @@ class MultiLayerContrastiveLoss(nn.Module): def forward(self, *, audio_layers, text_layers): device, batch = audio_layers.device, audio_layers.shape[1] audio_layers = pick_layers_evenly_interspersed(self.layers, audio_layers) text_layers = pick_layers_evenly_interspersed(self.layers, text_layers) audio_gap = reduce(audio_layers, 'l b n d -> l b d', 'mean') audio_embeds = self.audio_norm(audio_gap) * self.audio_gamma audio_latents = einsum('l b d, l d e -> l b e', audio_embeds, self.audio_latent_weight) + self.audio_latent_bias Loading Loading @@ -603,7 +597,8 @@ class MuLaN(nn.Module): text_transformer: TextTransformer, dim_latent = 128, # they use 128 decoupled_contrastive_learning = True, # think this was used, make it optional hierarchical_contrastive_loss = False hierarchical_contrastive_loss = False, hierarchical_contrastive_loss_layers = None ): super().__init__() self.dim_latent = dim_latent Loading @@ -621,9 +616,12 @@ class MuLaN(nn.Module): self.multi_layer_contrastive_learning = None if hierarchical_contrastive_loss: num_layers = min(audio_transformer.depth, text_transformer.depth) - 1 num_layers = default(hierarchical_contrastive_loss_layers, min(audio_transformer.depth, text_transformer.depth) - 1) assert num_layers > 0 self.register_buffer('text_layers_indices', interspersed_indices(num_layers, text_transformer.depth)) self.register_buffer('audio_layers_indices', interspersed_indices(num_layers, audio_transformer.depth)) self.multi_layer_contrastive_learning = MultiLayerContrastiveLoss( audio_dim = self.audio.dim, text_dim = self.text.dim, Loading Loading @@ -706,6 +704,9 @@ class MuLaN(nn.Module): if not exists(self.multi_layer_contrastive_learning): return cl_loss audio_layers = audio_layers[self.audio_layers_indices] text_layers = text_layers[self.text_layers_indices] # whether to do cl loss across all layers, from ViCHA paper https://arxiv.org/abs/2208.13628 hierarchical_cl_loss = self.multi_layer_contrastive_learning( Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.1.0', version = '0.1.1', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading