Loading musiclm_pytorch/musiclm_pytorch.py +8 −5 Original line number Diff line number Diff line Loading @@ -581,10 +581,12 @@ class MultiLayerContrastiveLoss(nn.Module): eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') contrastive_loss = -log(numerator) + log(denominator) denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum') contrastive_loss = reduce(contrastive_loss, 'l i -> l', 'mean') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean') return contrastive_loss.sum() # main classes Loading Loading @@ -696,9 +698,10 @@ class MuLaN(nn.Module): eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.) denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum') denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum') denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum') contrastive_loss = -log(numerator) + log(denominator) contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) cl_loss = contrastive_loss.mean() if not exists(self.multi_layer_contrastive_learning): Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.1.1', version = '0.1.2', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading Loading
musiclm_pytorch/musiclm_pytorch.py +8 −5 Original line number Diff line number Diff line Loading @@ -581,10 +581,12 @@ class MultiLayerContrastiveLoss(nn.Module): eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sims_exp = cosine_sims_exp.masked_fill(eye, 0.) denominator = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') contrastive_loss = -log(numerator) + log(denominator) denominator_i = reduce(cosine_sims_exp, 'l i j -> l i', 'sum') denominator_j = reduce(cosine_sims_exp, 'l i j -> l j', 'sum') contrastive_loss = reduce(contrastive_loss, 'l i -> l', 'mean') contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) contrastive_loss = reduce(contrastive_loss, 'l n -> l', 'mean') return contrastive_loss.sum() # main classes Loading Loading @@ -696,9 +698,10 @@ class MuLaN(nn.Module): eye = torch.eye(batch, device = device, dtype = torch.bool) cosine_sim_exp = cosine_sim_exp.masked_fill(eye, 0.) denominator = reduce(cosine_sim_exp, 'i j -> i', 'sum') denominator_i = reduce(cosine_sim_exp, 'i j -> i', 'sum') denominator_j = reduce(cosine_sim_exp, 'i j -> j', 'sum') contrastive_loss = -log(numerator) + log(denominator) contrastive_loss = -log(numerator) + 0.5 * (log(denominator_i) + log(denominator_j)) cl_loss = contrastive_loss.mean() if not exists(self.multi_layer_contrastive_learning): Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.1.1', version = '0.1.2', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading