diff --git "a/mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_scan.py" "b/mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_scan.py" new file mode 100644--- /dev/null +++ "b/mamba/build/lib/mamba_ssm/ops/triton/ssd_chunk_scan.py" @@ -0,0 +1,1829 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +"""We want triton==2.1.0 or 2.2.0 for this +""" + +import math +from packaging import version + +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + +from einops import rearrange, repeat + +from mamba_ssm.ops.triton.ssd_bmm import _bmm_chunk_fwd, _bmm_chunk_bwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse('2.2.0') + + +def init_to_zero(names): + return lambda nargs: [nargs[name].zero_() for name in names if nargs[name] is not None] + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['chunk_size', 'hdim', 'dstate', 'IS_CAUSAL'], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Without the if (pid_c > -1), with Triton 2.1.0, I get + # Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mm a layout conversion"' failed. + # With Triton 2.2.0, this works + if IS_TRITON_22 or pid_c > -1: + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc = tl.dot(C, prev_states) * scale_m[:, None] + else: + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), other=0.0) + # C = (C * scale_m[:, None]).to(C_ptr.dtype.element_ty) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), other=0.0) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp((dA_cs_m[:, None] - dA_cs_k[None, :])) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), other=0.0) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load(x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim) + tl.store(out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_N': 256}, num_stages=4, num_warps=4), + # triton.Config({'BLOCK_SIZE_N': 128}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_N': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_N': 32}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim', 'dstate'], +) +@triton.jit +def _chunk_scan_fwd_kernel_wip( + # Pointers to matrices + cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, B_ptr, prev_states_ptr, D_ptr, + # Matrix dimensions + chunk_size, hdim, dstate, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_B_batch, stride_B_seqlen, stride_B_head, stride_B_dstate, + stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, + stride_D_head, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_n = tl.program_id(axis=0) + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + B_ptr += pid_b * stride_B_batch + pid_c * chunk_size * stride_B_seqlen + (pid_h // nheads_ngroups_ratio) * stride_B_head + prev_states_ptr += pid_b * stride_states_batch + pid_c * stride_states_chunk + pid_h * stride_states_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + + offs_m = tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE) + + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + B_ptrs = B_ptr + (offs_m[None, :] * stride_B_seqlen + offs_k_dstate[:, None] * stride_B_dstate) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_states_hdim + offs_k_dstate[:, None] * stride_states_dstate) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_m[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + + prev_states = tl.load(prev_states_ptrs, mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), other=0.0) + # if pid_c == 0: + # if pid_b == 0: + # if pid_h == 0: + # tl.device_print("", prev_states) + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # scale_m = tl.exp(dA_cs_m) + # C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), other=0.0) + # acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_m[None, :] < chunk_size), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + # dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + # x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + # if HAS_D: + # if D_HAS_HDIM: + # D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + # else: + # D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + # acc += x.to(tl.float32) * D + # tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + for start_m in range(0, chunk_size_limit, BLOCK_SIZE_M): + start_m = tl.multiple_of(start_m, BLOCK_SIZE_M) + dA_cs_m = tl.load(dA_cumsum_ptr + (start_m + offs_m) * stride_dA_cs_csize, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr + start_m - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + (start_m + offs_m) * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit - start_m, other=-1) + if not HAS_SEQ_IDX: + scale_m = tl.exp(dA_cs_m) + else: + scale_m = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + C = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_k_dstate[None, :] < dstate), other=0.0) + acc = tl.dot(C, prev_states.to(C_ptr.dtype.element_ty)) * scale_m[:, None] + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size - start_m) & (offs_m[None, :] < chunk_size - start_m), other=0.0).to(tl.float32) + # cb *= tl.exp((dA_cs_m[:, None] - dA_cs_m[None, :])) + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size - start_m, other=0.0).to(tl.float32) + # cb *= dt_m + # mask = offs_m[:, None] >= offs_m[None, :] + # cb = tl.where(mask, cb, 0.0) + # cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim), other=0.0) + # acc += tl.dot(cb, x) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + acc += x.to(tl.float32) * D + + # if HAS_Z: + # out_x_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + # out_x_ptrs = out_x_ptr + (stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :]) + # tl.store(out_x_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim)) + + # z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + # z_ptrs = z_ptr + (stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :]) + # z = tl.load(z_ptrs, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), other=0.0).to(tl.float32) + # acc *= z * tl.sigmoid(z) + + tl.store(out_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit - start_m) & (offs_n[None, :] < hdim)) + + # TODO: this is not correct, and quite a bit slower + if start_m + BLOCK_SIZE_M < chunk_size_limit: + # B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0).to(tl.float32) + B = tl.load(B_ptrs, mask=(offs_m[None, :] < chunk_size_limit - start_m) & (offs_k_dstate[:, None] < dstate), other=0.0) + dA_cs_last = tl.load(dA_cumsum_ptr + (start_m + BLOCK_SIZE_M) * stride_dA_cs_csize).to(tl.float32) + # TODO: seq_idx + scale = tl.exp((dA_cs_last - dA_cs_m)) * dt_m + # B *= scale + B = B.to(x_ptr.dtype.element_ty) + tmp = tl.dot(B, x) + prev_states += tmp.to(prev_states.dtype) + + C_ptrs += BLOCK_SIZE_M * stride_C_seqlen + B_ptrs += BLOCK_SIZE_M * stride_B_seqlen + cb_ptrs += BLOCK_SIZE_M * stride_cb_csize_m + BLOCK_SIZE_M * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_M * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_M * stride_dt_csize + out_ptrs += BLOCK_SIZE_M * stride_out_seqlen + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_dz_kernel( + # Pointers to matrices + dout_ptr, out_ptr, z_ptr, x_ptr, D_ptr, outz_ptr, dz_ptr, dout_x_ptr, dD_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_outz_batch, stride_outz_seqlen, stride_outz_head, stride_outz_hdim, + stride_dz_batch, stride_dz_seqlen, stride_dz_head, stride_dz_hdim, + stride_doutx_batch, stride_doutx_seqlen, stride_doutx_head, stride_doutx_hdim, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_DDACS: tl.constexpr, + RECOMPUTE_OUTPUT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dout_x_ptr += pid_b * stride_doutx_batch + pid_c * chunk_size * stride_doutx_seqlen + pid_h * stride_doutx_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + z_ptr += pid_b * stride_z_batch + pid_c * chunk_size * stride_z_seqlen + pid_h * stride_z_head + dz_ptr += pid_b * stride_dz_batch + pid_c * chunk_size * stride_dz_seqlen + pid_h * stride_dz_head + if RECOMPUTE_OUTPUT: + outz_ptr += pid_b * stride_outz_batch + pid_c * chunk_size * stride_outz_seqlen + pid_h * stride_outz_head + if HAS_DDACS: + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_x_ptrs = dout_x_ptr + (offs_m[:, None] * stride_doutx_seqlen + offs_n[None, :] * stride_doutx_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + z_ptrs = z_ptr + (offs_m[:, None] * stride_z_seqlen + offs_n[None, :] * stride_z_hdim) + dz_ptrs = dz_ptr + (offs_m[:, None] * stride_dz_seqlen + offs_n[None, :] * stride_dz_hdim) + if RECOMPUTE_OUTPUT: + outz_ptrs = outz_ptr + (offs_m[:, None] * stride_outz_seqlen + offs_n[None, :] * stride_outz_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z = tl.load(z_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + z_sigmoid = tl.sigmoid(z) + if RECOMPUTE_OUTPUT: + outz = out * z * z_sigmoid + tl.store(outz_ptrs, outz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dz = dout * out * z_sigmoid * (1 + z * (1 - z_sigmoid)) + tl.store(dz_ptrs, dz, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + dout *= z * z_sigmoid + tl.store(dout_x_ptrs, dout, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + if HAS_DDACS: + ddA_cs = tl.sum(dout * out, axis=1) + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2), + ], + key=['hdim', 'dstate', 'chunk_size'], +) +@triton.jit +def _chunk_scan_bwd_dstates_kernel( + # Pointers to matrices + dout_ptr, c_ptr, dprev_states_ptr, dA_cumsum_ptr, seq_idx_ptr, + # Matrix dimensions + hdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_c_batch, stride_c_seqlen, stride_c_head, stride_c_dstate, + stride_dprev_states_batch, stride_dprev_states_chunk, stride_dprev_states_head, stride_dprev_states_hdim, stride_dprev_states_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + c_ptr += pid_b * stride_c_batch + pid_c * chunk_size * stride_c_seqlen + (pid_h // nheads_ngroups_ratio) * stride_c_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_hdim + offs_k[None, :] * stride_dout_seqlen) + c_ptrs = c_ptr + (offs_n[None, :] * stride_c_dstate + offs_k[:, None] * stride_c_seqlen) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs = seq_idx_ptr + offs_k * stride_seq_idx_seqlen + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), other=0.0).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale_k = tl.exp(dA_cs_k) + else: + seq_idx_k = tl.load(seq_idx_ptrs, mask=offs_k < chunk_size_limit - k, other=-1) + scale_k = tl.where(seq_idx_k == seq_idx_prev, tl.exp(dA_cs_k), 0.0) + dout = (dout * scale_k).to(dout_ptr.dtype.element_ty) + c = tl.load(c_ptrs, mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), other=0.0) + acc += tl.dot(dout, c) + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + c_ptrs += BLOCK_SIZE_K * stride_c_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + if HAS_SEQ_IDX: + seq_idx_ptrs += BLOCK_SIZE_K * stride_seq_idx_seqlen + out = acc.to(dprev_states_ptr.dtype.element_ty) + + dprev_states_ptr += pid_b * stride_dprev_states_batch + pid_c * stride_dprev_states_chunk + pid_h * stride_dprev_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dprev_states_ptrs = dprev_states_ptr + (offs_m[:, None] * stride_dprev_states_hdim + offs_n[None, :] * stride_dprev_states_dstate) + tl.store(dprev_states_ptrs, out, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dc_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + dc_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dc_batch, stride_dc_seqlen, stride_dc_split, stride_dc_group, stride_dc_dstate, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dc_ptr += pid_b * stride_dc_batch + pid_c * chunk_size * stride_dc_seqlen + pid_g * stride_dc_group + pid_s * stride_dc_split + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_prev_states_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + pid_g * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + if HAS_DDA_CS: + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + dc = tl.dot(dout, prev_states) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + else: + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + dc *= scale[:, None] + if HAS_DDA_CS: + ddA_cs = tl.sum(dc * c, axis=1) + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + acc += dc + dout_ptrs += stride_dout_head + prev_states_ptrs += stride_prev_states_head + dA_cumsum_ptrs += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptrs += stride_ddA_cs_head + # if HAS_SEQ_IDX: + # seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + # seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + # acc = tl.where(seq_idx_m[:, None] == seq_idx_prev, acc, 0.0) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dc_ptrs = dc_ptr + (offs_m[:, None] * stride_dc_seqlen + offs_n[None, :] * stride_dc_dstate) + tl.store(dc_ptrs, acc, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate)) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4, pre_hook=init_to_zero(["ddt_ptr"])), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_dx_kernel( + # Pointers to matrices + x_ptr, cb_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, D_ptr, + dx_ptr, ddt_ptr, # dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_D_head, + stride_dx_batch, stride_dx_seqlen, stride_dx_head, stride_dx_hdim, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + # stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_hdim, stride_dD_csize, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + # if HAS_D: + # dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + dout_ptrs = dout_ptr + (offs_k[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # Idk why limiting K_MAX gives wrong results, is it a Triton bug? + # K_MAX = min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + K_MAX = chunk_size_limit + for k in range(0, K_MAX, BLOCK_SIZE_K): + # For some reason setting mask to (offs_m[:, None] < chunk_size_limit) is much slower + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < K_MAX - k), other=0.0) + dout = tl.load(dout_ptrs, mask=(offs_k[:, None] < K_MAX - k) & (offs_n[None, :] < hdim), other=0.0) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < K_MAX - k, other=0.0).to(tl.float32) + cb *= tl.exp(dA_cs_k[None, :] - dA_cs_m[:, None]) + # If we don't have the (k + offs_k[None, :] < K_MAX) mask, for indices outside this range, + # we might have dA_cs_m = 0.0 and dA_cs_k very negative, and tl.exp will return inf. + # Multiplying with cb, which is 0.0 outside the range, will make the result NaN. + # This will cause NaN in acc, and hence NaN in dx and ddt. + mask = (k + offs_k[None, :] >= offs_m[:, None]) & (k + offs_k[None, :] < K_MAX) + cb = tl.where(mask, cb, 0.0) + cb = cb.to(dout_ptr.dtype.element_ty) + acc += tl.dot(cb, dout) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + dout_ptrs += BLOCK_SIZE_K * stride_dout_seqlen + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dt_ptrs = dt_ptr + offs_m * stride_dt_csize + dt_m = tl.load(dt_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dx = acc * dt_m[:, None] + dx_ptr += pid_b * stride_dx_batch + pid_c * chunk_size * stride_dx_seqlen + pid_h * stride_dx_head + dx_ptrs = dx_ptr + (offs_m[:, None] * stride_dx_seqlen + offs_n[None, :] * stride_dx_hdim) + if HAS_D: + dout_res_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + dout_res = tl.load(dout_res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + dx += dout_res * D + tl.store(dx_ptrs, dx, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim)) + + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + ddt = tl.sum(acc * x, axis=1) + ddt_ptrs = ddt_ptr + offs_m * stride_ddt_csize + tl.atomic_add(ddt_ptrs, ddt, mask=offs_m < chunk_size) + + # if HAS_D: + # dout_new_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize + offs_n[None, :] * stride_dout_hdim) + # dout = tl.load(dout_new_ptrs, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N), other=0.0).to(tl.float32) + # dD = tl.sum(x * dout, axis=0) + # tl.store(dD_ptr + offs_n * stride_dD_hdim, dD, mask=offs_n < N) + + +# Disabling HAS_DDA_CS for now since it's much slower +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: max(triton.next_power_of_2(args["chunk_size"]), 16)}) +# @triton.heuristics({"BLOCK_SIZE_N": lambda args: 32}) +@triton.jit +def _chunk_scan_bwd_dcb_kernel( + # Pointers to matrices + x_ptr, dout_ptr, cb_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, + dcb_ptr, ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_dcb_batch, stride_dcb_chunk, stride_dcb_split, stride_dcb_group, stride_dcb_csize_m, stride_dcb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + HAS_DDA_CS: tl.constexpr, + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_sg = tl.program_id(axis=2) + pid_s = pid_sg // ngroups + pid_g = pid_sg - pid_s * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_dA_cs_head + if HAS_DDA_CS: + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + pid_g * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + (pid_g * (nheads // ngroups) + pid_s * nheads_per_program) * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + if HAS_DDA_CS: + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=dcb_ptr.dtype.element_ty), mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + return + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if HAS_DDA_CS: + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + nheads_iter = min(nheads_per_program, nheads // ngroups - pid_s * nheads_per_program) + for h in range(nheads_iter): + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + dcb = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + dcb *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size_limit, other=0.0).to(tl.float32) + dcb *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + if HAS_DDA_CS: + tl.static_assert(not HAS_SEQ_IDX, "HAS_SEQ_IDX not supported with HAS_DDA_CS yet") + ddA_cs = dcb * cb + mask = offs_m[:, None] >= offs_n[None, :] + 1 + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.cumsum(ddA_cs, axis=1) + ddA_cs = tl.where(mask, ddA_cs, 0.0) + ddA_cs = tl.sum(ddA_cs, axis=0) + tl.store(ddA_cumsum_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddA_cumsum_ptr, 0.0) + acc += dcb + dout_ptrs += stride_dout_head + x_ptrs += stride_x_head + dt_ptrs += stride_dt_head + dA_cumsum_ptr += stride_dA_cs_head + if HAS_DDA_CS: + ddA_cumsum_ptr += stride_ddA_cs_head + ddA_cumsum_ptrs += stride_ddA_cs_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if HAS_SEQ_IDX: + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2) + acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0) + mask = offs_m[:, None] >= offs_n[None, :] + acc = tl.where(mask, acc, 0.0) + dcb_ptr += pid_b * stride_dcb_batch + pid_c * stride_dcb_chunk + pid_g * stride_dcb_group + pid_s * stride_dcb_split + dcb_ptrs = dcb_ptr + (offs_m[:, None] * stride_dcb_csize_m + offs_n[None, :] * stride_dcb_csize_n) + tl.store(dcb_ptrs, acc, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +# Not numerically stable and should not be used. Leaving here for reference. +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32}), + triton.Config({'BLOCK_SIZE_M': 64}), + triton.Config({'BLOCK_SIZE_M': 128}), + triton.Config({'BLOCK_SIZE_M': 256}), + ], + key=["chunk_size", "hdim"], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_unstable_kernel( + # Pointers to matrices + dout_ptr, out_ptr, dt_ptr, ddt_ptr, x_ptr, D_ptr, + ddA_cumsum_ptr, dD_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_ddt_batch, stride_ddt_chunk, stride_ddt_head, stride_ddt_csize, + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_D_head, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + stride_dD_batch, stride_dD_chunk, stride_dD_head, stride_dD_csize, stride_dD_hdim, + # Meta-parameters + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + SUBTRACT_DDTDT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + out_ptr += pid_b * stride_out_batch + pid_c * chunk_size * stride_out_seqlen + pid_h * stride_out_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + ddt_ptr += pid_b * stride_ddt_batch + pid_c * stride_ddt_chunk + pid_h * stride_ddt_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + if HAS_D: + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dD_ptr += pid_b * stride_dD_batch + pid_c * stride_dD_chunk + pid_h * stride_dD_head + pid_m * stride_dD_csize + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_n[None, :] * stride_dout_hdim) + out_ptrs = out_ptr + (offs_m[:, None] * stride_out_seqlen + offs_n[None, :] * stride_out_hdim) + if HAS_D: + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + if D_HAS_HDIM: + dD_ptrs = dD_ptr + offs_n * stride_dD_hdim + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + out = tl.load(out_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if HAS_D: + x = tl.load(x_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), other=0.0).to(tl.float32) + if D_HAS_HDIM: + dD = tl.sum(dout * x, axis=0) + tl.store(dD_ptrs, dD, mask=offs_n < hdim) + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to(tl.float32) + else: + dD = tl.sum(dout * x) + tl.store(dD_ptr, dD) + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + out -= x * D + ddA_cs = tl.sum(dout * out, axis=1) + if SUBTRACT_DDTDT: + dt = tl.load(dt_ptr + offs_m * stride_dt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddt = tl.load(ddt_ptr + offs_m * stride_ddt_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + ddA_cs -= dt * ddt + tl.store(ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize, ddA_cs, mask=offs_m < chunk_size) + + +@triton.autotune( + configs=[ + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=3, num_warps=4), + # triton.Config({'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + # triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 16}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 32}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64}, num_stages=4, num_warps=8), + triton.Config({'BLOCK_SIZE_M': 128}, num_stages=4, num_warps=8), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel_old( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddAcs_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddAcs_batch, stride_ddAcs_chunk, stride_ddAcs_head, stride_ddAcs_csize_m, stride_ddAcs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + acc = tl.dot(dout, x) + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size), other=0.0).to(tl.float32) + acc *= cb + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= dt_n + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + acc = tl.cumsum(acc, axis=1) + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + tl.store(ddAcs_ptrs + stride_ddAcs_csize_n, ddA_cs, mask=offs_n < chunk_size - 1) + tl.store(ddAcs_ptr, 0.0) + + # offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, 64) + # offs_k = tl.arange(0, BLOCK_SIZE_K) + # dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + # x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + # dt_ptrs = dt_ptr + offs_n * stride_dt_csize + # cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + + # chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + # chunk_size_limit_n = min(chunk_size_limit, (pid_m + 1) * BLOCK_SIZE_M) + # rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + # dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # ddAcs_ptr += pid_b * stride_ddAcs_batch + pid_c * stride_ddAcs_chunk + pid_h * stride_ddAcs_head + pid_m * stride_ddAcs_csize_m + # ddAcs_ptrs = ddAcs_ptr + offs_n * stride_ddAcs_csize_n + # for n in range(0, chunk_size_limit_n, 64): + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n - n), other=0.0) + # acc = tl.dot(dout, x) + # cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - n), other=0.0).to(tl.float32) + # acc *= cb + # dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= dt_n + # dA_cs_n = tl.load(dA_cumsum_ptr + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - n, other=0.0).to(tl.float32) + # acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + # mask = offs_m[:, None] >= offs_n[None, :] + 1 + n + # acc = tl.where(mask, acc, 0.0) + # acc = tl.cumsum(acc, axis=1) + # acc = tl.where(mask, acc, 0.0) + # ddA_cs = tl.sum(acc, axis=0) + # tl.store(ddAcs_ptrs, ddA_cs, mask=offs_n < chunk_size - 1 - n) + # # tl.store(ddAcs_ptr, 0.0) + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4), + ], + key=['chunk_size', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_stable_kernel( + # Pointers to matrices + x_ptr, dout_ptr, dt_ptr, dA_cumsum_ptr, cb_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, hdim, + batch, seqlen, nheads_ngroups_ratio, + # Strides + stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_n, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize_m, stride_ddA_cs_csize_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + pid_m = tl.program_id(axis=0) + + x_ptr += pid_b * stride_x_batch + pid_c * chunk_size * stride_x_seqlen + pid_h * stride_x_head + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + dt_ptr += pid_b * stride_dt_batch + pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + cb_ptr += pid_b * stride_cb_batch + pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + pid_m * stride_ddA_cs_csize_m + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + x_ptrs = x_ptr + (offs_n[None, :] * stride_x_seqlen + offs_k[:, None] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_n * stride_dt_csize + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_n[None, :] * stride_cb_csize_n) + ddAcs_ptrs = ddA_cumsum_ptr + offs_n * stride_ddA_cs_csize_n + tl.store(ddA_cumsum_ptr, 0.0) + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + rowsum = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + dA_cs_m = tl.load(dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0).to(tl.float32) + # Actually hi is (pid_m + 1) * BLOCK_SIZE_M - 1 but subtracting 1 makes it slower + lo, hi = 0, (pid_m + 1) * BLOCK_SIZE_M + # lo, hi = 0, chunk_size + for start_n in range(lo, hi, BLOCK_SIZE_N): + start_n = tl.multiple_of(start_n, BLOCK_SIZE_N) + # Doing a matmul loop with cumsum later on will cause Triton to crash + # Instead we do just one big matmul + # acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + # for k in range(0, hdim, BLOCK_SIZE_K): + # dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim - k), other=0.0) + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim - k) & (offs_n[None, :] < chunk_size_limit), other=0.0) + # acc += tl.dot(dout, x) + # dout_ptrs += BLOCK_SIZE_K * stride_dout_hdim + # x_ptrs += BLOCK_SIZE_K * stride_x_hdim + # x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit_n), other=0.0) + x = tl.load(x_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < chunk_size_limit - start_n), other=0.0) + acc = tl.dot(dout, x) + dt_n = tl.load(dt_ptrs, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= dt_n + # If there's seq_idx, we already zero'ed out cb[i, j] for seq_idx[i] != seq_idx[j] + cb = tl.load(cb_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size - start_n), other=0.0).to(tl.float32) + acc *= cb + dA_cs_n = tl.load(dA_cumsum_ptr + start_n + offs_n * stride_dA_cs_csize, mask=offs_n < chunk_size - start_n, other=0.0).to(tl.float32) + acc *= tl.exp(dA_cs_m[:, None] - dA_cs_n[None, :]) + mask = offs_m[:, None] >= start_n + offs_n[None, :] + 1 + acc = tl.where(mask, acc, 0.0) + rowsum_new = rowsum + tl.sum(acc, axis=1) + acc = rowsum[:, None] + tl.cumsum(acc, axis=1) + rowsum = rowsum_new + acc = tl.where(mask, acc, 0.0) + ddA_cs = tl.sum(acc, axis=0) + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, ddA_cs, mask=offs_n < chunk_size - start_n - 1) + x_ptrs += BLOCK_SIZE_N * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_N * stride_dt_csize + cb_ptrs += BLOCK_SIZE_N * stride_cb_csize_n + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + # Need to zero out the rest, since we'll be summing the rows together + for start_n in range(hi, chunk_size, BLOCK_SIZE_N): + tl.store(ddAcs_ptrs + stride_ddA_cs_csize_n, tl.zeros((BLOCK_SIZE_N,), dtype=tl.float32), mask=offs_n < chunk_size - start_n - 1) + ddAcs_ptrs += BLOCK_SIZE_N * stride_ddA_cs_csize_n + + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 128}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32}, num_stages=3, num_warps=4, pre_hook=init_to_zero(["ddA_cumsum_ptr"])), + ], + key=['chunk_size', 'dstate', 'hdim'], +) +@triton.jit +def _chunk_scan_bwd_ddAcs_prev_kernel( + # Pointers to matrices + dout_ptr, prev_states_ptr, C_ptr, dA_cumsum_ptr, seq_idx_ptr, + ddA_cumsum_ptr, + # Matrix dimensions + chunk_size, dstate, hdim, + batch, seqlen, nchunks, nheads_ngroups_ratio, + # Strides + stride_dout_batch, stride_dout_seqlen, stride_dout_head, stride_dout_hdim, + stride_prev_states_batch, stride_prev_states_chunk, stride_prev_states_head, stride_prev_states_hdim, stride_prev_states_dstate, + stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, + stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, + stride_seq_idx_batch, stride_seq_idx_seqlen, + stride_ddA_cs_batch, stride_ddA_cs_chunk, stride_ddA_cs_head, stride_ddA_cs_csize, + # Meta-parameters + HAS_SEQ_IDX: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +): + pid_bc = tl.program_id(axis=1) + pid_c = pid_bc // batch + pid_b = pid_bc - pid_c * batch + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + dout_ptr += pid_b * stride_dout_batch + pid_c * chunk_size * stride_dout_seqlen + pid_h * stride_dout_head + prev_states_ptr += pid_b * stride_prev_states_batch + pid_c * stride_prev_states_chunk + pid_h * stride_prev_states_head + C_ptr += pid_b * stride_C_batch + pid_c * chunk_size * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + ddA_cumsum_ptr += pid_b * stride_ddA_cs_batch + pid_c * stride_ddA_cs_chunk + pid_h * stride_ddA_cs_head + dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + if HAS_SEQ_IDX: + seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_seqlen + offs_k[None, :] * stride_dout_hdim) + prev_states_ptrs = prev_states_ptr + (offs_n[None, :] * stride_prev_states_dstate + offs_k[:, None] * stride_prev_states_hdim) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_n[None, :] * stride_C_dstate) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_m * stride_dA_cs_csize + + chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size) + dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < hdim), other=0.0) + prev_states = tl.load(prev_states_ptrs, mask=(offs_k[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0) + prev_states = prev_states.to(dout_ptrs.dtype.element_ty) + acc = tl.dot(dout, prev_states) + c = tl.load(C_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < dstate), other=0.0).to(tl.float32) + ddA_cs = tl.sum(acc * c, axis=1) + dA_cs_m = tl.load(dA_cumsum_ptrs, mask=offs_m < chunk_size_limit, other=0.0).to(tl.float32) + if not HAS_SEQ_IDX: + scale = tl.exp(dA_cs_m) + if HAS_SEQ_IDX: + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_seqlen, mask=pid_c >= 1, other=0) + seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1) + scale = tl.where(seq_idx_m == seq_idx_prev, tl.exp(dA_cs_m), 0.0) + ddA_cs *= scale + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + ddA_cumsum_ptrs = ddA_cumsum_ptr + offs_m * stride_ddA_cs_csize + tl.atomic_add(ddA_cumsum_ptrs, ddA_cs, mask=offs_m < chunk_size) + + +def _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + True, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + IS_TRITON_22=TRITON_22, + ) + return out, out_x + + +def _chunk_scan_fwd_wip(cb, x, dt, dA_cumsum, C, B, states, D=None, z=None, seq_idx=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert B.shape == C.shape + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert states.shape == (batch, nchunks, nheads, headdim, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + # Allocates output. + out = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + if z is not None: + out_x = torch.empty(batch, seqlen, nheads, headdim, device=x.device, dtype=x.dtype) + assert out_x.stride() == out.stride() + else: + out_x = None + grid = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_N']), batch * nchunks, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2), z.stride(3)) + if z is not None else (0, 0, 0, 0)) + _chunk_scan_fwd_kernel_wip[grid]( + cb, x, z, out, out_x, dt, dA_cumsum, seq_idx, C, B, states, D, + chunk_size, headdim, dstate, + batch, seqlen, nheads // ngroups, + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + z_strides[0], z_strides[1], z_strides[2], z_strides[3], + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + B.stride(0), B.stride(1), B.stride(2), B.stride(3), + states.stride(0), states.stride(1), states.stride(2), states.stride(3), states.stride(4), + D.stride(0) if D is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + BLOCK_SIZE_M=128, + HAS_Z=z is not None, + HAS_SEQ_IDX=seq_idx is not None, + ) + return out, out_x + + +def _chunk_scan_bwd_dz(x, z, out, dout, chunk_size, has_ddAcs=True, D=None, dz=None, recompute_output=False): + batch, seqlen, nheads, headdim = x.shape + assert z.shape == x.shape + assert out.shape == x.shape + assert dout.shape == out.shape + nchunks = math.ceil(seqlen / chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert D.stride(-1) == 1 + if has_ddAcs: + ddA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=x.device, dtype=torch.float32) + if D is not None: + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + if dz is not None: + assert dz.shape == z.shape + else: + dz = torch.empty_like(z) + if recompute_output: + outz = torch.empty_like(x) + dout_x = torch.empty_like(dout) + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + grid_dz = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dz_kernel[grid_dz]( + dout, out, z, x, D, outz if recompute_output else None, + dz, dout_x, dD, ddA_cumsum if has_ddAcs else None, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + z.stride(0), z.stride(1), z.stride(2), z.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + *((outz.stride(0), outz.stride(1), outz.stride(2), outz.stride(3)) if recompute_output else (0, 0, 0, 0)), + dz.stride(0), dz.stride(1), dz.stride(2), dz.stride(3), + dout_x.stride(0), dout_x.stride(1), dout_x.stride(2), dout_x.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + *((ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3)) + if has_ddAcs else (0, 0, 0, 0)), + D is not None, + D.dim() == 2 if D is not None else True, + has_ddAcs, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + RECOMPUTE_OUTPUT=recompute_output, + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_dz_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return_vals = (dz, dout_x, dD, ddA_cumsum) if has_ddAcs else (dz, dout_x, dD) + return return_vals if not recompute_output else (*return_vals, outz) + + +def _chunk_scan_bwd_dstates(C, dA_cumsum, dout, seq_idx=None, dtype=None): + batch, seqlen, nheads, headdim = dout.shape + _, _, nchunks, chunk_size = dA_cumsum.shape + _, _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + dtype = C.dtype if dtype is None else dtype + dprev_states = torch.empty(batch, nchunks, nheads, headdim, dstate, device=C.device, dtype=dtype) + grid_dstates = lambda META: (triton.cdiv(headdim, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(C.device.index): + _chunk_scan_bwd_dstates_kernel[grid_dstates]( + dout, C, dprev_states, dA_cumsum, seq_idx, + headdim, dstate, chunk_size, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dprev_states.stride(0), dprev_states.stride(1), dprev_states.stride(2), dprev_states.stride(3), dprev_states.stride(4), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + HAS_SEQ_IDX=seq_idx is not None, + ) + return dprev_states + + +def _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, seq_idx=None, C=None, ngroups=1): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if C is not None: + assert C.shape == (batch, seqlen, ngroups, dstate) + C_strides = (C.stride(0), C.stride(1), C.stride(2), C.stride(3)) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + ddA_cumsum_prev_strides = (ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3)) + else: + C_strides = (0, 0, 0, 0) + ddA_cumsum_prev = None + ddA_cumsum_prev_strides = (0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(dout.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dC = torch.empty(batch, seqlen, nsplits, ngroups, dstate, device=dout.device, dtype=torch.float32) + grid_dc = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_dc_kernel[grid_dc]( + dout, prev_states, C, dA_cumsum, seq_idx, dC, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + *C_strides, + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dC.stride(0), dC.stride(1), dC.stride(2), dC.stride(3), dC.stride(4), + *ddA_cumsum_prev_strides, + HAS_DDA_CS=ddA_cumsum_prev is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dC = dC.sum(2) + return dC if C is None else (dC, ddA_cumsum_prev) + + +def _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, seq_idx=None, CB=None, ngroups=1): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + if CB is not None: + assert CB.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + CB_strides = (CB.stride(0), CB.stride(1), CB.stride(2), CB.stride(3), CB.stride(4)) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + ddA_cumsum_strides = (ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4)) + else: + CB_strides = (0, 0, 0, 0, 0) + ddA_cumsum = None + ddA_cumsum_strides = (0, 0, 0, 0, 0) + nheads_ngroups_ratio = nheads // ngroups + sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count + nheads_per_program = max(min(math.ceil(batch * nchunks * nheads / sm_count), nheads_ngroups_ratio), 1) + nsplits = triton.cdiv(nheads_ngroups_ratio, nheads_per_program) + dcb = torch.empty(batch, nchunks, nsplits, ngroups, chunk_size, chunk_size, device=x.device, dtype=torch.float32) + grid_dcb = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']), + batch * nchunks, nsplits * ngroups) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dcb_kernel[grid_dcb]( + x, dout, CB, dt, dA_cumsum, seq_idx, dcb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads, nheads_per_program, ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + *CB_strides, + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + dcb.stride(0), dcb.stride(1), dcb.stride(2), dcb.stride(3), dcb.stride(4), dcb.stride(5), + *ddA_cumsum_strides, + HAS_DDA_CS=ddA_cumsum is not None, + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + dcb = dcb.sum(2) + if ddA_cumsum is not None: + BLOCK_SIZE_M_actual = _chunk_scan_bwd_dcb_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return dcb if CB is None else (dcb, ddA_cumsum) + + +def _chunk_scan_bwd_dx(cb, x, dt, dA_cumsum, dout, D=None): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert dout.shape == x.shape + # if D is not None: + # BLOCK_SIZE_M_min = 32 + # dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_M_min), batch, nchunks, nheads, headdim, device=D.device, dtype=torch.float32) + # else: + # dD = None + dx = torch.empty_like(x) + ddt = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_dx = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(headdim, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_dx_kernel[grid_dx]( + x, cb, dout, dt, dA_cumsum, D, dx, ddt, # dD, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(-1), cb.stride(-2), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + D.stride(0) if D is not None else 0, + dx.stride(0), dx.stride(1), dx.stride(2), dx.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + # dD.stride(1) if dD is not None else 0, dD.stride(2) if dD is not None else 0, dD.stride(3) if dD is not None else 0, dD.stride(4) if dD is not None else 0, dD.stride(0) if dD is not None else 0, + D is not None, + D.dim() == 2 if D is not None else True, + ) + # if D is not None: + # BLOCK_SIZE_actual = _chunk_scan_bwd_dx_kernel.best_config.kwargs["BLOCK_SIZE_M"] + # n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + # dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + return dx, ddt.to(dtype=dt.dtype) + + +def _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=None, subtract_ddtdt=True): + """Not numerically stable and should not be used. Leaving here for reference. + """ + + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert ddt.shape == dt.shape + assert out.shape == x.shape + assert dout.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + ddA_cumsum = torch.empty_like(dt) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + if D is not None: # Triton gives wrong results if we write to the same location + BLOCK_SIZE_min = 32 + dD = torch.empty(triton.cdiv(chunk_size, BLOCK_SIZE_min), batch, nchunks, nheads, + headdim if D.dim() == 2 else 1, device=D.device, dtype=torch.float32) + else: + dD = None + dD_strides = ((dD.stride(0), dD.stride(1), dD.stride(2), dD.stride(3), dD.stride(4)) + if D is not None else (0, 0, 0, 0, 0)) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_unstable_kernel[grid_ddtcs]( + dout, out, dt, ddt, x, D, ddA_cumsum, dD, + chunk_size, headdim, + batch, seqlen, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + out.stride(0), out.stride(1), out.stride(2), out.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + ddt.stride(0), ddt.stride(2), ddt.stride(1), ddt.stride(3), + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + D.stride(0) if D is not None else 0, + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), + dD_strides[1], dD_strides[2], dD_strides[3], dD_strides[0], dD_strides[4], + D is not None, + D.dim() == 2 if D is not None else True, + subtract_ddtdt, + BLOCK_SIZE_N=max(triton.next_power_of_2(headdim), 16), + ) + if D is not None: + BLOCK_SIZE_actual = _chunk_scan_bwd_ddAcs_unstable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_actual - 1) // BLOCK_SIZE_actual + dD = dD[:n_valid_blocks].sum(dim=(0, 1, 2)).to(dtype=D.dtype) + if D.dim() == 1: + dD = rearrange(dD, "h 1 -> h") + return ddA_cumsum, dD + + +def _chunk_scan_bwd_ddAcs_stable_old(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 16 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel_old[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + BLOCK_SIZE_N=max(triton.next_power_of_2(chunk_size), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel_old.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_stable(x, dt, dA_cumsum, dout, cb): + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == x.shape + assert dA_cumsum.shape == dt.shape + ngroups = cb.shape[2] + assert nheads % ngroups == 0 + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) + BLOCK_SIZE_M_min = 32 + ddA_cumsum = torch.empty(batch, nheads, nchunks, triton.cdiv(chunk_size, BLOCK_SIZE_M_min), + chunk_size, device=x.device, dtype=torch.float32) + grid_ddtcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']), batch * nchunks, nheads) + with torch.cuda.device(x.device.index): + _chunk_scan_bwd_ddAcs_stable_kernel[grid_ddtcs]( + x, dout, dt, dA_cumsum, cb, ddA_cumsum, + chunk_size, headdim, + batch, seqlen, nheads // ngroups, + x.stride(0), x.stride(1), x.stride(2), x.stride(3), + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + dt.stride(0), dt.stride(2), dt.stride(1), dt.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + cb.stride(0), cb.stride(1), cb.stride(2), cb.stride(3), cb.stride(4), + ddA_cumsum.stride(0), ddA_cumsum.stride(2), ddA_cumsum.stride(1), ddA_cumsum.stride(3), ddA_cumsum.stride(4), + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + BLOCK_SIZE_M_actual = _chunk_scan_bwd_ddAcs_stable_kernel.best_config.kwargs["BLOCK_SIZE_M"] + n_valid_blocks = (chunk_size + BLOCK_SIZE_M_actual - 1) // BLOCK_SIZE_M_actual + ddA_cumsum = ddA_cumsum[:, :, :, :n_valid_blocks].sum(dim=3) + return ddA_cumsum + + +def _chunk_scan_bwd_ddAcs_prev(prev_states, C, dout, dA_cumsum, seq_idx=None): + batch, nchunks, nheads, headdim, dstate = prev_states.shape + _, seqlen, _, _ = dout.shape + _, _, _, chunk_size = dA_cumsum.shape + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert dout.shape == (batch, seqlen, nheads, headdim) + ngroups = C.shape[2] + assert nheads % ngroups == 0 + assert C.shape == (batch, seqlen, ngroups, dstate) + if seq_idx is not None: + assert seq_idx.shape == (batch, seqlen) + ddA_cumsum_prev = torch.empty(batch, nheads, nchunks, chunk_size, device=dout.device, dtype=torch.float32) + grid_ddAcs = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(dstate, META['BLOCK_SIZE_N']), + batch * nchunks, nheads) + with torch.cuda.device(dout.device.index): + _chunk_scan_bwd_ddAcs_prev_kernel[grid_ddAcs]( + dout, prev_states, C, dA_cumsum, seq_idx, ddA_cumsum_prev, + chunk_size, dstate, headdim, + batch, seqlen, nchunks, nheads // ngroups, + dout.stride(0), dout.stride(1), dout.stride(2), dout.stride(3), + prev_states.stride(0), prev_states.stride(1), prev_states.stride(2), prev_states.stride(3), prev_states.stride(4), + C.stride(0), C.stride(1), C.stride(2), C.stride(3), + dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), + *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)), + ddA_cumsum_prev.stride(0), ddA_cumsum_prev.stride(2), ddA_cumsum_prev.stride(1), ddA_cumsum_prev.stride(3), + HAS_SEQ_IDX=seq_idx is not None, + BLOCK_SIZE_K=max(triton.next_power_of_2(headdim), 16), + ) + return ddA_cumsum_prev + + +class ChunkScanFn(torch.autograd.Function): + + @staticmethod + def forward(ctx, B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + # Check constraints. + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + assert dt.shape == (batch, nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(1) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if z is not None and z.stride(-1) != 1 and z.stride(1) != 1: # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + CB = _bmm_chunk_fwd(C, B, chunk_size) + out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, prev_states, D=D, z=z) + ctx.save_for_backward(out if z is None else out_x, B, C, CB, x, dt, dA_cumsum, prev_states, D, z) + return out + + @staticmethod + def backward(ctx, dout): + if dout.stride(-1) != 1: + dout = dout.contiguous() + out, B, C, CB, x, dt, dA_cumsum, prev_states, D, z = ctx.saved_tensors + batch, seqlen, nheads, headdim = x.shape + _, _, nchunks, chunk_size = dt.shape + _, _, ngroups, dstate = B.shape + assert dout.shape == (batch, seqlen, nheads, headdim) + if z is not None: + dz, dout, dD, ddA_cumsum = _chunk_scan_bwd_dz(x, z, out, dout, chunk_size=chunk_size, D=D) + else: + dz = None + dprev_states = _chunk_scan_bwd_dstates(C, dA_cumsum, dout, dtype=prev_states.dtype) + dC = _chunk_scan_bwd_dC(prev_states, dA_cumsum, dout, ngroups=ngroups) + dC = dC.to(C.dtype) + dCB = _chunk_scan_bwd_dcb(x, dt, dA_cumsum, dout, ngroups=ngroups) + dCB = dCB.to(CB.dtype) + dB = _bmm_chunk_bwd(C, dCB) + dC = _bmm_chunk_bwd(B, rearrange(dCB, "... l s -> ... s l"), residual=dC) + dx, ddt = _chunk_scan_bwd_dx(CB, x, dt, dA_cumsum, dout, D=D) + # Formula for ddA_cumsum, assuming out is the output of the forward pass before adding x * D. + # ddA_cumsum = torch.einsum("bclhp,bclhp->bhcl", out.float(), dout.float()) - ddt * dt + if z is not None: + ddA_cumsum -= ddt * dt + else: # If z is not None, we already calculated ddA_cumsum and dD when computing dz + ddA_cumsum, dD = _chunk_scan_bwd_ddAcs_unstable(x, dt, out, dout, ddt, D=D) + ddA_cumsum = ddA_cumsum.to(dA_cumsum.dtype) + return dB, dC, dx, ddt, ddA_cumsum, dprev_states, dD, dz + + +def chunk_scan(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + prev_states contains the initial_states at index 0, and the state for the next-to-last chunk at index -1. + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + return ChunkScanFn.apply(B, C, x, dt, dA_cumsum, prev_states, D, z) + + +def chunk_scan_ref(B, C, x, dt, dA_cumsum, prev_states, D=None, z=None): + """ + Argument: + B: (batch, seqlen, ngroups, dstate) + C: (batch, seqlen, ngroups, dstate) + x: (batch, seqlen, nheads, headdim) + dt: (batch, nheads, nchunks, chunk_size) + dA_cumsum: (batch, nheads, nchunks, chunk_size) + prev_states: (batch, nchunks, nheads, headdim, dstate) + D: (nheads, headdim) or (nheads,) + z: (batch, seqlen, nheads, headdim) + Return: + out: (batch, seqlen, nheads, headdim) + """ + batch, seqlen, nheads, headdim = x.shape + _, _, ngroups, dstate = B.shape + assert B.shape == (batch, seqlen, ngroups, dstate) + _, _, nchunks, chunk_size = dt.shape + assert seqlen == nchunks * chunk_size + assert C.shape == B.shape + B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) + CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) + # (batch, nheads, nchunks, chunksize, chunksize) + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] + decay = torch.exp(dt_segment_sum) + scores_decay = CB * rearrange(decay, "b h c l s -> b c h l s") + causal_mask = torch.tril(torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) + scores_decay = scores_decay.masked_fill(~causal_mask, 0) + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) + out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), + prev_states.to(C.dtype)) * state_decay_out + out = out + out_prev + out = rearrange(out, "b c l h p -> b (c l) h p") + if D is not None: + if D.dim() == 1: + D = rearrange(D, "h -> h 1") + out = out + x * D + return out if z is None else out * F.silu(z)