Block Causal Attention
Allows for Linear scaling of attention calculation with respect to sequence length
Vectorized way to general block causal mask
L, B = 16, 4
idx = np.arange(L) # [0, 1, …, L-1]
row = idx[:, None] # shape [L, 1]
col = idx[None, :] # shape [1, L]
same_block = (row // B) == (col // B)
lower_tri = (row % B) >= (col % B)
mask_bool = same_block & lower_tri # logical AND
print(np.where(mask_bool, 1, 0))
Last updated