Spaces:
Sleeping
Sleeping
File size: 3,501 Bytes
8b19012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import torch
from torch import Tensor
import triton
import triton.language as tl
@triton.jit
def _causal_conv1d_varlen_states(
X,
CU_SEQLENS,
STATES,
state_len,
dim,
stride_x_seqlen, stride_x_dim,
stride_states_batch, stride_states_seqlen, stride_states_dim,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr
):
batch_idx = tl.program_id(2)
STATES += batch_idx * stride_states_batch
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
other=0)
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
x,
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
"""
Forward pass only, does not support backward pass.
Parameters:
x: (total_tokens, dim)
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
If some of those elements belong to a different sequence, the value of the states will be zero.
Return:
states: (batch, dim, state_len)
"""
_, dim = x.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
BLOCK_N = min(triton.next_power_of_2(dim), 256)
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
with torch.cuda.device(x.device.index):
_causal_conv1d_varlen_states[grid](
x,
cu_seqlens,
states,
state_len,
dim,
x.stride(0), x.stride(1),
states.stride(0), states.stride(2), states.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
)
return states
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
"""
Forward pass only, does not support backward pass.
Parameters:
x: (total_tokens, dim)
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
If some of those elements belong to a different sequence, the value of the states will be zero.
Return:
states: (batch, dim, state_len)
"""
_, dim = x.shape
batch = cu_seqlens.shape[0] - 1
cu_seqlens = cu_seqlens.contiguous()
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
for i in range(batch):
end_idx = cu_seqlens[i + 1]
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
return states
|