Commit aeb4962a authored by Phil Wang's avatar Phil Wang
Browse files

increase odds transformers are trained successfully

parent 9145655b
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -409,3 +409,12 @@ sample = trainer.generate(text = ['sound of rain drops on the rooftops'], batch_
    year    = {2022}
}
```

```bibtex
@misc{gilmer2023intriguing
    title  = {Intriguing Properties of Transformer Training Instabilities},
    author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
    year   = {2023},
    status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
```
+16 −4
Original line number Diff line number Diff line
@@ -80,6 +80,9 @@ def grad_shrink(t, alpha = 0.1):
def log(t, eps = 1e-20):
    return torch.log(t + eps)

def l2norm(t):
    return F.normalize(t, dim = -1)

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))
@@ -243,11 +246,12 @@ class Attention(nn.Module):
        heads = 8,
        norm_context = False,
        num_null_kv = 0,
        dropout = 0.1
        dropout = 0.1,
        scale = 8
    ):
        super().__init__()
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.scale = scale
        self.causal = causal
        inner_dim = dim_head * heads

@@ -263,6 +267,10 @@ class Attention(nn.Module):

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)

        self.q_scale = nn.Parameter(torch.ones(dim_head))
        self.k_scale = nn.Parameter(torch.ones(dim_head))

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim, bias = False),
            nn.Dropout(dropout)
@@ -321,11 +329,15 @@ class Attention(nn.Module):

        q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)

        q = q * self.scale
        # new technique, rmsnormed queries and keys, first used by 22B parameter model successfully https://arxiv.org/abs/2302.05442

        q, k = map(l2norm, (q, k))
        q = q * self.q_scale
        k = k * self.k_scale

        # similarities

        sim = einsum('b h i d, b j d -> b h i j', q, k)
        sim = einsum('b h i d, b j d -> b h i j', q, k) * self.scale

        if exists(attn_bias):
            attn_bias = F.pad(attn_bias, (self.num_null_kv, 0), value = 0.)
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.12.0',
  version = '0.12.1',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',