Commit 9cd86d25 authored by Phil Wang's avatar Phil Wang
Browse files

causal mask needed in attention

parent 277dabf5
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -111,6 +111,10 @@ class Attention(nn.Module):
        if exists(attn_bias):
            sim = sim + attn_bias

        i, j = sim.shape[-2:]
        causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
        sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

        attn = sim.softmax(dim = -1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)