|
"""Triton implementation of Flash Attention. |
|
|
|
# Copyright (c) 2022, Tri Dao. |
|
# |
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
# you may not use this file except in compliance with the License. |
|
# You may obtain a copy of the License at |
|
# |
|
# http://www.apache.org/licenses/LICENSE-2.0 |
|
# |
|
# Unless required by applicable law or agreed to in writing, software |
|
# distributed under the License is distributed on an "AS IS" BASIS, |
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
# See the License for the specific language governing permissions and |
|
# limitations under the License. |
|
|
|
*Experimental* implementation of FlashAttention in Triton. |
|
We use the FlashAttention implementation from Phil Tillet a starting point. |
|
https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py |
|
|
|
Changes: |
|
- Implement both causal and non-causal attention. |
|
- Implement both self-attention and cross-attention. |
|
- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. |
|
- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. |
|
- Support attention bias. |
|
- Speed up the forward pass a bit, and only store the LSE instead of m and l. |
|
- Make the backward for d=128 much faster by reducing register spilling. |
|
- Optionally parallelize the backward pass across seqlen_k, to deal with the case of |
|
small batch size * nheads. |
|
|
|
Caution: |
|
- If you plan to use headdim other than 64 and 128, you should test for race conditions |
|
(due to the Triton compiler), as done in tests/test_flash_attn.py |
|
"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions |
|
for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident |
|
that there are none left for other head dimensions. |
|
Differences between this Triton version and the CUDA version: |
|
- Triton version doesn't support dropout. |
|
- Triton forward is generally faster than CUDA forward. |
|
- Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64. |
|
It is slightly slower when headdim=128 and batch * nheads is large. |
|
- Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). |
|
""" |
|
|
|
import math |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
from einops import repeat |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({ |
|
'BLOCK_M': 128, |
|
'BLOCK_N': 128 |
|
}, |
|
num_warps=8, |
|
num_stages=1), |
|
|
|
|
|
], |
|
key=[ |
|
'CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', |
|
'BLOCK_HEADDIM' |
|
]) |
|
@triton.heuristics({ |
|
'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0, |
|
'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0, |
|
'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM'], |
|
}) |
|
@triton.jit |
|
def _fwd_kernel( |
|
Q, |
|
K, |
|
V, |
|
Bias, |
|
Out, |
|
Lse, |
|
TMP, |
|
softmax_scale, |
|
stride_qb, |
|
stride_qh, |
|
stride_qm, |
|
stride_kb, |
|
stride_kh, |
|
stride_kn, |
|
stride_vb, |
|
stride_vh, |
|
stride_vn, |
|
stride_bb, |
|
stride_bh, |
|
stride_bm, |
|
stride_ob, |
|
stride_oh, |
|
stride_om, |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
seqlen_q_rounded, |
|
headdim, |
|
CACHE_KEY_SEQLEN_Q, |
|
CACHE_KEY_SEQLEN_K, |
|
BIAS_TYPE: tl.constexpr, |
|
IS_CAUSAL: tl.constexpr, |
|
BLOCK_HEADDIM: tl.constexpr, |
|
EVEN_M: tl.constexpr, |
|
EVEN_N: tl.constexpr, |
|
EVEN_HEADDIM: tl.constexpr, |
|
BLOCK_M: tl.constexpr, |
|
BLOCK_N: tl.constexpr, |
|
): |
|
start_m = tl.program_id(0) |
|
off_hb = tl.program_id(1) |
|
off_b = off_hb // nheads |
|
off_h = off_hb % nheads |
|
|
|
|
|
|
|
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
offs_n = tl.arange(0, BLOCK_N) |
|
offs_d = tl.arange(0, BLOCK_HEADDIM) |
|
|
|
|
|
|
|
|
|
q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + ( |
|
offs_m[:, None] * stride_qm + offs_d[None, :]) |
|
k_ptrs = K + off_b * stride_kb + off_h * stride_kh + ( |
|
offs_n[:, None] * stride_kn + offs_d[None, :]) |
|
v_ptrs = V + off_b * stride_vb + off_h * stride_vh + ( |
|
offs_n[:, None] * stride_vn + offs_d[None, :]) |
|
if BIAS_TYPE == 'vector': |
|
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n |
|
elif BIAS_TYPE == 'matrix': |
|
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + ( |
|
offs_m[:, None] * stride_bm + offs_n[None, :]) |
|
else: |
|
raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}") |
|
|
|
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m |
|
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') |
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf') |
|
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) |
|
|
|
|
|
|
|
if EVEN_M & EVEN_N: |
|
if EVEN_HEADDIM: |
|
q = tl.load(q_ptrs) |
|
else: |
|
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) |
|
else: |
|
if EVEN_HEADDIM: |
|
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) |
|
else: |
|
q = tl.load(q_ptrs, |
|
mask=(offs_m[:, None] < seqlen_q) & |
|
(offs_d[None, :] < headdim), |
|
other=0.0) |
|
|
|
end_n = seqlen_k if not IS_CAUSAL else tl.minimum( |
|
(start_m + 1) * BLOCK_M, seqlen_k) |
|
for start_n in range(0, end_n, BLOCK_N): |
|
start_n = tl.multiple_of(start_n, BLOCK_N) |
|
|
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
k = tl.load(k_ptrs + start_n * stride_kn) |
|
else: |
|
k = tl.load(k_ptrs + start_n * stride_kn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0) |
|
else: |
|
if EVEN_HEADDIM: |
|
k = tl.load(k_ptrs + start_n * stride_kn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0) |
|
else: |
|
k = tl.load(k_ptrs + start_n * stride_kn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & |
|
(offs_d[None, :] < headdim), |
|
other=0.0) |
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) |
|
qk += tl.dot(q, k, trans_b=True) |
|
|
|
if not EVEN_N: |
|
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, |
|
float('-inf')) |
|
if IS_CAUSAL: |
|
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, |
|
float('-inf')) |
|
if BIAS_TYPE != 'none': |
|
if BIAS_TYPE == 'vector': |
|
if EVEN_N: |
|
bias = tl.load(b_ptrs + start_n).to(tl.float32) |
|
else: |
|
bias = tl.load(b_ptrs + start_n, |
|
mask=(start_n + offs_n) < seqlen_k, |
|
other=0.0).to(tl.float32) |
|
bias = bias[None, :] |
|
elif BIAS_TYPE == 'matrix': |
|
if EVEN_M & EVEN_N: |
|
bias = tl.load(b_ptrs + start_n).to(tl.float32) |
|
else: |
|
bias = tl.load(b_ptrs + start_n, |
|
mask=(offs_m[:, None] < seqlen_q) & |
|
((start_n + offs_n)[None, :] < seqlen_k), |
|
other=0.0).to(tl.float32) |
|
else: |
|
raise ValueError( |
|
"BIAS_TYPE must be one of {'vector', 'matrix'}") |
|
|
|
|
|
|
|
qk = qk * softmax_scale + bias |
|
m_ij = tl.maximum(tl.max(qk, 1), lse_i) |
|
p = tl.exp(qk - m_ij[:, None]) |
|
else: |
|
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) |
|
p = tl.exp(qk * softmax_scale - m_ij[:, None]) |
|
l_ij = tl.sum(p, 1) |
|
|
|
|
|
acc_o_scale = tl.exp(m_i - m_ij) |
|
|
|
|
|
|
|
tl.store(t_ptrs, acc_o_scale) |
|
acc_o_scale = tl.load(t_ptrs) |
|
acc_o = acc_o * acc_o_scale[:, None] |
|
|
|
if EVEN_N & EVEN_M: |
|
if EVEN_HEADDIM: |
|
v = tl.load(v_ptrs + start_n * stride_vn) |
|
else: |
|
v = tl.load(v_ptrs + start_n * stride_vn, |
|
mask=offs_d[None, :] < headdim, |
|
other=0.0) |
|
else: |
|
if EVEN_HEADDIM: |
|
v = tl.load(v_ptrs + start_n * stride_vn, |
|
mask=(start_n + offs_n)[:, None] < seqlen_k, |
|
other=0.0) |
|
else: |
|
v = tl.load(v_ptrs + start_n * stride_vn, |
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & |
|
(offs_d[None, :] < headdim), |
|
other=0.0) |
|
p = p.to(v.dtype) |
|
acc_o += tl.dot(p, v) |
|
|
|
|
|
m_i = m_ij |
|
l_i_new = tl.exp(lse_i - m_ij) + l_ij |
|
lse_i = m_ij + tl.log(l_i_new) |
|
|
|
o_scale = tl.exp(m_i - lse_i) |
|
|
|
tl.store(t_ptrs, o_scale) |
|
o_scale = tl.load(t_ptrs) |
|
acc_o = acc_o * o_scale[:, None] |
|
|
|
start_m = tl.program_id(0) |
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) |
|
|
|
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m |
|
tl.store(lse_ptrs, lse_i) |
|
|
|
offs_n = tl.arange(0, BLOCK_HEADDIM) |
|
out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + ( |
|
offs_m[:, None] * stride_om + offs_n[None, :]) |
|
if EVEN_M: |
|
if EVEN_HEADDIM: |
|
tl.store(out_ptrs, acc_o) |
|
else: |
|
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) |
|
else: |
|
if EVEN_HEADDIM: |
|
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) |
|
else: |
|
tl.store(out_ptrs, |
|
acc_o, |
|
mask=(offs_m[:, None] < seqlen_q) & |
|
(offs_d[None, :] < headdim)) |
|
|
|
def init_to_zero(name): |
|
return lambda nargs: nargs[name].zero_() |
|
|
|
def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): |
|
|
|
batch, seqlen_q, nheads, d = q.shape |
|
_, seqlen_k, _, _ = k.shape |
|
assert k.shape == (batch, seqlen_k, nheads, d) |
|
assert v.shape == (batch, seqlen_k, nheads, d) |
|
assert d <= 128, 'FlashAttention only support head dimensions up to 128' |
|
assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' |
|
assert q.dtype in [torch.float16, |
|
torch.bfloat16], 'Only support fp16 and bf16' |
|
assert q.is_cuda and k.is_cuda and v.is_cuda |
|
softmax_scale = softmax_scale or 1.0 / math.sqrt(d) |
|
|
|
has_bias = bias is not None |
|
bias_type = 'none' |
|
if has_bias: |
|
assert bias.dtype in [q.dtype, torch.float] |
|
assert bias.is_cuda |
|
assert bias.dim() == 4 |
|
if bias.stride(-1) != 1: |
|
bias = bias.contiguous() |
|
if bias.shape[2:] == (1, seqlen_k): |
|
bias_type = 'vector' |
|
elif bias.shape[2:] == (seqlen_q, seqlen_k): |
|
bias_type = 'matrix' |
|
else: |
|
print(q.shape) |
|
print(k.shape) |
|
print(seqlen_q) |
|
print(seqlen_k) |
|
print(bias.shape) |
|
raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)' |
|
' or (seqlen_q, seqlen_k)') |
|
if bias.shape[:2] == (1, nheads): |
|
bias = repeat(bias, '1 h ... -> b h ...', b=batch) |
|
elif bias.shape[:2] == (batch, 1): |
|
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads) |
|
elif bias.shape[:2] == (1, 1): |
|
bias = repeat(bias, '1 h ... -> b h ...', b=batch) |
|
bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads) |
|
assert bias.shape[:2] == ( |
|
batch, nheads |
|
), f'First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}' |
|
assert bias is not None |
|
bias_strides = (bias.stride(0), bias.stride(1), |
|
bias.stride(2)) if has_bias else (0, 0, 0) |
|
|
|
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 |
|
lse = torch.empty((batch, nheads, seqlen_q_rounded), |
|
device=q.device, |
|
dtype=torch.float32) |
|
tmp = torch.empty((batch, nheads, seqlen_q_rounded), |
|
device=q.device, |
|
dtype=torch.float32) |
|
o = torch.empty_like(q) |
|
|
|
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) |
|
|
|
|
|
grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads) |
|
_fwd_kernel[grid]( |
|
q, |
|
k, |
|
v, |
|
bias, |
|
o, |
|
lse, |
|
tmp, |
|
softmax_scale, |
|
q.stride(0), |
|
q.stride(2), |
|
q.stride(1), |
|
k.stride(0), |
|
k.stride(2), |
|
k.stride(1), |
|
v.stride(0), |
|
v.stride(2), |
|
v.stride(1), |
|
*bias_strides, |
|
o.stride(0), |
|
o.stride(2), |
|
o.stride(1), |
|
nheads, |
|
seqlen_q, |
|
seqlen_k, |
|
seqlen_q_rounded, |
|
d, |
|
seqlen_q // 32, |
|
seqlen_k // 32, |
|
|
|
|
|
bias_type, |
|
causal, |
|
BLOCK_HEADDIM, |
|
|
|
|
|
|
|
) |
|
return o, lse, softmax_scale |
|
|
|
class _FlashAttnFunc(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): |
|
"""Forward pass for FlashAttention. |
|
|
|
Args: |
|
ctx: autograd context |
|
q: (batch_size, seqlen_q, nheads, headdim) |
|
k: (batch_size, seqlen_k, nheads, headdim) |
|
v: (batch_size, seqlen_k, nheads, headdim) |
|
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). |
|
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). |
|
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) |
|
causal (bool): whether to incorporate causal attention masking |
|
softmax_scale (float, optional): scale factor for softmax |
|
""" |
|
|
|
q, k, v = [ |
|
x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v] |
|
] |
|
o, lse, ctx.softmax_scale = _flash_attn_forward( |
|
q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale) |
|
ctx.save_for_backward(q, k, v, o, lse, bias) |
|
ctx.causal = causal |
|
return o |
|
|
|
@staticmethod |
|
def backward(ctx, do): |
|
raise NotImplementedError |
|
|
|
flash_attn_func = _FlashAttnFunc.apply |
|
|