Jackmin801 commited on
Commit
c3b584d
1 Parent(s): 6f3de15

use tri daos version with backprop

Browse files
Files changed (1) hide show
  1. flash_attn_triton.py +928 -169
flash_attn_triton.py CHANGED
@@ -1,20 +1,11 @@
1
- """Triton implementation of Flash Attention.
2
-
3
- # Copyright (c) 2022, Tri Dao.
4
- #
5
- # Licensed under the Apache License, Version 2.0 (the "License");
6
- # you may not use this file except in compliance with the License.
7
- # You may obtain a copy of the License at
8
- #
9
- # http://www.apache.org/licenses/LICENSE-2.0
10
- #
11
- # Unless required by applicable law or agreed to in writing, software
12
- # distributed under the License is distributed on an "AS IS" BASIS,
13
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
- # See the License for the specific language governing permissions and
15
- # limitations under the License.
16
-
17
  *Experimental* implementation of FlashAttention in Triton.
 
 
 
 
 
 
18
  We use the FlashAttention implementation from Phil Tillet a starting point.
19
  https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
20
 
@@ -30,47 +21,47 @@ Changes:
30
  small batch size * nheads.
31
 
32
  Caution:
 
 
 
33
  - If you plan to use headdim other than 64 and 128, you should test for race conditions
34
  (due to the Triton compiler), as done in tests/test_flash_attn.py
35
  "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
36
  for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
37
  that there are none left for other head dimensions.
 
38
  Differences between this Triton version and the CUDA version:
39
  - Triton version doesn't support dropout.
40
- - Triton forward is generally faster than CUDA forward.
41
- - Triton backward is faster than CUDA backward when batch * nheads is small, and when headdim=64.
42
- It is slightly slower when headdim=128 and batch * nheads is large.
43
- - Triton version doesn't yet support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
 
44
  """
45
 
46
  import math
47
 
48
  import torch
49
- import triton # type: ignore (reportMissingImports)
50
- import triton.language as tl # type: ignore (reportMissingImports)
51
- from einops import repeat
52
 
53
 
54
- @triton.autotune(
55
- configs=[
56
- triton.Config({
57
- 'BLOCK_M': 128,
58
- 'BLOCK_N': 128
59
- },
60
- num_warps=8,
61
- num_stages=1),
62
- # This config has a race condition when EVEN_M == False, disabling it for now.
63
- # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
64
- ],
65
- key=[
66
- 'CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL',
67
- 'BLOCK_HEADDIM'
68
- ])
69
- @triton.heuristics({
70
- 'EVEN_M': lambda args: args['seqlen_q'] % args['BLOCK_M'] == 0,
71
- 'EVEN_N': lambda args: args['seqlen_k'] % args['BLOCK_N'] == 0,
72
- 'EVEN_HEADDIM': lambda args: args['headdim'] == args['BLOCK_HEADDIM'],
73
- })
74
  @triton.jit
