Commit 6b630672 authored by Phil Wang's avatar Phil Wang
Browse files

add best relative positional encoding for AST

parent 598aa341
Loading
Loading
Loading
Loading
+12 −1
Original line number Diff line number Diff line
@@ -142,8 +142,8 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
- [x] wrap mulan with mulan wrapper and quantize the output, project to audiolm dimensions
- [x] modify audiolm to accept conditioning embeddings, optionally take care of different dimensions through a separate projection
- [x] audiolm and mulan goes into musiclm and generate, filter with mulan
- [x] give dynamic positional bias to self attention in AST

- [ ] give dynamic positional bias to self attention in AST
- [ ] add a version of mulan to <a href="https://github.com/mlfoundations/open_clip">open clip</a>
- [ ] set all the proper spectrogram hyperparameters

@@ -189,6 +189,17 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
}
```

```bibtex
@misc{liu2021swin,
    title   = {Swin Transformer V2: Scaling Up Capacity and Resolution},
    author  = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
    year    = {2021},
    eprint  = {2111.09883},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
```

*The only truth is music.* - Jack Kerouac

*Music is the universal language of mankind.* - Henry Wadsworth Longfellow
+67 −29
Original line number Diff line number Diff line
@@ -136,6 +136,7 @@ class Attention(nn.Module):
    def forward(
        self,
        x,
        rel_pos_bias = None,
        mask = None
    ):
        b, n, _, device = *x.shape, x.device
@@ -158,6 +159,9 @@ class Attention(nn.Module):

        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        if exists(rel_pos_bias):
            sim = sim + rel_pos_bias

        if exists(mask):
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
@@ -202,35 +206,19 @@ class Transformer(nn.Module):
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
            ]))

    def forward(self, x, mask = None):
    def forward(
        self,
        x,
        rel_pos_bias = None,
        mask = None
    ):

        for attn, ff in self.layers:
            x = attn(x, mask = mask) + x
            x = attn(x, rel_pos_bias = rel_pos_bias, mask = mask) + x
            x = ff(x) + x

        return x

# Patch Dropout - https://arxiv.org/abs/2208.07220

class PatchDropout(nn.Module):
    def __init__(self, prob):
        super().__init__()
        assert 0 <= prob < 1.
        self.prob = prob

    def forward(self, x, force_keep_all = False):
        if not self.training or self.prob == 0. or force_keep_all:
            return x

        b, n, _, device = *x.shape, x.device

        batch_indices = torch.arange(b, device = device)
        batch_indices = rearrange(batch_indices, '... -> ... 1')
        num_patches_keep = max(1, int(n * (1 - self.prob)))
        patch_indices_keep = torch.randn(b, n, device = device).topk(num_patches_keep, dim = -1).indices

        return x[batch_indices, patch_indices_keep]

# Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778

def pair(t):
@@ -302,9 +290,30 @@ class AudioSpectrogramTransformer(nn.Module):

        self.norm = LayerNorm(dim)

        self.patch_dropout = PatchDropout(patch_dropout_prob)
        # patch dropout

        self.patch_dropout_prob = patch_dropout_prob

        # 2d dynamic positional bias

        mlp_hidden_dim = dim // 4

        self.dynamic_pos_bias_mlp = nn.Sequential(
            nn.Linear(2, mlp_hidden_dim),
            nn.SiLU(),
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.SiLU(),
            nn.Linear(mlp_hidden_dim, heads),
            Rearrange('b i j h -> b h i j')
        )

    def forward(
        self,
        x,
        force_no_patch_dropout = False
    ):
        batch, device = x.shape[0], x.device

    def forward(self, x):
        x = self.spec(x)

        if self.training:
@@ -326,17 +335,46 @@ class AudioSpectrogramTransformer(nn.Module):

        x = self.to_patch_tokens(x)

        # get number of patches along height and width

        num_patch_height, num_patch_width = x.shape[-2:]

        # get 2d relative positions

        grid = torch.stack(torch.meshgrid(
            torch.arange(num_patch_height, device = device),
            torch.arange(num_patch_width, device = device)
        , indexing = 'ij'), dim = -1)

        grid = repeat(grid, '... c -> b (...) c', b = batch)

        # 2d sinusoidal positional embedding

        x = x + posemb_sincos_2d(x)

        # attention, what else

        x = rearrange(x, 'b ... c -> b (...) c')

        x = self.patch_dropout(x)
        # patch dropout

        if self.training and self.patch_dropout_prob > 0. and not force_no_patch_dropout:
            n, device = x.shape[1], x.device

            batch_indices = torch.arange(batch, device = device)
            batch_indices = rearrange(batch_indices, '... -> ... 1')
            num_patches_keep = max(1, int(n * (1 - self.patch_dropout_prob)))
            patch_indices_keep = torch.randn(batch, n, device = device).topk(num_patches_keep, dim = -1).indices

            x = x[batch_indices, patch_indices_keep]
            grid = grid[batch_indices, patch_indices_keep]

        # 2d relative positional bias

        rel_dist = rearrange(grid, 'b i c -> b i 1 c') - rearrange(grid, 'b j c -> b 1 j c')
        rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())

        # attention, what else

        x = self.transformer(x)
        x = self.transformer(x, rel_pos_bias = rel_pos_bias)

        # final global average and norm (most recent papers show this is superior to CLS token)

+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.19',
  version = '0.0.20',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',