Commit deceb7a0 authored by Phil Wang's avatar Phil Wang
Browse files

bring the feedforward design from localvit to audiolm

parent 94fe05d1
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
@@ -364,3 +364,13 @@ sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_
    note    = {under review}
}
```

```bibtex
@article{Li2021LocalViTBL,
    title   = {LocalViT: Bringing Locality to Vision Transformers},
    author  = {Yawei Li and K. Zhang and Jie Cao and Radu Timofte and Luc Van Gool},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2104.05707}
}
```
+14 −1
Original line number Diff line number Diff line
import math
from functools import partial

from typing import Optional, Union, List
from beartype.typing import Optional, Union, List
from beartype import beartype

import torch
@@ -11,6 +11,7 @@ import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from audiolm_pytorch.vq_wav2vec import FairseqVQWav2Vec
from audiolm_pytorch.hubert_kmeans import HubertWithKmeans
@@ -203,6 +204,17 @@ class RelativePositionBias(nn.Module):

# feedforward

class CausalDSConv(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.ds_conv = nn.Conv1d(dim, dim, 3, bias = False, groups = dim)

    def forward(self, x):
        x = rearrange(x, 'b n c -> b c n')
        x = F.pad(x, (2, 0))
        x = self.ds_conv(x)
        return rearrange(x, 'b c n -> b n c')

class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
@@ -213,6 +225,7 @@ def FeedForward(dim, mult = 4, dropout = 0.1):
    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        CausalDSConv(inner_dim * 2),
        GEGLU(),
        LayerNorm(inner_dim),
        nn.Dropout(dropout),
+1 −1
Original line number Diff line number Diff line
from pathlib import Path
from functools import partial, wraps

from typing import Tuple
from beartype.typing import Tuple
from beartype.door import is_bearable

import torchaudio
+1 −1
Original line number Diff line number Diff line
@@ -4,7 +4,7 @@ from random import choice
from pathlib import Path
from shutil import rmtree

from typing import Union, List, Optional, Tuple
from beartype.typing import Union, List, Optional, Tuple
from typing_extensions import Annotated

from beartype import beartype
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.22',
  version = '0.2.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',