Block Causal Attention

Allows for Linear scaling of attention calculation with respect to sequence length

Vectorized way to general block causal mask

L, B = 16, 4

idx = np.arange(L)          # [0, 1, …, L-1]
row = idx[:, None]             # shape [L, 1]
col = idx[None, :]             # shape [1, L]

same_block   = (row // B) == (col // B)
lower_tri    = (row %  B) >= (col %  B)
mask_bool    = same_block & lower_tri    # logical AND

print(np.where(mask_bool, 1, 0))

Last updated