Commit 95280273 authored by Phil Wang's avatar Phil Wang
Browse files

always use shazeer glu feedforward

parent 9c6785b8
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -63,3 +63,12 @@ loss.backward()
  year   = {2021}
}
```

```bibtex
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}
}
```
+8 −3
Original line number Diff line number Diff line
@@ -442,12 +442,17 @@ class RelativePositionBias(nn.Module):

# feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult)
    inner_dim = int(dim * 2 * mult / 3)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.Linear(dim, inner_dim * 2, bias = False),
        GEGLU(),
        nn.Linear(inner_dim, dim, bias = False)
    )