Loading README.md +9 −0 Original line number Diff line number Diff line Loading @@ -63,3 +63,12 @@ loss.backward() year = {2021} } ``` ```bibtex @misc{shazeer2020glu, title = {GLU Variants Improve Transformer}, author = {Noam Shazeer}, year = {2020}, url = {https://arxiv.org/abs/2002.05202} } ``` audiolm_pytorch/audiolm_pytorch.py +8 −3 Original line number Diff line number Diff line Loading @@ -442,12 +442,17 @@ class RelativePositionBias(nn.Module): # feedforward class GEGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim = -1) return F.gelu(gate) * x def FeedForward(dim, mult = 4): inner_dim = int(dim * mult) inner_dim = int(dim * 2 * mult / 3) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias = False), nn.GELU(), nn.Linear(dim, inner_dim * 2, bias = False), GEGLU(), nn.Linear(inner_dim, dim, bias = False) ) Loading Loading
README.md +9 −0 Original line number Diff line number Diff line Loading @@ -63,3 +63,12 @@ loss.backward() year = {2021} } ``` ```bibtex @misc{shazeer2020glu, title = {GLU Variants Improve Transformer}, author = {Noam Shazeer}, year = {2020}, url = {https://arxiv.org/abs/2002.05202} } ```
audiolm_pytorch/audiolm_pytorch.py +8 −3 Original line number Diff line number Diff line Loading @@ -442,12 +442,17 @@ class RelativePositionBias(nn.Module): # feedforward class GEGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim = -1) return F.gelu(gate) * x def FeedForward(dim, mult = 4): inner_dim = int(dim * mult) inner_dim = int(dim * 2 * mult / 3) return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias = False), nn.GELU(), nn.Linear(dim, inner_dim * 2, bias = False), GEGLU(), nn.Linear(inner_dim, dim, bias = False) ) Loading