75
  def _fwd_kernel(
76
  Q,
@@ -81,11 +72,21 @@ def _fwd_kernel(
81
  Lse,
82
  TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
83
  softmax_scale,
84
- stride_qb, stride_qh, stride_qm,
85
- stride_kb, stride_kh, stride_kn,
86
- stride_vb, stride_vh, stride_vn,
87
- stride_bb, stride_bh, stride_bm,
88
- stride_ob, stride_oh, stride_om,
 
 
 
 
 
 
 
 
 
 
89
  nheads,
90
  seqlen_q,
91
  seqlen_k,
@@ -117,23 +118,28 @@ def _fwd_kernel(
117
  # Adding parenthesis around indexing might use int32 math instead of int64 math?
118
  # https://github.com/openai/triton/issues/741
119
  # I'm seeing a tiny bit of difference (5-7us)
120
- q_ptrs = Q + off_b * stride_qb + off_h * stride_qh + (
121
- offs_m[:, None] * stride_qm + offs_d[None, :])
122
- k_ptrs = K + off_b * stride_kb + off_h * stride_kh + (
123
- offs_n[:, None] * stride_kn + offs_d[None, :])
124
- v_ptrs = V + off_b * stride_vb + off_h * stride_vh + (
125
- offs_n[:, None] * stride_vn + offs_d[None, :])
126
- if BIAS_TYPE == 'vector':
 
 
 
127
  b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
128
- elif BIAS_TYPE == 'matrix':
129
- b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + (
130
- offs_m[:, None] * stride_bm + offs_n[None, :])
131
- else:
132
- raise ValueError("BIAS_TYPE must be one of {'vector', 'matrix'}")
 
 
133
  # initialize pointer to m and l
134
  t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
135
- lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
136
- m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
137
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
138
  # load q: it will stay in SRAM throughout
139
  # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
@@ -147,13 +153,11 @@ def _fwd_kernel(
147
  if EVEN_HEADDIM:
148
  q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
149
  else:
150
- q = tl.load(q_ptrs,
151
- mask=(offs_m[:, None] < seqlen_q) &
152
- (offs_d[None, :] < headdim),
153
- other=0.0)
154
  # loop over k, v and update accumulator
155
- end_n = seqlen_k if not IS_CAUSAL else tl.minimum(
156
- (start_m + 1) * BLOCK_M, seqlen_k)
157
  for start_n in range(0, end_n, BLOCK_N):
158
  start_n = tl.multiple_of(start_n, BLOCK_N)
159
  # -- compute qk ----
@@ -161,48 +165,46 @@ def _fwd_kernel(
161
  if EVEN_HEADDIM:
162
  k = tl.load(k_ptrs + start_n * stride_kn)
163
  else:
164
- k = tl.load(k_ptrs + start_n * stride_kn,
165
- mask=offs_d[None, :] < headdim,
166
- other=0.0)
167
  else:
168
  if EVEN_HEADDIM:
169
- k = tl.load(k_ptrs + start_n * stride_kn,
170
- mask=(start_n + offs_n)[:, None] < seqlen_k,
171
- other=0.0)
 
 
172
  else:
173
- k = tl.load(k_ptrs + start_n * stride_kn,
174
- mask=((start_n + offs_n)[:, None] < seqlen_k) &
175
- (offs_d[None, :] < headdim),
176
- other=0.0)
 
177
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
178
  qk += tl.dot(q, k, trans_b=True)
179
  # Trying to combine the two masks seem to make the result wrong
180
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
181
- qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
182
- float('-inf'))
183
  if IS_CAUSAL:
184
- qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0,
185
- float('-inf'))
186
- if BIAS_TYPE != 'none':
187
- if BIAS_TYPE == 'vector':
188
  if EVEN_N:
189
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
190
  else:
191
- bias = tl.load(b_ptrs + start_n,
192
- mask=(start_n + offs_n) < seqlen_k,
193
- other=0.0).to(tl.float32)
194
  bias = bias[None, :]
195
- elif BIAS_TYPE == 'matrix':
196
  if EVEN_M & EVEN_N:
197
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
198
  else:
199
- bias = tl.load(b_ptrs + start_n,
200
- mask=(offs_m[:, None] < seqlen_q) &
201
- ((start_n + offs_n)[None, :] < seqlen_k),
202
- other=0.0).to(tl.float32)
203
- else:
204
- raise ValueError(
205
- "BIAS_TYPE must be one of {'vector', 'matrix'}")
206
  # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
207
  # can then fuse the mult and add into an fma instruction. But if we have bias we need to
208
  # to multiply with softmax_scale here.
@@ -227,19 +229,20 @@ def _fwd_kernel(
227
  if EVEN_HEADDIM:
228
  v = tl.load(v_ptrs + start_n * stride_vn)
229
  else:
230
- v = tl.load(v_ptrs + start_n * stride_vn,
231
- mask=offs_d[None, :] < headdim,
232
- other=0.0)
233
  else:
234
  if EVEN_HEADDIM:
235
- v = tl.load(v_ptrs + start_n * stride_vn,
236
- mask=(start_n + offs_n)[:, None] < seqlen_k,
237
- other=0.0)
 
 
238
  else:
239
- v = tl.load(v_ptrs + start_n * stride_vn,
240
- mask=((start_n + offs_n)[:, None] < seqlen_k) &
241
- (offs_d[None, :] < headdim),
242
- other=0.0)
 
243
  p = p.to(v.dtype)
244
  acc_o += tl.dot(p, v)
245
 
@@ -260,9 +263,13 @@ def _fwd_kernel(
260
  lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
261
  tl.store(lse_ptrs, lse_i)
262
  # initialize pointers to output
263
- offs_n = tl.arange(0, BLOCK_HEADDIM)
264
- out_ptrs = Out + off_b * stride_ob + off_h * stride_oh + (
265
- offs_m[:, None] * stride_om + offs_n[None, :])
 
 
 
 
266
  if EVEN_M:
267
  if EVEN_HEADDIM:
268
  tl.store(out_ptrs, acc_o)
@@ -272,29 +279,550 @@ def _fwd_kernel(
272
  if EVEN_HEADDIM:
273
  tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
274
  else:
275
- tl.store(out_ptrs,
276
- acc_o,
277
- mask=(offs_m[:, None] < seqlen_q) &
278
- (offs_d[None, :] < headdim))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  def init_to_zero(name):
281
  return lambda nargs: nargs[name].zero_()
282
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
284
  # shape constraints
285
  batch, seqlen_q, nheads, d = q.shape
286
  _, seqlen_k, _, _ = k.shape
287
  assert k.shape == (batch, seqlen_k, nheads, d)
288
  assert v.shape == (batch, seqlen_k, nheads, d)
289
- assert d <= 128, 'FlashAttention only support head dimensions up to 128'
290
- assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type'
291
- assert q.dtype in [torch.float16,
292
- torch.bfloat16], 'Only support fp16 and bf16'
293
  assert q.is_cuda and k.is_cuda and v.is_cuda
294
  softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
295
 
296
  has_bias = bias is not None
297
- bias_type = 'none'
298
  if has_bias:
299
  assert bias.dtype in [q.dtype, torch.float]
300
  assert bias.is_cuda
@@ -302,40 +830,26 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
302
  if bias.stride(-1) != 1:
303
  bias = bias.contiguous()
304
  if bias.shape[2:] == (1, seqlen_k):
305
- bias_type = 'vector'
306
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
307
- bias_type = 'matrix'
308
  else:
309
- raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
310
- ' or (seqlen_q, seqlen_k)')
311
- if bias.shape[:2] == (1, nheads):
312
- bias = repeat(bias, '1 h ... -> b h ...', b=batch)
313
- elif bias.shape[:2] == (batch, 1):
314
- bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
315
- elif bias.shape[:2] == (1, 1):
316
- bias = repeat(bias, '1 h ... -> b h ...', b=batch)
317
- bias = repeat(bias, 'b 1 ... -> b h ...', h=nheads)
318
- assert bias.shape[:2] == (
319
- batch, nheads
320
- ), f'First 2 dimensions of bias must be broadcastible to (batch, nheads) = ({batch, nheads}). Bias has shape: {bias.shape}'
321
- assert bias is not None # for type checking
322
- bias_strides = (bias.stride(0), bias.stride(1),
323
- bias.stride(2)) if has_bias else (0, 0, 0)
324
 
325
  seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
326
- lse = torch.empty((batch, nheads, seqlen_q_rounded),
327
- device=q.device,
328
- dtype=torch.float32)
329
- tmp = torch.empty((batch, nheads, seqlen_q_rounded),
330
- device=q.device,
331
- dtype=torch.float32)
332
  o = torch.empty_like(q)
333
 
334
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
335
- # BLOCK = 128
336
- # num_warps = 4 if d <= 64 else 8
337
- grid = lambda META: (triton.cdiv(seqlen_q, META['BLOCK_M']), batch * nheads)
338
- _fwd_kernel[grid]( # type: ignore
339
  q,
340
  k,
341
  v,
@@ -344,11 +858,138 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
344
  lse,
345
  tmp,
346
  softmax_scale,
347
- q.stride(0), q.stride(2), q.stride(1),
348
- k.stride(0), k.stride(2), k.stride(1),
349
- v.stride(0), v.stride(2), v.stride(1),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  *bias_strides,
351
- o.stride(0), o.stride(2), o.stride(1),
 
 
 
 
 
 
 
 
 
 
 
352
  nheads,
353
  seqlen_q,
354
  seqlen_k,
@@ -361,41 +1002,159 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
361
  bias_type,
362
  causal,
363
  BLOCK_HEADDIM,
364
- # BLOCK_M=BLOCK, BLOCK_N=BLOCK,
 
365
  # num_warps=num_warps,
366
  # num_stages=1,
367
  )
368
- return o, lse, softmax_scale # softmax_scale could have been updated
369
 
370
- class _FlashAttnFunc(torch.autograd.Function):
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  @staticmethod
373
  def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
374
- """Forward pass for FlashAttention.
375
-
376
- Args:
377
- ctx: autograd context
378
- q: (batch_size, seqlen_q, nheads, headdim)
379
- k: (batch_size, seqlen_k, nheads, headdim)
380
- v: (batch_size, seqlen_k, nheads, headdim)
381
- bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
382
- For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
383
- ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
384
- causal (bool): whether to incorporate causal attention masking
385
- softmax_scale (float, optional): scale factor for softmax
386
  """
387
  # Make sure that the last dimension is contiguous
388
- q, k, v = [
389
- x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]
390
- ]
391
  o, lse, ctx.softmax_scale = _flash_attn_forward(
392
- q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale)
 
393
  ctx.save_for_backward(q, k, v, o, lse, bias)
394
  ctx.causal = causal
395
  return o
396
 
397
  @staticmethod
398
  def backward(ctx, do):
399
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
- flash_attn_func = _FlashAttnFunc.apply
 
1
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *Experimental* implementation of FlashAttention in Triton.
3
+ Tested with triton==2.0.0.dev20221202.
4
+ Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
5
+ other than 64:
6
+ https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
7
+ We'll update this implementation with the new Triton backend once this is fixed.
8
+
9
  We use the FlashAttention implementation from Phil Tillet a starting point.
10
  https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py
11
 
 
21
  small batch size * nheads.
22
 
23
  Caution:
24
+ - This is an *experimental* implementation. The forward pass should be quite robust but
25
+ I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler).
26
+ - This implementation has only been tested on A100.
27
  - If you plan to use headdim other than 64 and 128, you should test for race conditions
28
  (due to the Triton compiler), as done in tests/test_flash_attn.py
29
  "test_flash_attn_triton_race_condition". I've tested and fixed many race conditions
30
  for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident
31
  that there are none left for other head dimensions.
32
+
33
  Differences between this Triton version and the CUDA version:
34
  - Triton version doesn't support dropout.
35
+ - Triton forward is generally faster than CUDA forward, while Triton backward is
36
+ generally slower than CUDA backward. Overall Triton forward + backward is slightly slower
37
+ than CUDA forward + backward.
38
+ - Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor).
39
+ - Triton version supports attention bias, while CUDA version doesn't.
40
  """
41
 
42
  import math
43
 
44
  import torch
45
+ import triton
46
+ import triton.language as tl
 
47
 
48
 
49
+ # Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
50
+ # @triton.autotune(
51
+ # configs=[
52
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
53
+ # # This config has a race condition when EVEN_M == False, disabling it for now.
54
+ # # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
55
+ # ],
56
+ # key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
57
+ # )
58
+ @triton.heuristics(
59
+ {
60
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
61
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
62
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
63
+ }
64
+ )
 
 
 
 
65
  @triton.jit
66
  def _fwd_kernel(
67
  Q,
 
72
  Lse,
73
  TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
74
  softmax_scale,
75
+ stride_qb,
76
+ stride_qh,
77
+ stride_qm,
78
+ stride_kb,
79
+ stride_kh,
80
+ stride_kn,
81
+ stride_vb,
82
+ stride_vh,
83
+ stride_vn,
84
+ stride_bb,
85
+ stride_bh,
86
+ stride_bm,
87
+ stride_ob,
88
+ stride_oh,
89
+ stride_om,
90
  nheads,
91
  seqlen_q,
92
  seqlen_k,
 
118
  # Adding parenthesis around indexing might use int32 math instead of int64 math?
119
  # https://github.com/openai/triton/issues/741
120
  # I'm seeing a tiny bit of difference (5-7us)
121
+ q_ptrs = (
122
+ Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
123
+ )
124
+ k_ptrs = (
125
+ K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
126
+ )
127
+ v_ptrs = (
128
+ V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
129
+ )
130
+ if BIAS_TYPE == "vector":
131
  b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
132
+ elif BIAS_TYPE == "matrix":
133
+ b_ptrs = (
134
+ Bias
135
+ + off_b * stride_bb
136
+ + off_h * stride_bh
137
+ + (offs_m[:, None] * stride_bm + offs_n[None, :])
138
+ )
139
  # initialize pointer to m and l
140
  t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
141
+ lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
142
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
143
  acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
144
  # load q: it will stay in SRAM throughout
145
  # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
 
153
  if EVEN_HEADDIM:
154
  q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
155
  else:
156
+ q = tl.load(
157
+ q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
158
+ )
 
159
  # loop over k, v and update accumulator
160
+ end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
 
161
  for start_n in range(0, end_n, BLOCK_N):
162
  start_n = tl.multiple_of(start_n, BLOCK_N)
163
  # -- compute qk ----
 
165
  if EVEN_HEADDIM:
166
  k = tl.load(k_ptrs + start_n * stride_kn)
167
  else:
168
+ k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
 
 
169
  else:
170
  if EVEN_HEADDIM:
171
+ k = tl.load(
172
+ k_ptrs + start_n * stride_kn,
173
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
174
+ other=0.0,
175
+ )
176
  else:
177
+ k = tl.load(
178
+ k_ptrs + start_n * stride_kn,
179
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
180
+ other=0.0,
181
+ )
182
  qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
183
  qk += tl.dot(q, k, trans_b=True)
184
  # Trying to combine the two masks seem to make the result wrong
185
  if not EVEN_N: # Need to mask out otherwise the softmax is wrong
186
+ qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
 
187
  if IS_CAUSAL:
188
+ qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
189
+ if BIAS_TYPE != "none":
190
+ if BIAS_TYPE == "vector":
 
191
  if EVEN_N:
192
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
193
  else:
194
+ bias = tl.load(
195
+ b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0
196
+ ).to(tl.float32)
197
  bias = bias[None, :]
198
+ elif BIAS_TYPE == "matrix":
199
  if EVEN_M & EVEN_N:
200
  bias = tl.load(b_ptrs + start_n).to(tl.float32)
201
  else:
202
+ bias = tl.load(
203
+ b_ptrs + start_n,
204
+ mask=(offs_m[:, None] < seqlen_q)
205
+ & ((start_n + offs_n)[None, :] < seqlen_k),
206
+ other=0.0,
207
+ ).to(tl.float32)
 
208
  # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
209
  # can then fuse the mult and add into an fma instruction. But if we have bias we need to
210
  # to multiply with softmax_scale here.
 
229
  if EVEN_HEADDIM:
230
  v = tl.load(v_ptrs + start_n * stride_vn)
231
  else:
232
+ v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
 
 
233
  else:
234
  if EVEN_HEADDIM:
235
+ v = tl.load(
236
+ v_ptrs + start_n * stride_vn,
237
+ mask=(start_n + offs_n)[:, None] < seqlen_k,
238
+ other=0.0,
239
+ )
240
  else:
241
+ v = tl.load(
242
+ v_ptrs + start_n * stride_vn,
243
+ mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
244
+ other=0.0,
245
+ )
246
  p = p.to(v.dtype)
247
  acc_o += tl.dot(p, v)
248
 
 
263
  lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
264
  tl.store(lse_ptrs, lse_i)
265
  # initialize pointers to output
266
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
267
+ out_ptrs = (
268
+ Out
269
+ + off_b * stride_ob
270
+ + off_h * stride_oh
271
+ + (offs_m[:, None] * stride_om + offs_d[None, :])
272
+ )
273
  if EVEN_M:
274
  if EVEN_HEADDIM:
275
  tl.store(out_ptrs, acc_o)
 
279
  if EVEN_HEADDIM:
280
  tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
281
  else:
282
+ tl.store(
283
+ out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
284
+ )
285
+
286
+
287
+ @triton.jit
288
+ def _bwd_preprocess_do_o_dot(
289
+ Out,
290
+ DO,
291
+ Delta,
292
+ stride_ob,
293
+ stride_oh,
294
+ stride_om,
295
+ stride_dob,
296
+ stride_doh,
297
+ stride_dom,
298
+ nheads,
299
+ seqlen_q,
300
+ seqlen_q_rounded,
301
+ headdim,
302
+ BLOCK_M: tl.constexpr,
303
+ BLOCK_HEADDIM: tl.constexpr,
304
+ ):
305
+ start_m = tl.program_id(0)
306
+ off_hb = tl.program_id(1)
307
+ off_b = off_hb // nheads
308
+ off_h = off_hb % nheads
309
+ # initialize offsets
310
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
311
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
312
+ # load
313
+ o = tl.load(
314
+ Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :],
315
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
316
+ other=0.0,
317
+ ).to(tl.float32)
318
+ do = tl.load(
319
+ DO
320
+ + off_b * stride_dob
321
+ + off_h * stride_doh
322
+ + offs_m[:, None] * stride_dom
323
+ + offs_d[None, :],
324
+ mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
325
+ other=0.0,
326
+ ).to(tl.float32)
327
+ delta = tl.sum(o * do, axis=1)
328
+ # write-back
329
+ tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta)
330
+
331
+
332
+ @triton.jit
333
+ def _bwd_store_dk_dv(
334
+ dk_ptrs,
335
+ dv_ptrs,
336
+ dk,
337
+ dv,
338
+ offs_n,
339
+ offs_d,
340
+ seqlen_k,
341
+ headdim,
342
+ EVEN_M: tl.constexpr,
343
+ EVEN_N: tl.constexpr,
344
+ EVEN_HEADDIM: tl.constexpr,
345
+ ):
346
+ # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
347
+ # if we just call tl.store(dv_ptrs), there's a race condition
348
+ if EVEN_N & EVEN_M:
349
+ if EVEN_HEADDIM:
350
+ tl.store(dv_ptrs, dv)
351
+ tl.store(dk_ptrs, dk)
352
+ else:
353
+ tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim)
354
+ tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim)
355
+ else:
356
+ if EVEN_HEADDIM:
357
+ tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k)
358
+ tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k)
359
+ else:
360
+ tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
361
+ tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim))
362
+
363
+
364
+ @triton.jit
365
+ def _bwd_kernel_one_col_block(
366
+ start_n,
367
+ Q,
368
+ K,
369
+ V,
370
+ Bias,
371
+ DO,
372
+ DQ,
373
+ DK,
374
+ DV,
375
+ LSE,
376
+ D,
377
+ softmax_scale,
378
+ stride_qm,
379
+ stride_kn,
380
+ stride_vn,
381
+ stride_bm,
382
+ stride_dom,
383
+ stride_dqm,
384
+ stride_dkn,
385
+ stride_dvn,
386
+ seqlen_q,
387
+ seqlen_k,
388
+ headdim,
389
+ ATOMIC_ADD: tl.constexpr,
390
+ BIAS_TYPE: tl.constexpr,
391
+ IS_CAUSAL: tl.constexpr,
392
+ BLOCK_HEADDIM: tl.constexpr,
393
+ EVEN_M: tl.constexpr,
394
+ EVEN_N: tl.constexpr,
395
+ EVEN_HEADDIM: tl.constexpr,
396
+ BLOCK_M: tl.constexpr,
397
+ BLOCK_N: tl.constexpr,
398
+ ):
399
+ # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
400
+ begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M
401
+ # initialize row/col offsets
402
+ offs_qm = begin_m + tl.arange(0, BLOCK_M)
403
+ offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
404
+ offs_m = tl.arange(0, BLOCK_M)
405
+ offs_d = tl.arange(0, BLOCK_HEADDIM)
406
+ # initialize pointers to value-like data
407
+ q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :])
408
+ k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :])
409
+ v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :])
410
+ do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :])
411
+ dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :])
412
+ if BIAS_TYPE == "vector":
413
+ b_ptrs = Bias + offs_n
414
+ elif BIAS_TYPE == "matrix":
415
+ b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :])
416
+ # initialize dv and dk
417
+ dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
418
+ dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32)
419
+ # There seems to be some problem with Triton pipelining that makes results wrong for
420
+ # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop
421
+ # may have zero step, and pipelining with the bias matrix could screw it up.
422
+ # So we just exit early.
423
+ if begin_m >= seqlen_q:
424
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
425
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
426
+ _bwd_store_dk_dv(
427
+ dk_ptrs,
428
+ dv_ptrs,
429
+ dk,
430
+ dv,
431
+ offs_n,
432
+ offs_d,
433
+ seqlen_k,
434
+ headdim,
435
+ EVEN_M=EVEN_M,
436
+ EVEN_N=EVEN_N,
437
+ EVEN_HEADDIM=EVEN_HEADDIM,
438
+ )
439
+ return
440
+ # k and v stay in SRAM throughout
441
+ # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
442
+ # if we just call tl.load(k_ptrs), we get the wrong output!
443
+ if EVEN_N & EVEN_M:
444
+ if EVEN_HEADDIM:
445
+ k = tl.load(k_ptrs)
446
+ v = tl.load(v_ptrs)
447
+ else:
448
+ k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
449
+ v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
450
+ else:
451
+ if EVEN_HEADDIM:
452
+ k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
453
+ v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0)
454
+ else:
455
+ k = tl.load(
456
+ k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
457
+ )
458
+ v = tl.load(
459
+ v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0
460
+ )
461
+ # loop over rows
462
+ num_block_m = tl.cdiv(seqlen_q, BLOCK_M)
463
+ for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M):
464
+ start_m = tl.multiple_of(start_m, BLOCK_M)
465
+ offs_m_curr = start_m + offs_m
466
+ # load q, k, v, do on-chip
467
+ # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117)
468
+ if EVEN_M & EVEN_HEADDIM:
469
+ q = tl.load(q_ptrs)
470
+ else:
471
+ if EVEN_HEADDIM:
472
+ q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
473
+ else:
474
+ q = tl.load(
475
+ q_ptrs,
476
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
477
+ other=0.0,
478
+ )
479
+ # recompute p = softmax(qk, dim=-1).T
480
+ qk = tl.dot(q, k, trans_b=True)
481
+ # Trying to combine the two masks seem to make the result wrong
482
+ if not EVEN_N: # Need to mask out otherwise the softmax is wrong
483
+ qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf"))
484
+ if IS_CAUSAL:
485
+ qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
486
+ if BIAS_TYPE != "none":
487
+ tl.debug_barrier() # Race condition otherwise
488
+ if BIAS_TYPE == "vector":
489
+ if EVEN_N:
490
+ bias = tl.load(b_ptrs).to(tl.float32)
491
+ else:
492
+ bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32)
493
+ bias = bias[None, :]
494
+ elif BIAS_TYPE == "matrix":
495
+ if EVEN_M & EVEN_N:
496
+ bias = tl.load(b_ptrs).to(tl.float32)
497
+ else:
498
+ bias = tl.load(
499
+ b_ptrs,
500
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k),
501
+ other=0.0,
502
+ ).to(tl.float32)
503
+ qk = qk * softmax_scale + bias
504
+ # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
505
+ # Also wrong for headdim=64.
506
+ if not (EVEN_M & EVEN_HEADDIM):
507
+ tl.debug_barrier()
508
+ lse_i = tl.load(LSE + offs_m_curr)
509
+ if BIAS_TYPE == "none":
510
+ p = tl.exp(qk * softmax_scale - lse_i[:, None])
511
+ else:
512
+ p = tl.exp(qk - lse_i[:, None])
513
+ # compute dv
514
+ # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call
515
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs
516
+ # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512,
517
+ # the output is correct.
518
+ if EVEN_M & EVEN_HEADDIM:
519
+ do = tl.load(do_ptrs)
520
+ else:
521
+ # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
522
+ do = tl.load(
523
+ do_ptrs,
524
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
525
+ other=0.0,
526
+ )
527
+ # if EVEN_M:
528
+ # if EVEN_HEADDIM:
529
+ # do = tl.load(do_ptrs)
530
+ # else:
531
+ # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
532
+ # else:
533
+ # if EVEN_HEADDIM:
534
+ # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0)
535
+ # else:
536
+ # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
537
+ # & (offs_d[None, :] < headdim), other=0.0)
538
+ dv += tl.dot(p.to(do.dtype), do, trans_a=True)
539
+ # compute dp = dot(v, do)
540
+ # There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
541
+ # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
542
+ # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False
543
+ if not (EVEN_M & EVEN_HEADDIM):
544
+ tl.debug_barrier()
545
+ dp = tl.dot(do, v, trans_b=True)
546
+ # There's a race condition for headdim=48
547
+ if not EVEN_HEADDIM:
548
+ tl.debug_barrier()
549
+ # compute ds = p * (dp - delta[:, None])
550
+ # Putting the subtraction after the dp matmul (instead of before) is slightly faster
551
+ Di = tl.load(D + offs_m_curr)
552
+ # Converting ds to q.dtype here reduces register pressure and makes it much faster
553
+ # for BLOCK_HEADDIM=128
554
+ ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
555
+ # compute dk = dot(ds.T, q)
556
+ dk += tl.dot(ds, q, trans_a=True)
557
+ # compute dq
558
+ if not (
559
+ EVEN_M & EVEN_HEADDIM
560
+ ): # Otherewise there's a race condition when BIAS_TYPE='matrix'
561
+ tl.debug_barrier()
562
+ if not ATOMIC_ADD:
563
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
564
+ dq = tl.load(dq_ptrs, eviction_policy="evict_last")
565
+ dq += tl.dot(ds, k)
566
+ tl.store(dq_ptrs, dq, eviction_policy="evict_last")
567
+ else:
568
+ if EVEN_HEADDIM:
569
+ dq = tl.load(
570
+ dq_ptrs,
571
+ mask=offs_m_curr[:, None] < seqlen_q,
572
+ other=0.0,
573
+ eviction_policy="evict_last",
574
+ )
575
+ dq += tl.dot(ds, k)
576
+ tl.store(
577
+ dq_ptrs,
578
+ dq,
579
+ mask=offs_m_curr[:, None] < seqlen_q,
580
+ eviction_policy="evict_last",
581
+ )
582
+ else:
583
+ dq = tl.load(
584
+ dq_ptrs,
585
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
586
+ other=0.0,
587
+ eviction_policy="evict_last",
588
+ )
589
+ dq += tl.dot(ds, k)
590
+ tl.store(
591
+ dq_ptrs,
592
+ dq,
593
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
594
+ eviction_policy="evict_last",
595
+ )
596
+ else: # If we're parallelizing across the seqlen_k dimension
597
+ dq = tl.dot(ds, k)
598
+ if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
599
+ tl.atomic_add(dq_ptrs, dq)
600
+ else:
601
+ if EVEN_HEADDIM:
602
+ tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q)
603
+ else:
604
+ tl.atomic_add(
605
+ dq_ptrs,
606
+ dq,
607
+ mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
608
+ )
609
+ # increment pointers
610
+ dq_ptrs += BLOCK_M * stride_dqm
611
+ q_ptrs += BLOCK_M * stride_qm
612
+ do_ptrs += BLOCK_M * stride_dom
613
+ if BIAS_TYPE == "matrix":
614
+ b_ptrs += BLOCK_M * stride_bm
615
+ # write-back
616
+ dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :])
617
+ dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :])
618
+ _bwd_store_dk_dv(
619
+ dk_ptrs,
620
+ dv_ptrs,
621
+ dk,
622
+ dv,
623
+ offs_n,
624
+ offs_d,
625
+ seqlen_k,
626
+ headdim,
627
+ EVEN_M=EVEN_M,
628
+ EVEN_N=EVEN_N,
629
+ EVEN_HEADDIM=EVEN_HEADDIM,
630
+ )
631
+
632
 
633
  def init_to_zero(name):
634
  return lambda nargs: nargs[name].zero_()
635
 
636
+
637
+ @triton.autotune(
638
+ configs=[
639
+ triton.Config(
640
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False},
641
+ num_warps=8,
642
+ num_stages=1,
643
+ pre_hook=init_to_zero("DQ"),
644
+ ),
645
+ triton.Config(
646
+ {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True},
647
+ num_warps=8,
648
+ num_stages=1,
649
+ pre_hook=init_to_zero("DQ"),
650
+ ),
651
+ # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
652
+ # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
653
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
654
+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
655
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
656
+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
657
+ ],
658
+ key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"],
659
+ )
660
+ @triton.heuristics(
661
+ {
662
+ "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
663
+ "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
664
+ "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
665
+ }
666
+ )
667
+ @triton.jit
668
+ def _bwd_kernel(
669
+ Q,
670
+ K,
671
+ V,
672
+ Bias,
673
+ DO,
674
+ DQ,
675
+ DK,
676
+ DV,
677
+ LSE,
678
+ D,
679
+ softmax_scale,
680
+ stride_qb,
681
+ stride_qh,
682
+ stride_qm,
683
+ stride_kb,
684
+ stride_kh,
685
+ stride_kn,
686
+ stride_vb,
687
+ stride_vh,
688
+ stride_vn,
689
+ stride_bb,
690
+ stride_bh,
691
+ stride_bm,
692
+ stride_dob,
693
+ stride_doh,
694
+ stride_dom,
695
+ stride_dqb,
696
+ stride_dqh,
697
+ stride_dqm,
698
+ stride_dkb,
699
+ stride_dkh,
700
+ stride_dkn,
701
+ stride_dvb,
702
+ stride_dvh,
703
+ stride_dvn,
704
+ nheads,
705
+ seqlen_q,
706
+ seqlen_k,
707
+ seqlen_q_rounded,
708
+ headdim,
709
+ CACHE_KEY_SEQLEN_Q,
710
+ CACHE_KEY_SEQLEN_K,
711
+ BIAS_TYPE: tl.constexpr,
712
+ IS_CAUSAL: tl.constexpr,
713
+ BLOCK_HEADDIM: tl.constexpr,
714
+ SEQUENCE_PARALLEL: tl.constexpr,
715
+ EVEN_M: tl.constexpr,
716
+ EVEN_N: tl.constexpr,
717
+ EVEN_HEADDIM: tl.constexpr,
718
+ BLOCK_M: tl.constexpr,
719
+ BLOCK_N: tl.constexpr,
720
+ ):
721
+ off_hb = tl.program_id(1)
722
+ off_b = off_hb // nheads
723
+ off_h = off_hb % nheads
724
+ # offset pointers for batch/head
725
+ Q += off_b * stride_qb + off_h * stride_qh
726
+ K += off_b * stride_kb + off_h * stride_kh
727
+ V += off_b * stride_vb + off_h * stride_vh
728
+ DO += off_b * stride_dob + off_h * stride_doh
729
+ DQ += off_b * stride_dqb + off_h * stride_dqh
730
+ DK += off_b * stride_dkb + off_h * stride_dkh
731
+ DV += off_b * stride_dvb + off_h * stride_dvh
732
+ if BIAS_TYPE != "none":
733
+ Bias += off_b * stride_bb + off_h * stride_bh
734
+ # pointer to row-wise quantities in value-like data
735
+ D += off_hb * seqlen_q_rounded
736
+ LSE += off_hb * seqlen_q_rounded
737
+ if not SEQUENCE_PARALLEL:
738
+ num_block_n = tl.cdiv(seqlen_k, BLOCK_N)
739
+ for start_n in range(0, num_block_n):
740
+ _bwd_kernel_one_col_block(
741
+ start_n,
742
+ Q,
743
+ K,
744
+ V,
745
+ Bias,
746
+ DO,
747
+ DQ,
748
+ DK,
749
+ DV,
750
+ LSE,
751
+ D,
752
+ softmax_scale,
753
+ stride_qm,
754
+ stride_kn,
755
+ stride_vn,
756
+ stride_bm,
757
+ stride_dom,
758
+ stride_dqm,
759
+ stride_dkn,
760
+ stride_dvn,
761
+ seqlen_q,
762
+ seqlen_k,
763
+ headdim,
764
+ ATOMIC_ADD=False,
765
+ BIAS_TYPE=BIAS_TYPE,
766
+ IS_CAUSAL=IS_CAUSAL,
767
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
768
+ EVEN_M=EVEN_M,
769
+ EVEN_N=EVEN_N,
770
+ EVEN_HEADDIM=EVEN_HEADDIM,
771
+ BLOCK_M=BLOCK_M,
772
+ BLOCK_N=BLOCK_N,
773
+ )
774
+ else:
775
+ start_n = tl.program_id(0)
776
+ _bwd_kernel_one_col_block(
777
+ start_n,
778
+ Q,
779
+ K,
780
+ V,
781
+ Bias,
782
+ DO,
783
+ DQ,
784
+ DK,
785
+ DV,
786
+ LSE,
787
+ D,
788
+ softmax_scale,
789
+ stride_qm,
790
+ stride_kn,
791
+ stride_vn,
792
+ stride_bm,
793
+ stride_dom,
794
+ stride_dqm,
795
+ stride_dkn,
796
+ stride_dvn,
797
+ seqlen_q,
798
+ seqlen_k,
799
+ headdim,
800
+ ATOMIC_ADD=True,
801
+ BIAS_TYPE=BIAS_TYPE,
802
+ IS_CAUSAL=IS_CAUSAL,
803
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
804
+ EVEN_M=EVEN_M,
805
+ EVEN_N=EVEN_N,
806
+ EVEN_HEADDIM=EVEN_HEADDIM,
807
+ BLOCK_M=BLOCK_M,
808
+ BLOCK_N=BLOCK_N,
809
+ )
810
+
811
+
812
  def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
813
  # shape constraints
814
  batch, seqlen_q, nheads, d = q.shape
815
  _, seqlen_k, _, _ = k.shape
816
  assert k.shape == (batch, seqlen_k, nheads, d)
817
  assert v.shape == (batch, seqlen_k, nheads, d)
818
+ assert d <= 128, "FlashAttention only support head dimensions up to 128"
819
+ assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
820
+ assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
 
821
  assert q.is_cuda and k.is_cuda and v.is_cuda
822
  softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
823
 
824
  has_bias = bias is not None
825
+ bias_type = "none"
826
  if has_bias:
827
  assert bias.dtype in [q.dtype, torch.float]
828
  assert bias.is_cuda
 
830
  if bias.stride(-1) != 1:
831
  bias = bias.contiguous()
832
  if bias.shape[2:] == (1, seqlen_k):
833
+ bias_type = "vector"
834
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
835
+ bias_type = "matrix"
836
  else:
837
+ raise RuntimeError(
838
+ "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
839
+ )
840
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
841
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
 
 
 
 
 
 
 
 
 
 
842
 
843
  seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
844
+ lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
845
+ tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
 
 
 
 
846
  o = torch.empty_like(q)
847
 
848
  BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
849
+ BLOCK = 128
850
+ num_warps = 4 if d <= 64 else 8
851
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
852
+ _fwd_kernel[grid](
853
  q,
854
  k,
855
  v,
 
858
  lse,
859
  tmp,
860
  softmax_scale,
861
+ q.stride(0),
862
+ q.stride(2),
863
+ q.stride(1),
864
+ k.stride(0),
865
+ k.stride(2),
866
+ k.stride(1),
867
+ v.stride(0),
868
+ v.stride(2),
869
+ v.stride(1),
870
+ *bias_strides,
871
+ o.stride(0),
872
+ o.stride(2),
873
+ o.stride(1),
874
+ nheads,
875
+ seqlen_q,
876
+ seqlen_k,
877
+ seqlen_q_rounded,
878
+ d,
879
+ seqlen_q // 32,
880
+ seqlen_k // 32, # key for triton cache (limit number of compilations)
881
+ # Can't use kwargs here because triton autotune expects key to be args, not kwargs
882
+ # IS_CAUSAL=causal, BLOCK_HEADDIM=d,
883
+ bias_type,
884
+ causal,
885
+ BLOCK_HEADDIM,
886
+ BLOCK_M=BLOCK,
887
+ BLOCK_N=BLOCK,
888
+ num_warps=num_warps,
889
+ num_stages=1,
890
+ )
891
+ return o, lse, softmax_scale # softmax_scale could have been updated
892
+
893
+
894
+ def _flash_attn_backward(
895
+ do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None
896
+ ):
897
+ # Make sure that the last dimension is contiguous
898
+ if do.stride(-1) != 1:
899
+ do = do.contiguous()
900
+ batch, seqlen_q, nheads, d = q.shape
901
+ _, seqlen_k, _, _ = k.shape
902
+ # assert d in {16, 32, 64, 128}
903
+ assert d <= 128
904
+ seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
905
+ assert lse.shape == (batch, nheads, seqlen_q_rounded)
906
+ assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1
907
+ assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1
908
+ softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
909
+ # dq_accum = torch.zeros_like(q, dtype=torch.float32)
910
+ dq_accum = torch.empty_like(q, dtype=torch.float32)
911
+ delta = torch.empty_like(lse)
912
+ # delta = torch.zeros_like(lse)
913
+
914
+ BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16)
915
+ grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
916
+ _bwd_preprocess_do_o_dot[grid](
917
+ o,
918
+ do,
919
+ delta,
920
+ o.stride(0),
921
+ o.stride(2),
922
+ o.stride(1),
923
+ do.stride(0),
924
+ do.stride(2),
925
+ do.stride(1),
926
+ nheads,
927
+ seqlen_q,
928
+ seqlen_q_rounded,
929
+ d,
930
+ BLOCK_M=128,
931
+ BLOCK_HEADDIM=BLOCK_HEADDIM,
932
+ )
933
+
934
+ has_bias = bias is not None
935
+ bias_type = "none"
936
+ if has_bias:
937
+ assert bias.dtype in [q.dtype, torch.float]
938
+ assert bias.is_cuda
939
+ assert bias.dim() == 4
940
+ assert bias.stride(-1) == 1
941
+ if bias.shape[2:] == (1, seqlen_k):
942
+ bias_type = "vector"
943
+ elif bias.shape[2:] == (seqlen_q, seqlen_k):
944
+ bias_type = "matrix"
945
+ else:
946
+ raise RuntimeError(
947
+ "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
948
+ )
949
+ bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
950
+ bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
951
+
952
+ # BLOCK_M = 128
953
+ # BLOCK_N = 64
954
+ # num_warps = 4
955
+ grid = lambda META: (
956
+ triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1,
957
+ batch * nheads,
958
+ )
959
+ _bwd_kernel[grid](
960
+ q,
961
+ k,
962
+ v,
963
+ bias,
964
+ do,
965
+ dq_accum,
966
+ dk,
967
+ dv,
968
+ lse,
969
+ delta,
970
+ softmax_scale,
971
+ q.stride(0),
972
+ q.stride(2),
973
+ q.stride(1),
974
+ k.stride(0),
975
+ k.stride(2),
976
+ k.stride(1),
977
+ v.stride(0),
978
+ v.stride(2),
979
+ v.stride(1),
980
  *bias_strides,
981
+ do.stride(0),
982
+ do.stride(2),
983
+ do.stride(1),
984
+ dq_accum.stride(0),
985
+ dq_accum.stride(2),
986
+ dq_accum.stride(1),
987
+ dk.stride(0),
988
+ dk.stride(2),
989
+ dk.stride(1),
990
+ dv.stride(0),
991
+ dv.stride(2),
992
+ dv.stride(1),
993
  nheads,
994
  seqlen_q,
995
  seqlen_k,
 
1002
  bias_type,
1003
  causal,
1004
  BLOCK_HEADDIM,
1005
+ # SEQUENCE_PARALLEL=False,
1006
+ # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
1007
  # num_warps=num_warps,
1008
  # num_stages=1,
1009
  )
1010
+ dq.copy_(dq_accum)
1011
 
 
1012
 
1013
+ class FlashAttnQKVPackedFunc(torch.autograd.Function):
1014
+ @staticmethod
1015
+ def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None):
1016
+ """
1017
+ qkv: (batch, seqlen, 3, nheads, headdim)
1018
+ bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen).
1019
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen).
1020
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen)
1021
+ """
1022
+ # Make sure that the last dimension is contiguous
1023
+ if qkv.stride(-1) != 1:
1024
+ qkv = qkv.contiguous()
1025
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1026
+ qkv[:, :, 0],
1027
+ qkv[:, :, 1],
1028
+ qkv[:, :, 2],
1029
+ bias=bias,
1030
+ causal=causal,
1031
+ softmax_scale=softmax_scale,
1032
+ )
1033
+ ctx.save_for_backward(qkv, o, lse, bias)
1034
+ ctx.causal = causal
1035
+ return o
1036
+
1037
+ @staticmethod
1038
+ def backward(ctx, do):
1039
+ qkv, o, lse, bias = ctx.saved_tensors
1040
+ assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet"
1041
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1042
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1043
+ with torch.inference_mode():
1044
+ dqkv = torch.empty_like(qkv)
1045
+ _flash_attn_backward(
1046
+ do,
1047
+ qkv[:, :, 0],
1048
+ qkv[:, :, 1],
1049
+ qkv[:, :, 2],
1050
+ o,
1051
+ lse,
1052
+ dqkv[:, :, 0],
1053
+ dqkv[:, :, 1],
1054
+ dqkv[:, :, 2],
1055
+ bias=bias,
1056
+ causal=ctx.causal,
1057
+ softmax_scale=ctx.softmax_scale,
1058
+ )
1059
+ return dqkv, None, None, None
1060
+
1061
+
1062
+ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
1063
+
1064
+
1065
+ class FlashAttnKVPackedFunc(torch.autograd.Function):
1066
+ @staticmethod
1067
+ def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None):
1068
+ """
1069
+ q: (batch, seqlen_q, nheads, headdim)
1070
+ kv: (batch, seqlen_k, 2, nheads, headdim)
1071
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1072
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1073
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
1074
+ """
1075
+ # Make sure that the last dimension is contiguous
1076
+ q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]]
1077
+ o, lse, ctx.softmax_scale = _flash_attn_forward(
1078
+ q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale
1079
+ )
1080
+ ctx.save_for_backward(q, kv, o, lse, bias)
1081
+ ctx.causal = causal
1082
+ return o
1083
+
1084
+ @staticmethod
1085
+ def backward(ctx, do):
1086
+ q, kv, o, lse, bias = ctx.saved_tensors
1087
+ if len(ctx.needs_input_grad) >= 3:
1088
+ assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet"
1089
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1090
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1091
+ with torch.inference_mode():
1092
+ dq = torch.empty_like(q)
1093
+ dkv = torch.empty_like(kv)
1094
+ _flash_attn_backward(
1095
+ do,
1096
+ q,
1097
+ kv[:, :, 0],
1098
+ kv[:, :, 1],
1099
+ o,
1100
+ lse,
1101
+ dq,
1102
+ dkv[:, :, 0],
1103
+ dkv[:, :, 1],
1104
+ bias=bias,
1105
+ causal=ctx.causal,
1106
+ softmax_scale=ctx.softmax_scale,
1107
+ )
1108
+ return dq, dkv, None, None, None
1109
+
1110
+
1111
+ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
1112
+
1113
+
1114
+ class FlashAttnFunc(torch.autograd.Function):
1115
  @staticmethod
1116
  def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None):
1117
+ """
1118
+ q: (batch_size, seqlen_q, nheads, headdim)
1119
+ k, v: (batch_size, seqlen_k, nheads, headdim)
1120
+ bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
1121
+ For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
1122
+ ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
 
 
 
 
 
 
1123
  """
1124
  # Make sure that the last dimension is contiguous
1125
+ q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]]
 
 
1126
  o, lse, ctx.softmax_scale = _flash_attn_forward(
1127
+ q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale
1128
+ )
1129
  ctx.save_for_backward(q, k, v, o, lse, bias)
1130
  ctx.causal = causal
1131
  return o
1132
 
1133
  @staticmethod
1134
  def backward(ctx, do):
1135
+ q, k, v, o, lse, bias = ctx.saved_tensors
1136
+ assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet"
1137
+ # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
1138
+ # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
1139
+ with torch.inference_mode():
1140
+ dq = torch.empty_like(q)
1141
+ dk = torch.empty_like(k)
1142
+ dv = torch.empty_like(v)
1143
+ _flash_attn_backward(
1144
+ do,
1145
+ q,
1146
+ k,
1147
+ v,
1148
+ o,
1149
+ lse,
1150
+ dq,
1151
+ dk,
1152
+ dv,
1153
+ bias=bias,
1154
+ causal=ctx.causal,
1155
+ softmax_scale=ctx.softmax_scale,
1156
+ )
1157
+ return dq, dk, dv, None, None, None
1158
+
1159
 
1160
+ flash_attn_func = FlashAttnFunc.apply