Commit 4abc58ed authored by Phil Wang's avatar Phil Wang
Browse files

use bias-less layernorm, normformer extra norm in feedforward, remove wip as nearly ready

parent 8c4d01ad
Loading
Loading
Loading
Loading
+12 −1
Original line number Diff line number Diff line
<img src="./audiolm.png" width="600px"></img>

## AudioLM - Pytorch (wip)
## AudioLM - Pytorch

Implementation of <a href="https://google-research.github.io/seanet/audiolm/examples/">AudioLM</a>, a Language Modeling Approach to Audio Generation out of Google Research, in Pytorch

@@ -285,3 +285,14 @@ generated_wav_with_text_condition = audiolm(text = ['chirping of birds and the d
    volume  = {abs/2210.13432}
}
```

```bibtex
@inproceedings{anonymous2022normformer,
    title   = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
    author  = {Anonymous},
    booktitle = {Submitted to The Tenth International Conference on Learning Representations },
    year    = {2022},
    url     = {https://openreview.net/forum?id=GMYWzWztDx5},
    note    = {under review}
}
```
+17 −4
Original line number Diff line number Diff line
@@ -144,6 +144,18 @@ def get_embeds(

    return embeds

# bias-less layernorm, being used in more recent T5s, PaLM, also in @borisdayma 's experiments shared with me
# greater stability

class LayerNorm(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.register_buffer("beta", torch.zeros(dim))

    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# relative positional bias

class RelativePositionBias(nn.Module):
@@ -199,9 +211,10 @@ class GEGLU(nn.Module):
def FeedForward(dim, mult = 4, dropout = 0.1):
    inner_dim = int(dim * 2 * mult / 3)
    return nn.Sequential(
        nn.LayerNorm(dim),
        LayerNorm(dim),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        LayerNorm(inner_dim),
        nn.Dropout(dropout),
        nn.Linear(inner_dim, dim, bias = False)
    )
@@ -228,8 +241,8 @@ class Attention(nn.Module):

        dim_context = default(dim_context, dim)

        self.norm = nn.LayerNorm(dim)
        self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()
        self.norm = LayerNorm(dim)
        self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()

        self.num_null_kv = num_null_kv
        self.null_kv = nn.Parameter(torch.randn(2, num_null_kv, dim_head))
@@ -321,7 +334,7 @@ class Transformer(nn.Module):
                FeedForward(dim = dim, dropout = ff_dropout)
            ]))

        self.norm = nn.LayerNorm(dim)
        self.norm = LayerNorm(dim)

    def forward(
        self,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.68',
  version = '0.1.0',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',