Spaces:
Sleeping
Sleeping
import torch | |
from torch import Tensor | |
import triton | |
import triton.language as tl | |
def _causal_conv1d_varlen_states( | |
X, | |
CU_SEQLENS, | |
STATES, | |
state_len, | |
dim, | |
stride_x_seqlen, stride_x_dim, | |
stride_states_batch, stride_states_seqlen, stride_states_dim, | |
BLOCK_M: tl.constexpr, | |
BLOCK_N: tl.constexpr | |
): | |
batch_idx = tl.program_id(2) | |
STATES += batch_idx * stride_states_batch | |
end_idx = tl.load(CU_SEQLENS + batch_idx + 1) | |
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) | |
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) | |
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) | |
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, | |
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), | |
other=0) | |
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) | |
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, | |
x, | |
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) | |
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: | |
""" | |
Forward pass only, does not support backward pass. | |
Parameters: | |
x: (total_tokens, dim) | |
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. | |
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. | |
If some of those elements belong to a different sequence, the value of the states will be zero. | |
Return: | |
states: (batch, dim, state_len) | |
""" | |
_, dim = x.shape | |
batch = cu_seqlens.shape[0] - 1 | |
cu_seqlens = cu_seqlens.contiguous() | |
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) | |
BLOCK_M = min(triton.next_power_of_2(state_len), 16) | |
BLOCK_N = min(triton.next_power_of_2(dim), 256) | |
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) | |
with torch.cuda.device(x.device.index): | |
_causal_conv1d_varlen_states[grid]( | |
x, | |
cu_seqlens, | |
states, | |
state_len, | |
dim, | |
x.stride(0), x.stride(1), | |
states.stride(0), states.stride(2), states.stride(1), | |
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N | |
) | |
return states | |
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: | |
""" | |
Forward pass only, does not support backward pass. | |
Parameters: | |
x: (total_tokens, dim) | |
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. | |
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. | |
If some of those elements belong to a different sequence, the value of the states will be zero. | |
Return: | |
states: (batch, dim, state_len) | |
""" | |
_, dim = x.shape | |
batch = cu_seqlens.shape[0] - 1 | |
cu_seqlens = cu_seqlens.contiguous() | |
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) | |
for i in range(batch): | |
end_idx = cu_seqlens[i + 1] | |
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) | |
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T | |
return states | |