Commit 5aaee642 authored by Phil Wang's avatar Phil Wang
Browse files

adopt dual patchnorm

parent 26c82682
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -158,4 +158,16 @@ music = musiclm(['the crystalline sounds of the piano in a ballroom']) # torch.T
}
```

```bibtex
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
```

*The only truth is music.* - Jack Kerouac
+14 −5
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ from x_clip.tokenizer import tokenizer
from vector_quantize_pytorch import ResidualVQ

from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange

from beartype.typing import List, Optional, Tuple
from beartype import beartype
@@ -26,6 +27,9 @@ def default(val, d):
def round_down_nearest_multiple(n, divisor):
    return n // divisor * divisor

def Sequential(*modules):
    return nn.Sequential(*filter(exists, modules))

# tensor functions

def log(t, eps = 1e-20):
@@ -214,14 +218,21 @@ class AudioSpectrogramTransformer(nn.Module):
        spec_pad_mode = 'reflect',
        spec_aug_stretch_factor = 0.8,
        spec_aug_freq_mask = 80,
        spec_aug_time_mask = 80

        spec_aug_time_mask = 80,
        dual_patchnorm = True
    ):
        super().__init__()
        self.dim = dim

        self.patch_size = pair(patch_size)
        self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1)
        patch_input_dim = self.patch_size[0] * self.patch_size[1]

        self.to_patch_tokens = Sequential(
            Rearrange('b (h p1) (w p2) -> b h w (p1 p2)', p1 = self.patch_size[0], p2 = self.patch_size[1]),
            nn.LayerNorm(patch_input_dim) if dual_patchnorm else None,
            nn.Linear(patch_input_dim, dim),
            nn.LayerNorm(dim) if dual_patchnorm else None
        )

        self.spec = Spectrogram(
            n_fft = spec_n_fft,
@@ -273,12 +284,10 @@ class AudioSpectrogramTransformer(nn.Module):

        # to patches

        x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width)
        x = self.to_patch_tokens(x)

        # 2d sinusoidal positional embedding

        x = rearrange(x, 'b c h w -> b h w c')
        x = x + posemb_sincos_2d(x)

        # attention, what else
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'musiclm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.8',
  version = '0.0.9',
  license='MIT',
  description = 'MusicLM - AudioLM + Audio CLIP to text to music synthesis',
  author = 'Phil Wang',