Upload triton_flash_blocksparse_attn.py

#25
Files changed (1) hide show
  1. triton_flash_blocksparse_attn.py +58 -56
triton_flash_blocksparse_attn.py CHANGED
@@ -611,30 +611,31 @@ def _forward(ctx, q, k, v, layout_crow_indices, layout_col_indices, sm_scale, BL
611
  # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
  # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
 
614
- _fwd_kernel[grid](
615
- q, k, v, sm_scale,
616
- layout_crow_indices,
617
- layout_col_indices,
618
- layout_crow_indices.stride(0), layout_crow_indices.stride(1),
619
- layout_col_indices.stride(0), layout_col_indices.stride(1),
620
- tmp, L, m,
621
- o,
622
- q.stride(0), q.stride(1), q.stride(2), q.stride(3),
623
- k.stride(0), k.stride(1), k.stride(2), k.stride(3),
624
- v.stride(0), v.stride(1), v.stride(2), v.stride(3),
625
- o.stride(0), o.stride(1), o.stride(2), o.stride(3),
626
- q.shape[0], q.shape[1], k.shape[2],
627
- k.shape[2] - q.shape[2],
628
- q_rounded_len,
629
- BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
630
- BLOCK_DMODEL=BLOCK_DMODEL,
631
- EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
632
- EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
633
- INFERENCE=inference,
634
- NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
635
- num_warps=num_warps,
636
- num_stages=num_stages,
637
- )
 
638
  if inference:
639
  L, m = None, None
640
 
@@ -991,37 +992,38 @@ def blocksparse_flash_attn_padded_fwd(
991
 
992
  grid = (len(q_start_sids), n_heads)
993
 
994
- _fwd_kernel_batch_inference[grid](
995
- q, k, v, out,
996
- sm_scale,
997
- q_batch_starts,
998
- q_batch_ends,
999
- k_batch_starts,
1000
- k_batch_ends,
1001
- q_batch_ids,
1002
- q_start_sids,
1003
-
1004
- *q.stride(),
1005
- *k.stride(),
1006
- *v.stride(),
1007
- *out.stride(),
1008
-
1009
- layout_crow_indices,
1010
- layout_col_indices,
1011
- *layout_crow_indices.stride(),
1012
- *layout_col_indices.stride(),
1013
-
1014
- q_k_ratio,
1015
- HAS_BATCH_DIM = True,
1016
- D_HEAD = head_size,
1017
- BLOCK_M = block_size,
1018
- BLOCK_N = block_size,
1019
- BLOCK_D = block_d,
1020
- BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1021
- EVEN_D = block_d == head_size,
1022
- num_warps = 1 if q_len == 1 else 4,
1023
- num_stages = 3
1024
- )
 
1025
 
1026
  return out
1027
 
@@ -1940,4 +1942,4 @@ if __name__ == '__main__':
1940
  # 4 4096.0 3.401622 6.221376 1.636039
1941
  # 5 8192.0 11.915136 23.483391 3.968725
1942
  # 6 16384.0 44.660225 91.302910 10.857130
1943
- # 7 32768.0 175.038467 359.048187 32.778240
 
611
  # print(f'> {q.shape=}, {k.shape=}, {layout_crow_indices.shape}, {layout_col_indices.shape}, {layout_crow_indices.stride()}, \
612
  # {layout_col_indices.stride()}, {layout_crow_indices=}, {layout_col_indices=}')
613
 
614
+ with torch.cuda.device(q.device.index):
615
+ _fwd_kernel[grid](
616
+ q, k, v, sm_scale,
617
+ layout_crow_indices,
618
+ layout_col_indices,
619
+ layout_crow_indices.stride(0), layout_crow_indices.stride(1),
620
+ layout_col_indices.stride(0), layout_col_indices.stride(1),
621
+ tmp, L, m,
622
+ o,
623
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
624
+ k.stride(0), k.stride(1), k.stride(2), k.stride(3),
625
+ v.stride(0), v.stride(1), v.stride(2), v.stride(3),
626
+ o.stride(0), o.stride(1), o.stride(2), o.stride(3),
627
+ q.shape[0], q.shape[1], k.shape[2],
628
+ k.shape[2] - q.shape[2],
629
+ q_rounded_len,
630
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
631
+ BLOCK_DMODEL=BLOCK_DMODEL,
632
+ EVEN_M_BLOCK=q.shape[2] % BLOCK_M == 0,
633
+ EVEN_N_BLOCK=k.shape[2] % BLOCK_N == 0 ,
634
+ INFERENCE=inference,
635
+ NUM_DBLOCKS=q.shape[-1] // BLOCK_DMODEL,
636
+ num_warps=num_warps,
637
+ num_stages=num_stages,
638
+ )
639
  if inference:
640
  L, m = None, None
641
 
 
992
 
993
  grid = (len(q_start_sids), n_heads)
994
 
995
+ with torch.cuda.device(q.device.index):
996
+ _fwd_kernel_batch_inference[grid](
997
+ q, k, v, out,
998
+ sm_scale,
999
+ q_batch_starts,
1000
+ q_batch_ends,
1001
+ k_batch_starts,
1002
+ k_batch_ends,
1003
+ q_batch_ids,
1004
+ q_start_sids,
1005
+
1006
+ *q.stride(),
1007
+ *k.stride(),
1008
+ *v.stride(),
1009
+ *out.stride(),
1010
+
1011
+ layout_crow_indices,
1012
+ layout_col_indices,
1013
+ *layout_crow_indices.stride(),
1014
+ *layout_col_indices.stride(),
1015
+
1016
+ q_k_ratio,
1017
+ HAS_BATCH_DIM = True,
1018
+ D_HEAD = head_size,
1019
+ BLOCK_M = block_size,
1020
+ BLOCK_N = block_size,
1021
+ BLOCK_D = block_d,
1022
+ BLOCK_M_LOADING = 16 if q_len == 1 else block_size, # smaller for decoding
1023
+ EVEN_D = block_d == head_size,
1024
+ num_warps = 1 if q_len == 1 else 4,
1025
+ num_stages = 3
1026
+ )
1027
 
1028
  return out
1029
 
 
1942
  # 4 4096.0 3.401622 6.221376 1.636039
1943
  # 5 8192.0 11.915136 23.483391 3.968725
1944
  # 6 16384.0 44.660225 91.302910 10.857130
1945
+ # 7 32768.0 175.038467 359.048187 32.778240