Loading audiolm_pytorch/audiolm_pytorch.py +4 −0 Original line number Diff line number Diff line Loading @@ -111,6 +111,10 @@ class Attention(nn.Module): if exists(attn_bias): sim = sim + attn_bias i, j = sim.shape[-2:] causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) attn = sim.softmax(dim = -1) out = einsum('b h i j, b h j d -> b h i d', attn, v) Loading Loading
audiolm_pytorch/audiolm_pytorch.py +4 −0 Original line number Diff line number Diff line Loading @@ -111,6 +111,10 @@ class Attention(nn.Module): if exists(attn_bias): sim = sim + attn_bias i, j = sim.shape[-2:] causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) attn = sim.softmax(dim = -1) out = einsum('b h i j, b h j d -> b h i d', attn, v) Loading