Commit e70ecea0 authored by Phil Wang's avatar Phil Wang
Browse files

transformer for text and audio, going for AST encoder

parent 7ecb3a5f
Loading
Loading
Loading
Loading
+48 −2
Original line number Diff line number Diff line
@@ -20,6 +20,24 @@ class LayerNorm(nn.Module):
    def forward(self, x):
        return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# feedforward

class GEGLU(nn.Module):
    def forward(self, x):
        x, gate = x.chunk(2, dim = -1)
        return F.gelu(gate) * x

def FeedForward(dim, mult = 4, dropout = 0.):
    dim_hidden = int(dim * mult * 2 / 3)

    return nn.Sequential(
        LayerNorm(dim),
        nn.Linear(dim, dim_hidden * 2, bias = False),
        GEGLU(),
        nn.Dropout(dropout),
        nn.Linear(dim_hidden, dim, bias = False)
    )

# attention

class Attention(nn.Module):
@@ -29,8 +47,7 @@ class Attention(nn.Module):
        causal = False,
        dim_head = 64,
        heads = 8,
        num_null_kv = 0,
        dropout = 0.1
        dropout = 0.
    ):
        super().__init__()
        self.heads = heads
@@ -98,6 +115,35 @@ class Attention(nn.Module):
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# transformer

class Transformer(nn.Module):
    def __init__(
        self,
        dim,
        depth,
        dim_head = 64,
        heads = 8,
        attn_dropout = 0.,
        ff_mult = 4,
        ff_dropout = 0.
    ):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout),
                FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout),
            ]))

    def forward(self, x, mask = None):

        for attn, ff in self.layers:
            x = attn(x, mask = mask) + x
            x = ff(x) + x

        return x

# main classes

class MuLaN(nn.Module):