Spaces:
Sleeping
Sleeping
/****************************************************************************** | |
* Copyright (c) 2023, Tri Dao. | |
******************************************************************************/ | |
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_> | |
struct Causal_conv1d_update_kernel_traits { | |
using input_t = input_t_; | |
using weight_t = weight_t_; | |
static constexpr int kNThreads = kNThreads_; | |
static constexpr int kWidth = kWidth_; | |
static constexpr int kNBytes = sizeof(input_t); | |
static_assert(kNBytes == 2 || kNBytes == 4); | |
}; | |
template<typename Ktraits, bool kIsCircularBuffer> | |
__global__ __launch_bounds__(Ktraits::kNThreads) | |
void causal_conv1d_update_kernel(ConvParamsBase params) { | |
constexpr int kWidth = Ktraits::kWidth; | |
constexpr int kNThreads = Ktraits::kNThreads; | |
using input_t = typename Ktraits::input_t; | |
using weight_t = typename Ktraits::weight_t; | |
const int tidx = threadIdx.x; | |
const int batch_id = blockIdx.x; | |
const int channel_id = blockIdx.y * kNThreads + tidx; | |
if (channel_id >= params.dim) return; | |
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride | |
+ channel_id * params.x_c_stride; | |
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride | |
+ channel_id * params.conv_state_c_stride; | |
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride; | |
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride | |
+ channel_id * params.out_c_stride; | |
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]); | |
int state_len = params.conv_state_len; | |
int advance_len = params.seqlen; | |
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; | |
int update_idx = cache_seqlen - (kWidth - 1); | |
update_idx = update_idx < 0 ? update_idx + state_len : update_idx; | |
float weight_vals[kWidth] = {0}; | |
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } | |
float x_vals[kWidth] = {0}; | |
if constexpr (!kIsCircularBuffer) { | |
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { | |
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; | |
} | |
for (int i = 0; i < kWidth - 1; ++i) { | |
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; | |
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { | |
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; | |
} | |
x_vals[i] = float(state_val); | |
} | |
} else { | |
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { | |
input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; | |
x_vals[i] = float(state_val); | |
} | |
} | |
for (int i = 0; i < params.seqlen; ++i) { | |
input_t x_val = x[i * params.x_l_stride]; | |
if constexpr (!kIsCircularBuffer) { | |
if (i < advance_len && state_len - advance_len + i >= 0) { | |
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; | |
} | |
} else { | |
conv_state[update_idx * params.conv_state_l_stride] = x_val; | |
++update_idx; | |
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; | |
} | |
x_vals[kWidth - 1] = float(x_val); | |
float out_val = bias_val; | |
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } | |
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } | |
out[i * params.out_l_stride] = input_t(out_val); | |
// Shift the input buffer by 1 | |
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } | |
} | |
} | |
template<int kNThreads, int kWidth, typename input_t, typename weight_t> | |
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { | |
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>; | |
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); | |
auto kernel = params.cache_seqlens == nullptr | |
? &causal_conv1d_update_kernel<Ktraits, false> | |
: &causal_conv1d_update_kernel<Ktraits, true>; | |
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params); | |
C10_CUDA_KERNEL_LAUNCH_CHECK(); | |
} | |
template<typename input_t, typename weight_t> | |
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { | |
if (params.width == 2) { | |
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); | |
} else if (params.width == 3) { | |
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); | |
} else if (params.width == 4) { | |
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); | |
} | |
} | |
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); | |
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream); |