Loading musiclm_pytorch/musiclm_pytorch.py +138 −1 Original line number Diff line number Diff line Loading @@ -2,16 +2,23 @@ import torch import torch.nn.functional as F from torch import nn, einsum from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking from x_clip.tokenizer import tokenizer from vector_quantize_pytorch import ResidualVQ from einops import rearrange, repeat, reduce, pack, unpack from beartype import beartype # functions def exists(val): return val is not None def round_down_nearest_multiple(n, divisor): return n // divisor * divisor # tensor functions def log(t, eps = 1e-20): Loading @@ -20,6 +27,26 @@ def log(t, eps = 1e-20): def l2norm(t): return F.normalize(t, p = 2, dim = -1) # 2d sinusoidal positional embedding # simple vit paper shows it is good enough compared to learned def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) omega = 1. / (temperature ** omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) pe = pe.type(dtype) return rearrange(pe, '(h w) d -> h w d', h = h, w = w) # biasless layernorm class LayerNorm(nn.Module): Loading Loading @@ -155,10 +182,120 @@ class Transformer(nn.Module): return x # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): return (t, t) if not isinstance(t, tuple) else t class AudioSpectrogramTransformer(nn.Module): def __init__( self, dim, depth, patch_size = 16, dim_head = 64, heads = 8, attn_dropout = 0., ff_mult = 4, ff_dropout = 0., spec_n_fft = 128, spec_power = 2, spec_win_length = 24, spec_hop_length = None, spec_pad = 0, spec_center = True, spec_pad_mode = 'reflect', spec_aug_stretch_factor = 0.8, spec_aug_freq_mask = 80, spec_aug_time_mask = 80 ): super().__init__() self.patch_size = pair(patch_size) self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1) self.spec = Spectrogram( n_fft = spec_n_fft, power = spec_power, win_length = spec_win_length, hop_length = spec_hop_length, pad = spec_pad, center = spec_center, pad_mode = spec_pad_mode ) self.aug = torch.nn.Sequential( TimeStretch(spec_aug_stretch_factor, fixed_rate=True), FrequencyMasking(freq_mask_param = spec_aug_freq_mask), TimeMasking(time_mask_param = spec_aug_time_mask), ) self.transformer = Transformer( dim = dim, depth = depth, dim_head = dim_head, heads = heads, attn_dropout = attn_dropout, ff_mult = ff_mult, ff_dropout = ff_dropout ) self.norm = LayerNorm(dim) def forward(self, x): x = self.spec(x) if self.training: x = self.aug(x) # automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes height, width = x.shape[-2:] patch_height, patch_width = self.patch_size rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width))) if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer') x = x[..., :rounded_height, :rounded_width] # to patches x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width) x = self.to_patch_tokens(x) # 2d sinusoidal positional embedding x = rearrange(x, 'b c h w -> b h w c') x = x + posemb_sincos_2d(x) # attention, what else x = rearrange(x, 'b ... c -> b (...) c') x = self.transformer(x) # final global average and norm (most recent papers show this is superior to CLS token) x = reduce(x, 'b n d -> b d', 'mean') return self.norm(x) # text transformer class TextTransformer: pass # main classes @beartype class MuLaN(nn.Module): def __init__(self): def __init__( self, audio_transformer: AudioSpectrogramTransformer, text_transformer: TextTransformer ): super().__init__() def forward(self, x): Loading setup.py +1 −0 Original line number Diff line number Diff line Loading @@ -20,6 +20,7 @@ setup( ], install_requires=[ 'audiolm-pytorch', 'beartype', 'einops>=0.4', 'vector-quantize-pytorch>=0.10.15', 'x-clip', Loading Loading
musiclm_pytorch/musiclm_pytorch.py +138 −1 Original line number Diff line number Diff line Loading @@ -2,16 +2,23 @@ import torch import torch.nn.functional as F from torch import nn, einsum from torchaudio.transforms import Spectrogram, TimeStretch, FrequencyMasking, TimeMasking from x_clip.tokenizer import tokenizer from vector_quantize_pytorch import ResidualVQ from einops import rearrange, repeat, reduce, pack, unpack from beartype import beartype # functions def exists(val): return val is not None def round_down_nearest_multiple(n, divisor): return n // divisor * divisor # tensor functions def log(t, eps = 1e-20): Loading @@ -20,6 +27,26 @@ def log(t, eps = 1e-20): def l2norm(t): return F.normalize(t, p = 2, dim = -1) # 2d sinusoidal positional embedding # simple vit paper shows it is good enough compared to learned def posemb_sincos_2d(patches, temperature = 10000, dtype = torch.float32): _, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype y, x = torch.meshgrid(torch.arange(h, device = device), torch.arange(w, device = device), indexing = 'ij') assert (dim % 4) == 0, 'feature dimension must be multiple of 4 for sincos emb' omega = torch.arange(dim // 4, device = device) / (dim // 4 - 1) omega = 1. / (temperature ** omega) y = y.flatten()[:, None] * omega[None, :] x = x.flatten()[:, None] * omega[None, :] pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim = 1) pe = pe.type(dtype) return rearrange(pe, '(h w) d -> h w d', h = h, w = w) # biasless layernorm class LayerNorm(nn.Module): Loading Loading @@ -155,10 +182,120 @@ class Transformer(nn.Module): return x # Audio Spectrogram Transformer - https://arxiv.org/abs/2104.01778 def pair(t): return (t, t) if not isinstance(t, tuple) else t class AudioSpectrogramTransformer(nn.Module): def __init__( self, dim, depth, patch_size = 16, dim_head = 64, heads = 8, attn_dropout = 0., ff_mult = 4, ff_dropout = 0., spec_n_fft = 128, spec_power = 2, spec_win_length = 24, spec_hop_length = None, spec_pad = 0, spec_center = True, spec_pad_mode = 'reflect', spec_aug_stretch_factor = 0.8, spec_aug_freq_mask = 80, spec_aug_time_mask = 80 ): super().__init__() self.patch_size = pair(patch_size) self.to_patch_tokens = nn.Conv2d(self.patch_size[0] * self.patch_size[1], dim, 1) self.spec = Spectrogram( n_fft = spec_n_fft, power = spec_power, win_length = spec_win_length, hop_length = spec_hop_length, pad = spec_pad, center = spec_center, pad_mode = spec_pad_mode ) self.aug = torch.nn.Sequential( TimeStretch(spec_aug_stretch_factor, fixed_rate=True), FrequencyMasking(freq_mask_param = spec_aug_freq_mask), TimeMasking(time_mask_param = spec_aug_time_mask), ) self.transformer = Transformer( dim = dim, depth = depth, dim_head = dim_head, heads = heads, attn_dropout = attn_dropout, ff_mult = ff_mult, ff_dropout = ff_dropout ) self.norm = LayerNorm(dim) def forward(self, x): x = self.spec(x) if self.training: x = self.aug(x) # automatically crop if audio does not yield a 2d spectrogram that is divisible by patch sizes height, width = x.shape[-2:] patch_height, patch_width = self.patch_size rounded_height, rounded_width = map(lambda args: round_down_nearest_multiple(*args), ((height, patch_height), (width, patch_width))) if (height, width) != (rounded_height, rounded_width): # just keep printing to be annoying until it is fixed print(f'spectrogram yielded shape of {(height, width)}, but had to be cropped to {(rounded_height, rounded_width)} to be patchified for transformer') x = x[..., :rounded_height, :rounded_width] # to patches x = rearrange(x, 'b (h p1) (w p2) -> b (p1 p2) h w', p1 = patch_height, p2 = patch_width) x = self.to_patch_tokens(x) # 2d sinusoidal positional embedding x = rearrange(x, 'b c h w -> b h w c') x = x + posemb_sincos_2d(x) # attention, what else x = rearrange(x, 'b ... c -> b (...) c') x = self.transformer(x) # final global average and norm (most recent papers show this is superior to CLS token) x = reduce(x, 'b n d -> b d', 'mean') return self.norm(x) # text transformer class TextTransformer: pass # main classes @beartype class MuLaN(nn.Module): def __init__(self): def __init__( self, audio_transformer: AudioSpectrogramTransformer, text_transformer: TextTransformer ): super().__init__() def forward(self, x): Loading
setup.py +1 −0 Original line number Diff line number Diff line Loading @@ -20,6 +20,7 @@ setup( ], install_requires=[ 'audiolm-pytorch', 'beartype', 'einops>=0.4', 'vector-quantize-pytorch>=0.10.15', 'x-clip', Loading