|
#include <torch/all.h> |
|
#include <torch/python.h> |
|
#include <cuda.h> |
|
#include <cuda_runtime.h> |
|
|
|
|
|
|
|
|
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 |
|
__device__ double atomicAdd( |
|
double* address, |
|
double val |
|
) { |
|
unsigned long long int* address_as_ull = (unsigned long long int*)address; |
|
unsigned long long int old = *address_as_ull, assumed; |
|
|
|
do { |
|
assumed = old; |
|
old = atomicCAS( |
|
address_as_ull, |
|
assumed, |
|
__double_as_longlong(val + __longlong_as_double(assumed)) |
|
); |
|
|
|
|
|
} while (assumed != old); |
|
|
|
return __longlong_as_double(old); |
|
} |
|
#endif |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant2MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant3MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant4MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant8MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
const int BLOCKWIDTH = 256; |
|
const int BLOCKHEIGHT2 = 16; |
|
const int BLOCKHEIGHT3 = 24; |
|
const int BLOCKHEIGHT4 = 32; |
|
const int BLOCKHEIGHT8 = 64; |
|
|
|
__device__ inline unsigned int as_unsigned(int i) { |
|
return *reinterpret_cast<unsigned int*>(&i); |
|
} |
|
|
|
void vecquant2matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant2matmul_cuda", ([&] { |
|
VecQuant2MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant2MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT2 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = h * 16; |
|
int k = 0; |
|
|
|
int z_w = w / 16; |
|
int z_mod = (w % 16) * 2; |
|
|
|
unsigned int tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); |
|
|
|
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9]; |
|
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10]; |
|
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11]; |
|
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12]; |
|
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; |
|
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; |
|
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; |
|
|
|
i += width; |
|
k += 16; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |
|
|
|
void vecquant3matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant3matmul_cuda", ([&] { |
|
VecQuant3MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant3MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT3 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = (h / 3) * 32; |
|
int k = 0; |
|
|
|
int z_w = (w / 32) * 3; |
|
int z_mod = w % 32; |
|
int z_bit; |
|
|
|
if (z_mod != 10){ |
|
if (z_mod != 21){ |
|
z_bit = z_mod; |
|
if (z_bit > 21){ |
|
z_bit -= 22; |
|
z_bit *= 3; |
|
z_bit += 2; |
|
z_w += 2; |
|
} else if (z_bit > 10){ |
|
z_bit -= 11; |
|
z_bit *= 3; |
|
z_bit += 1; |
|
z_w += 1; |
|
} else { |
|
z_bit *= 3; |
|
} |
|
} else { |
|
z_w += 1; |
|
} |
|
} |
|
|
|
unsigned int tmp1; |
|
unsigned int tmp2; |
|
unsigned int tmp; |
|
unsigned int z_tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp1 = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero; |
|
if (z_mod == 10) { |
|
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); |
|
zero = scale * scalar_t((z_tmp) + 1); |
|
} else if (z_mod == 21){ |
|
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); |
|
zero = scale * scalar_t((z_tmp) + 1); |
|
} else { |
|
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); |
|
} |
|
|
|
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; |
|
|
|
i += width; |
|
tmp2 = as_unsigned(mat[i]); |
|
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); |
|
tmp2 >>= 1; |
|
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; |
|
k += 11; |
|
|
|
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; |
|
|
|
i += width; |
|
tmp1 = as_unsigned(mat[i]); |
|
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); |
|
tmp1 >>= 2; |
|
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; |
|
k += 11; |
|
|
|
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; |
|
|
|
i += width; |
|
k += 10; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |
|
|
|
void vecquant4matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant4matmul_cuda", ([&] { |
|
VecQuant4MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant4MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT4 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = h * 8; |
|
int k = 0; |
|
|
|
int z_w = w / 8; |
|
int z_mod = (w % 8) * 4; |
|
|
|
unsigned int tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); |
|
|
|
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; |
|
|
|
i += width; |
|
k += 8; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |
|
|
|
void vecquant8matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant8matmul_cuda", ([&] { |
|
VecQuant8MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant8MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT8 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = h * 4; |
|
int k = 0; |
|
|
|
int z_w = w / 4; |
|
int z_mod = (w % 4) * 8; |
|
|
|
unsigned int tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); |
|
|
|
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; |
|
|
|
i += width; |
|
k += 4; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |