Commit d0cdbb25 authored by Phil Wang's avatar Phil Wang
Browse files

repeat batch for relative positional bias only if doing patch dropout

parent 6b630672
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -304,7 +304,7 @@ class AudioSpectrogramTransformer(nn.Module):
            nn.Linear(mlp_hidden_dim, mlp_hidden_dim),
            nn.SiLU(),
            nn.Linear(mlp_hidden_dim, heads),
            Rearrange('b i j h -> b h i j')
            Rearrange('... i j h -> ... h i j')
        )

    def forward(
@@ -346,7 +346,7 @@ class AudioSpectrogramTransformer(nn.Module):
            torch.arange(num_patch_width, device = device)
        , indexing = 'ij'), dim = -1)

        grid = repeat(grid, '... c -> b (...) c', b = batch)
        grid = rearrange(grid, '... c -> (...) c')

        # 2d sinusoidal positional embedding

@@ -365,11 +365,13 @@ class AudioSpectrogramTransformer(nn.Module):
            patch_indices_keep = torch.randn(batch, n, device = device).topk(num_patches_keep, dim = -1).indices

            x = x[batch_indices, patch_indices_keep]

            grid = repeat(grid, '... -> b ...', b = batch)
            grid = grid[batch_indices, patch_indices_keep]

        # 2d relative positional bias

        rel_dist = rearrange(grid, 'b i c -> b i 1 c') - rearrange(grid, 'b j c -> b 1 j c')
        rel_dist = rearrange(grid, '... i c -> ... i 1 c') - rearrange(grid, '... j c -> ... 1 j c')
        rel_pos_bias = self.dynamic_pos_bias_mlp(rel_dist.float())

        # attention, what else
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.20',
  version = '0.0.21',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',