Loading README.md +10 −0 Original line number Diff line number Diff line Loading @@ -170,4 +170,14 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T } ``` ```bibtex @article{Liu2022PatchDropoutEV, title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout}, author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith}, journal = {ArXiv}, year = {2022}, volume = {abs/2208.07220} } ``` *The only truth is music.* - Jack Kerouac musiclm_pytorch/musiclm_pytorch.py +27 −1 Original line number Diff line number Diff line Loading @@ -193,6 +193,27 @@ class Transformer(nn.Module): return x # Patch Dropout - https://arxiv.org/abs/2208.07220 class PatchDropout(nn.Module): def __init__(self, prob): super().__init__() assert 0 <= prob < 1. self.prob = prob def forward(self, x, force_keep_all = False): if not self.training or self.prob == 0. or force_keep_all: return x b, n, _, device = *x.shape, x.device batch_indices = torch.arange(b, device = device) batch_indices = rearrange(batch_indices, '... -> ... 1') num_patches_keep = max(1, int(n * (1 - self.prob))) patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices return x[batch_indices, patch_indices_keep] # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): Loading @@ -219,7 +240,8 @@ class AudioSpectrogramTransformer(nn.Module): spec_aug_stretch_factor = 0.8, spec_aug_freq_mask = 80, spec_aug_time_mask = 80, dual_patchnorm = True dual_patchnorm = True, patch_dropout_prob = 0.5 ): super().__init__() self.dim = dim Loading Loading @@ -264,6 +286,8 @@ class AudioSpectrogramTransformer(nn.Module): self.norm = LayerNorm(dim) self.patch_dropout = PatchDropout(patch_dropout_prob) def forward(self, x): x = self.spec(x) Loading Loading @@ -294,6 +318,8 @@ class AudioSpectrogramTransformer(nn.Module): x = rearrange(x, 'b ... c -> b (...) c') x = self.patch_dropout(x) x = self.transformer(x) # final global average and norm (most recent papers show this is superior to CLS token) Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.0.9', version = '0.0.10', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading Loading
README.md +10 −0 Original line number Diff line number Diff line Loading @@ -170,4 +170,14 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T } ``` ```bibtex @article{Liu2022PatchDropoutEV, title = {PatchDropout: Economizing Vision Transformers Using Patch Dropout}, author = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith}, journal = {ArXiv}, year = {2022}, volume = {abs/2208.07220} } ``` *The only truth is music.* - Jack Kerouac
musiclm_pytorch/musiclm_pytorch.py +27 −1 Original line number Diff line number Diff line Loading @@ -193,6 +193,27 @@ class Transformer(nn.Module): return x # Patch Dropout - https://arxiv.org/abs/2208.07220 class PatchDropout(nn.Module): def __init__(self, prob): super().__init__() assert 0 <= prob < 1. self.prob = prob def forward(self, x, force_keep_all = False): if not self.training or self.prob == 0. or force_keep_all: return x b, n, _, device = *x.shape, x.device batch_indices = torch.arange(b, device = device) batch_indices = rearrange(batch_indices, '... -> ... 1') num_patches_keep = max(1, int(n * (1 - self.prob))) patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices return x[batch_indices, patch_indices_keep] # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): Loading @@ -219,7 +240,8 @@ class AudioSpectrogramTransformer(nn.Module): spec_aug_stretch_factor = 0.8, spec_aug_freq_mask = 80, spec_aug_time_mask = 80, dual_patchnorm = True dual_patchnorm = True, patch_dropout_prob = 0.5 ): super().__init__() self.dim = dim Loading Loading @@ -264,6 +286,8 @@ class AudioSpectrogramTransformer(nn.Module): self.norm = LayerNorm(dim) self.patch_dropout = PatchDropout(patch_dropout_prob) def forward(self, x): x = self.spec(x) Loading Loading @@ -294,6 +318,8 @@ class AudioSpectrogramTransformer(nn.Module): x = rearrange(x, 'b ... c -> b (...) c') x = self.patch_dropout(x) x = self.transformer(x) # final global average and norm (most recent papers show this is superior to CLS token) Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'musiclm-pytorch', packages = find_packages(exclude=[]), version = '0.0.9', version = '0.0.10', license='MIT', description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis', author = 'Phil Wang', Loading