|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace { |
|
|
|
/* |
|
template <typename Datatype, int ELEMENTS_PER_LDG> |
|
__device__ __inline__ void copy_vector(Datatype *dst, const Datatype *src); |
|
|
|
template <> |
|
__device__ __inline__ void copy_vector<c10::BFloat16, 1>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *dst = *src; } |
|
|
|
template <> |
|
__device__ __inline__ void copy_vector<c10::BFloat16, 4>(c10::BFloat16 *dst, const c10::BFloat16 *src) { *((float2*) dst) = *((float2*) src); } |
|
|
|
template <> |
|
__device__ __inline__ void copy_vector<c10::Half, 1>(c10::Half *dst, const c10::Half *src) { *dst = *src; } |
|
|
|
template <> |
|
__device__ __inline__ void copy_vector<c10::Half, 4>(c10::Half *dst, const c10::Half *src) { *((float2*) dst) = *((float2*) src); } |
|
|
|
template <> |
|
__device__ __inline__ void copy_vector<uint8_t, 1>(uint8_t *dst, const uint8_t *src) { *dst = *src; } |
|
|
|
template <> |
|
__device__ __inline__ void copy_vector<uint8_t, 4>(uint8_t *dst, const uint8_t *src) {*((half2*) dst) = *((half2*) src); } |
|
|
|
int log2_ceil(int value) { |
|
int log2_value = 0; |
|
while ((1 << log2_value) < value) ++log2_value; |
|
return log2_value; |
|
} |
|
|
|
template<typename T> |
|
struct Add { |
|
__device__ __forceinline__ T operator()(T a, T b) const { |
|
return a + b; |
|
} |
|
}; |
|
|
|
template<typename T> |
|
struct Max { |
|
__device__ __forceinline__ T operator()(T a, T b) const { |
|
return a < b ? b : a; |
|
} |
|
}; |
|
|
|
template <typename T> |
|
__device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) |
|
{ |
|
|
|
return __shfl_xor_sync(mask, value, laneMask, width); |
|
|
|
return __shfl_xor(value, laneMask, width); |
|
|
|
} |
|
|
|
template <typename acc_t, int WARP_BATCH, int WARP_SIZE, template<typename> class ReduceOp> |
|
__device__ __forceinline__ void warp_reduce(acc_t* sum) { |
|
ReduceOp<acc_t> r; |
|
|
|
for (int offset = WARP_SIZE / 2; offset > 0; offset /= 2) { |
|
|
|
for (int i = 0; i < WARP_BATCH; ++i) { |
|
acc_t b = WARP_SHFL_XOR_NATIVE(sum[i], offset, WARP_SIZE); |
|
sum[i] = r(sum[i], b); |
|
} |
|
} |
|
} |
|
*/ |
|
|
|
template <typename input_t, typename output_t, typename acc_t> |
|
__global__ void anti_alias_activation_forward( |
|
output_t *dst, |
|
const input_t *src, |
|
const input_t *ftr, |
|
const input_t *alpha, |
|
const input_t *beta, |
|
int batch_size, |
|
int channels, |
|
int seq_len) |
|
{ |
|
|
|
constexpr int ELEMENTS_PER_LDG_STG = 1; |
|
constexpr int BUFFER_SIZE = 32; |
|
constexpr int FILTER_SIZE = 12; |
|
constexpr int HALF_FILTER_SIZE = 6; |
|
constexpr int REPLICATION_PAD = 5; |
|
|
|
|
|
|
|
int block_offset = (blockIdx.x * 128 * BUFFER_SIZE + seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); |
|
int local_offset = threadIdx.x * BUFFER_SIZE; |
|
int seq_offset = blockIdx.x * 128 * BUFFER_SIZE + local_offset; |
|
|
|
|
|
|
|
|
|
|
|
|
|
int output_seq_len = seq_len * 2 ; |
|
int output_block_offset = (blockIdx.x * 128 * BUFFER_SIZE * 2 + output_seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); |
|
int output_local_offset = threadIdx.x * BUFFER_SIZE * 2; |
|
int output_seq_offset = blockIdx.x * 128 * BUFFER_SIZE *2 + output_local_offset; |
|
|
|
const input_t *right_most_pntr = src + (seq_len * (blockIdx.y + gridDim.y * blockIdx.z)); |
|
input_t seq_left_most_value = right_most_pntr[0]; |
|
input_t seq_right_most_value = right_most_pntr[seq_len - 1]; |
|
|
|
src += block_offset + local_offset; |
|
dst += output_block_offset + output_local_offset ; |
|
alpha = alpha + blockIdx.y; |
|
input_t alpha_val = expf(alpha[0]); |
|
beta = beta + blockIdx.y; |
|
input_t beta_val = expf(beta[0]); |
|
|
|
input_t elements[2*FILTER_SIZE+2*BUFFER_SIZE] = {0}; |
|
input_t intermediates[2*FILTER_SIZE+2*BUFFER_SIZE] = {0}; |
|
|
|
input_t filter[FILTER_SIZE]; |
|
|
|
|
|
|
|
|
|
for (int it = 0; it < FILTER_SIZE; it+=1) { |
|
filter[it] = ftr[it]; |
|
} |
|
|
|
|
|
|
|
for (int it = -HALF_FILTER_SIZE; it < BUFFER_SIZE + HALF_FILTER_SIZE ; it+=1) { |
|
int element_index = seq_offset + it; |
|
if ((element_index < 0) && (element_index >= -REPLICATION_PAD)) { |
|
elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_left_most_value; |
|
} |
|
if ((element_index >= seq_len) && (element_index < seq_len + REPLICATION_PAD)) { |
|
elements[2*(HALF_FILTER_SIZE+it)] = 2*seq_right_most_value; |
|
} |
|
if ((element_index >= 0) && (element_index < seq_len)) { |
|
elements[2*(HALF_FILTER_SIZE+it)] = 2*src[it]; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
for (int it = 0; it < (2 * BUFFER_SIZE + 2*FILTER_SIZE); it+=1) { |
|
input_t acc = 0.0; |
|
|
|
int element_index = output_seq_offset + it; |
|
|
|
for (int f_idx = 0; f_idx < FILTER_SIZE; f_idx+=1){ |
|
if ((element_index + f_idx) >= 0){ |
|
acc += filter[f_idx] * elements[it+f_idx]; |
|
} |
|
} |
|
intermediates[it] = acc; |
|
} |
|
|
|
double no_div_by_zero = 0.000000001; |
|
|
|
for (int it = 0; it < 12 + 2 * BUFFER_SIZE; it++) { |
|
intermediates[it] += (1.0/(beta_val + no_div_by_zero)) * sinf(intermediates[it] * alpha_val) * sinf(intermediates[it] * alpha_val); |
|
} |
|
|
|
|
|
|
|
|
|
for (int it = 0; it < 2*BUFFER_SIZE; it+=1){ |
|
int element_index = output_seq_offset + it; |
|
if (element_index < output_seq_len) { |
|
dst[it] = intermediates[it+6]; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
template<typename input_t, typename output_t, typename acc_t> |
|
void dispatch_anti_alias_activation_forward( |
|
output_t *dst, |
|
const input_t *src, |
|
const input_t *ftr, |
|
const input_t *alpha, |
|
const input_t *beta, |
|
int batch_size, |
|
int channels, |
|
int seq_len) |
|
{ |
|
if (seq_len == 0) { |
|
return; |
|
} else { |
|
|
|
constexpr int threads_per_block = 128; |
|
constexpr int seq_len_per_block = 4096; |
|
int blocks_per_seq_len = (seq_len + seq_len_per_block - 1) / seq_len_per_block; |
|
dim3 blocks(blocks_per_seq_len, channels, batch_size); |
|
dim3 threads(threads_per_block, 1, 1); |
|
|
|
anti_alias_activation_forward<input_t, output_t, acc_t> |
|
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(dst, src, ftr, alpha, beta, batch_size, channels, seq_len); |
|
} |
|
} |
|
} |
|
|
|
namespace anti_alias_activation { |
|
|
|
torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& filter, torch::Tensor const& alpha, torch::Tensor const& beta) |
|
{ |
|
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] |
|
const int batches = input.size(0); |
|
const int channels = input.size(1); |
|
const int seq_len = input.size(2); |
|
|
|
|
|
auto act_options = input.options().requires_grad(false); |
|
int output_seq_len = seq_len*2; |
|
|
|
torch::Tensor anti_alias_activation_results = |
|
torch::empty({batches, channels, output_seq_len}, act_options); |
|
|
|
|
|
void* input_ptr = static_cast<void*>(input.data_ptr()); |
|
void* filter_ptr = static_cast<void*>(filter.data_ptr()); |
|
void* alpha_ptr = static_cast<void*>(alpha.data_ptr()); |
|
void* beta_ptr = static_cast<void*>(beta.data_ptr()); |
|
void* anti_alias_activation_results_ptr = static_cast<void*>(anti_alias_activation_results.data_ptr()); |
|
|
|
DISPATCH_FLOAT_HALF_AND_BFLOAT( |
|
input.scalar_type(), |
|
"dispatch anti alias activation_forward", |
|
dispatch_anti_alias_activation_forward<scalar_t, scalar_t, float>( |
|
reinterpret_cast<scalar_t*>(anti_alias_activation_results_ptr), |
|
reinterpret_cast<const scalar_t*>(input_ptr), |
|
reinterpret_cast<const scalar_t*>(filter_ptr), |
|
reinterpret_cast<const scalar_t*>(alpha_ptr), |
|
reinterpret_cast<const scalar_t*>(beta_ptr), |
|
batches, |
|
channels, |
|
seq_len); |
|
); |
|
return anti_alias_activation_results; |
|
} |
|
} |
|
|