Commit 9b1fef52 authored by Phil Wang's avatar Phil Wang
Browse files

complete basic attention encoder

parent d3bb0612
Loading
Loading
Loading
Loading
+54 −4
Original line number Diff line number Diff line
@@ -14,7 +14,18 @@ class SoundStream(nn.Module):
    def forward(self, x):
        return x

# classes
# feedforward

def FeedForward(dim, mult = 4):
    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias = False),
        nn.GELU(),
        nn.Linear(inner_dim, dim, bias = False)
    )

# attention

class Attention(nn.Module):
    def __init__(
@@ -28,14 +39,18 @@ class Attention(nn.Module):
        self.scale = dim_head ** -0.5
        inner_dim = dim_head * heads

        self.norm = nn.LayerNorm(dim)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        q, k, v = self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d'), (q, k, v))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), (q, k, v))

        q = q * self.scale

@@ -48,11 +63,46 @@ class Attention(nn.Module):
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

# transformer

class Transformer(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        **kwargs
    ):
        super().__init__()
        self.layers = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim = dim, **kwargs),
                FeedForward(dim = dim)
            ]))

        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

# audio LM

class AudioLM(nn.Module):
    def __init__(self):
    def __init__(
        self,
        *,
        dim,
        depth,
        **kwargs
    ):
        super().__init__()
        self.transformer = Transformer(dim = dim, depth = depth, **kwargs)

    def forward(self, x):
        return x
        return self.transformer(x)