Commit 7dcfc974 authored by Phil Wang's avatar Phil Wang
Browse files

add patch dropout

parent 5aaee642
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -170,4 +170,14 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
}
```

```bibtex
@article{Liu2022PatchDropoutEV,
    title   = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
    author  = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.07220}
}
```

*The only truth is music.* - Jack Kerouac
+27 −1
Original line number Diff line number Diff line
@@ -193,6 +193,27 @@ class Transformer(nn.Module):

        return x

# Patch Dropout - https://arxiv.org/abs/2208.07220

class PatchDropout(nn.Module):
    def __init__(self, prob):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob

    def forward(self, x, force_keep_all = False):
        if not self.training or self.prob == 0. or force_keep_all:
            return x

        b, n, _, device = *x.shape, x.device

        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

def pair(t):
@@ -219,7 +240,8 @@ class AudioSpectrogramTransformer(nn.Module):
        spec_aug_stretch_factor = 0.8,
        spec_aug_freq_mask = 80,
        spec_aug_time_mask = 80,
        dual_patchnorm = True
        dual_patchnorm = True,
        patch_dropout_prob = 0.5
    ):
        super().__init__()
        self.dim = dim
@@ -264,6 +286,8 @@ class AudioSpectrogramTransformer(nn.Module):

        self.norm = LayerNorm(dim)

        self.patch_dropout = PatchDropout(patch_dropout_prob)

    def forward(self, x):
        x = self.spec(x)

@@ -294,6 +318,8 @@ class AudioSpectrogramTransformer(nn.Module):

        x = rearrange(x, 'b ... c -> b (...) c')

        x = self.patch_dropout(x)

        x = self.transformer(x)

        # final global average and norm (most recent papers show this is superior to CLS token)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.9',
  version = '0.0.10',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',