|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import math |
|
|
|
def quantize(x, scale, zero, maxq): |
|
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) |
|
return scale * (q - zero) |
|
|
|
class Quantizer(nn.Module): |
|
|
|
def __init__(self, shape=1): |
|
super(Quantizer, self).__init__() |
|
self.register_buffer('maxq', torch.tensor(0)) |
|
self.register_buffer('scale', torch.zeros(shape)) |
|
self.register_buffer('zero', torch.zeros(shape)) |
|
|
|
def configure( |
|
self, |
|
bits, perchannel=False, sym=True, |
|
mse=False, norm=2.4, grid=100, maxshrink=.8 |
|
): |
|
self.maxq = torch.tensor(2 ** bits - 1) |
|
self.perchannel = perchannel |
|
self.sym = sym |
|
self.mse = mse |
|
self.norm = norm |
|
self.grid = grid |
|
self.maxshrink = maxshrink |
|
|
|
def find_params(self, x, weight=False): |
|
dev = x.device |
|
self.maxq = self.maxq.to(dev) |
|
|
|
shape = x.shape |
|
if self.perchannel: |
|
if weight: |
|
x = x.flatten(1) |
|
else: |
|
if len(shape) == 4: |
|
x = x.permute([1, 0, 2, 3]) |
|
x = x.flatten(1) |
|
if len(shape) == 3: |
|
x = x.reshape((-1, shape[-1])).t() |
|
if len(shape) == 2: |
|
x = x.t() |
|
else: |
|
x = x.flatten().unsqueeze(0) |
|
|
|
tmp = torch.zeros(x.shape[0], device=dev) |
|
xmin = torch.minimum(x.min(1)[0], tmp) |
|
xmax = torch.maximum(x.max(1)[0], tmp) |
|
|
|
if self.sym: |
|
xmax = torch.maximum(torch.abs(xmin), xmax) |
|
tmp = xmin < 0 |
|
if torch.any(tmp): |
|
xmin[tmp] = -xmax[tmp] |
|
tmp = (xmin == 0) & (xmax == 0) |
|
xmin[tmp] = -1 |
|
xmax[tmp] = +1 |
|
|
|
self.scale = (xmax - xmin) / self.maxq |
|
if self.sym: |
|
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) |
|
else: |
|
self.zero = torch.round(-xmin / self.scale) |
|
|
|
if self.mse: |
|
best = torch.full([x.shape[0]], float('inf'), device=dev) |
|
for i in range(int(self.maxshrink * self.grid)): |
|
p = 1 - i / self.grid |
|
xmin1 = p * xmin |
|
xmax1 = p * xmax |
|
scale1 = (xmax1 - xmin1) / self.maxq |
|
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero |
|
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) |
|
q -= x |
|
q.abs_() |
|
q.pow_(self.norm) |
|
err = torch.sum(q, 1) |
|
tmp = err < best |
|
if torch.any(tmp): |
|
best[tmp] = err[tmp] |
|
self.scale[tmp] = scale1[tmp] |
|
self.zero[tmp] = zero1[tmp] |
|
if not self.perchannel: |
|
if weight: |
|
tmp = shape[0] |
|
else: |
|
tmp = shape[1] if len(shape) != 3 else shape[2] |
|
self.scale = self.scale.repeat(tmp) |
|
self.zero = self.zero.repeat(tmp) |
|
|
|
if weight: |
|
shape = [-1] + [1] * (len(shape) - 1) |
|
self.scale = self.scale.reshape(shape) |
|
self.zero = self.zero.reshape(shape) |
|
return |
|
if len(shape) == 4: |
|
self.scale = self.scale.reshape((1, -1, 1, 1)) |
|
self.zero = self.zero.reshape((1, -1, 1, 1)) |
|
if len(shape) == 3: |
|
self.scale = self.scale.reshape((1, 1, -1)) |
|
self.zero = self.zero.reshape((1, 1, -1)) |
|
if len(shape) == 2: |
|
self.scale = self.scale.unsqueeze(0) |
|
self.zero = self.zero.unsqueeze(0) |
|
|
|
def quantize(self, x): |
|
if self.ready(): |
|
return quantize(x, self.scale, self.zero, self.maxq) |
|
return x |
|
|
|
def enabled(self): |
|
return self.maxq > 0 |
|
|
|
def ready(self): |
|
return torch.all(self.scale != 0) |
|
|
|
|
|
try: |
|
import importlib |
|
quant_cuda = importlib.import_module("quant_cuda") |
|
except: |
|
import os |
|
import sys |
|
argv = sys.argv |
|
sys.argv = ['quant.py','install'] |
|
dir_path = os.path.dirname(os.path.realpath(__file__)) |
|
from setuptools import setup, Extension |
|
from torch.utils import cpp_extension |
|
os.chdir(dir_path) |
|
cucode = ''' |
|
#include <torch/all.h> |
|
#include <torch/python.h> |
|
#include <cuda.h> |
|
#include <cuda_runtime.h> |
|
|
|
// atomicAdd for double-precision floating-point numbers on hardware with |
|
// compute capability < 6.0 from: |
|
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions |
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600 |
|
__device__ double atomicAdd( |
|
double* address, |
|
double val |
|
) { |
|
unsigned long long int* address_as_ull = (unsigned long long int*)address; |
|
unsigned long long int old = *address_as_ull, assumed; |
|
|
|
do { |
|
assumed = old; |
|
old = atomicCAS( |
|
address_as_ull, |
|
assumed, |
|
__double_as_longlong(val + __longlong_as_double(assumed)) |
|
); |
|
|
|
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN) |
|
} while (assumed != old); |
|
|
|
return __longlong_as_double(old); |
|
} |
|
#endif |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant2MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant3MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant4MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant8MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
); |
|
|
|
const int BLOCKWIDTH = 256; |
|
const int BLOCKHEIGHT2 = 16; |
|
const int BLOCKHEIGHT3 = 24; |
|
const int BLOCKHEIGHT4 = 32; |
|
const int BLOCKHEIGHT8 = 64; |
|
|
|
__device__ inline unsigned int as_unsigned(int i) { |
|
return *reinterpret_cast<unsigned int*>(&i); |
|
} |
|
|
|
void vecquant2matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant2matmul_cuda", ([&] { |
|
VecQuant2MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant2MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT2 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = h * 16; |
|
int k = 0; |
|
|
|
int z_w = w / 16; |
|
int z_mod = (w % 16) * 2; |
|
|
|
unsigned int tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); |
|
|
|
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp >> 6) & 0x3) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp >> 8) & 0x3) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp >> 10) & 0x3) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp >> 12) & 0x3) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp >> 14) & 0x3) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp >> 16) & 0x3) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp >> 18) & 0x3) - zero) * blockvec[k + 9]; |
|
res += (scale * scalar_t((tmp >> 20) & 0x3) - zero) * blockvec[k + 10]; |
|
res += (scale * scalar_t((tmp >> 22) & 0x3) - zero) * blockvec[k + 11]; |
|
res += (scale * scalar_t((tmp >> 24) & 0x3) - zero) * blockvec[k + 12]; |
|
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; |
|
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; |
|
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; |
|
|
|
i += width; |
|
k += 16; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |
|
|
|
void vecquant3matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant3matmul_cuda", ([&] { |
|
VecQuant3MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant3MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT3 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = (h / 3) * 32; |
|
int k = 0; |
|
|
|
int z_w = (w / 32) * 3; // ((w / 256) * 24) / 3 |
|
int z_mod = w % 32; |
|
int z_bit; |
|
|
|
if (z_mod != 10){ |
|
if (z_mod != 21){ |
|
z_bit = z_mod; |
|
if (z_bit > 21){ |
|
z_bit -= 22; |
|
z_bit *= 3; |
|
z_bit += 2; |
|
z_w += 2; |
|
} else if (z_bit > 10){ |
|
z_bit -= 11; |
|
z_bit *= 3; |
|
z_bit += 1; |
|
z_w += 1; |
|
} else { |
|
z_bit *= 3; |
|
} |
|
} else { |
|
z_w += 1; |
|
} |
|
} |
|
|
|
unsigned int tmp1; |
|
unsigned int tmp2; |
|
unsigned int tmp; |
|
unsigned int z_tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp1 = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero; |
|
if (z_mod == 10) { |
|
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); |
|
zero = scale * scalar_t((z_tmp) + 1); |
|
} else if (z_mod == 21){ |
|
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); |
|
zero = scale * scalar_t((z_tmp) + 1); |
|
} else { |
|
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); |
|
} |
|
|
|
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; |
|
|
|
i += width; |
|
tmp2 = as_unsigned(mat[i]); |
|
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); |
|
tmp2 >>= 1; |
|
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; |
|
k += 11; |
|
|
|
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp2 >> 9) & 0x7) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp2 >> 12) & 0x7) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp2 >> 15) & 0x7) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp2 >> 18) & 0x7) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; |
|
|
|
i += width; |
|
tmp1 = as_unsigned(mat[i]); |
|
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); |
|
tmp1 >>= 2; |
|
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; |
|
k += 11; |
|
|
|
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp1 >> 9) & 0x7) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp1 >> 12) & 0x7) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp1 >> 15) & 0x7) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp1 >> 18) & 0x7) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; |
|
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; |
|
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; |
|
|
|
i += width; |
|
k += 10; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |
|
|
|
void vecquant4matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant4matmul_cuda", ([&] { |
|
VecQuant4MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant4MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT4 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = h * 8; |
|
int k = 0; |
|
|
|
int z_w = w / 8; |
|
int z_mod = (w % 8) * 4; |
|
|
|
unsigned int tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); |
|
|
|
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp >> 12) & 0xF) - zero) * blockvec[k + 3]; |
|
res += (scale * scalar_t((tmp >> 16) & 0xF) - zero) * blockvec[k + 4]; |
|
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; |
|
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; |
|
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; |
|
|
|
i += width; |
|
k += 8; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |
|
|
|
void vecquant8matmul_cuda( |
|
torch::Tensor vec, |
|
torch::Tensor mat, |
|
torch::Tensor mul, |
|
torch::Tensor scales, |
|
torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
int batch = vec.size(0); |
|
int vec_height = vec.size(1); |
|
int height = mat.size(0); |
|
int width = mat.size(1); |
|
int zero_width = zeros.size(1); |
|
|
|
dim3 blocks( |
|
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8, |
|
(width + BLOCKWIDTH - 1) / BLOCKWIDTH, |
|
batch |
|
); |
|
dim3 threads(BLOCKWIDTH); |
|
|
|
AT_DISPATCH_FLOATING_TYPES( |
|
vec.type(), "vecquant8matmul_cuda", ([&] { |
|
VecQuant8MatMulKernel<<<blocks, threads>>>( |
|
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(), |
|
scales.data<scalar_t>(), zeros.data<int>(), |
|
batch, vec_height, height, width, zero_width, groupsize |
|
); |
|
}) |
|
); |
|
} |
|
|
|
template <typename scalar_t> |
|
__global__ void VecQuant8MatMulKernel( |
|
const scalar_t* __restrict__ vec, |
|
const int* __restrict__ mat, |
|
scalar_t* __restrict__ mul, |
|
const scalar_t* __restrict__ scales, |
|
const int* __restrict__ zeros, |
|
int batch, |
|
int vec_height, |
|
int height, |
|
int width, |
|
int zero_width, |
|
int groupsize |
|
) { |
|
int b = blockIdx.z; |
|
int h = BLOCKHEIGHT8 * blockIdx.x; |
|
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; |
|
|
|
__shared__ scalar_t blockvec[BLOCKWIDTH]; |
|
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; |
|
__syncthreads(); |
|
|
|
scalar_t res = 0; |
|
int i = width * h + w; |
|
int g_h = h * 4; |
|
int k = 0; |
|
|
|
int z_w = w / 4; |
|
int z_mod = (w % 4) * 8; |
|
|
|
unsigned int tmp; |
|
|
|
while (k < BLOCKWIDTH) { |
|
tmp = as_unsigned(mat[i]); |
|
|
|
int g = (g_h + k) / groupsize; |
|
scalar_t scale = scales[g * width + w]; |
|
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); |
|
|
|
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; |
|
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; |
|
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; |
|
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; |
|
|
|
i += width; |
|
k += 4; |
|
} |
|
|
|
atomicAdd(&mul[b * width + w], res); |
|
} |
|
''' |
|
with open("quant_cuda_kernel.cu","w") as f: |
|
f.write(cucode) |
|
cppcode = ''' |
|
#include <torch/all.h> |
|
#include <torch/python.h> |
|
#include <c10/cuda/CUDAGuard.h> |
|
|
|
void vecquant2matmul_cuda( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
); |
|
|
|
void vecquant2matmul( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); |
|
vecquant2matmul_cuda(vec, mat, mul, scales, zeros,groupsize); |
|
} |
|
|
|
void vecquant3matmul_cuda( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
); |
|
|
|
void vecquant3matmul( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); |
|
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, groupsize); |
|
} |
|
|
|
void vecquant4matmul_cuda( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
); |
|
|
|
void vecquant4matmul( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); |
|
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, groupsize); |
|
} |
|
|
|
void vecquant8matmul_cuda( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
); |
|
|
|
void vecquant8matmul( |
|
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, |
|
torch::Tensor scales, torch::Tensor zeros, |
|
int groupsize |
|
) { |
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); |
|
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, groupsize); |
|
} |
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)"); |
|
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)"); |
|
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)"); |
|
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)"); |
|
} |
|
''' |
|
with open("quant_cuda.cpp","w") as f: |
|
f.write(cppcode) |
|
setup( |
|
name='quant_cuda', |
|
ext_modules=[cpp_extension.CUDAExtension( |
|
'quant_cuda', ['quant_cuda.cpp', 'quant_cuda_kernel.cu'] |
|
)], |
|
cmdclass={'build_ext': cpp_extension.BuildExtension} |
|
) |
|
os.chdir(os.getcwd()) |
|
sys.argv = argv |
|
for i in sys.path: |
|
if i.endswith("site-packages"): |
|
for j in os.listdir(i): |
|
if j.find("quant_cuda") != -1: |
|
sys.path.append(os.path.join(i,j)) |
|
break |
|
break |
|
import importlib |
|
quant_cuda = importlib.import_module("quant_cuda") |
|
|
|
|
|
|
|
class QuantLinear(nn.Module): |
|
def __init__(self, bits, groupsize, infeatures, outfeatures): |
|
super().__init__() |
|
if bits not in [2,3,4,8]: |
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.") |
|
self.infeatures = infeatures |
|
self.outfeatures = outfeatures |
|
self.bits = bits |
|
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2,int(math.log2(groupsize)))): |
|
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)") |
|
groupsize = groupsize if groupsize != -1 else infeatures |
|
self.groupsize = groupsize |
|
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures/groupsize),outfeatures // 256 * (bits * 8)), dtype=torch.int)) |
|
self.register_buffer('scales', torch.zeros((math.ceil(infeatures/groupsize),outfeatures))) |
|
self.register_buffer('bias', torch.zeros(outfeatures)) |
|
self.register_buffer( |
|
'qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int) |
|
) |
|
self._initialized_quant_state = False |
|
|
|
def pack(self, linear, scales, zeros): |
|
scales = scales.t().contiguous() |
|
zeros = zeros.t().contiguous() |
|
scale_zeros = zeros * scales |
|
self.scales = scales.clone() |
|
if linear.bias is not None: |
|
self.bias = linear.bias.clone() |
|
|
|
intweight = [] |
|
for idx in range(self.infeatures): |
|
g_idx = idx // self.groupsize |
|
intweight.append(torch.round((linear.weight.data[:,idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,None]) |
|
intweight = torch.cat(intweight,dim=1) |
|
intweight = intweight.t().contiguous() |
|
intweight = intweight.numpy().astype(np.uint32) |
|
qweight = np.zeros( |
|
(intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32 |
|
) |
|
i = 0 |
|
row = 0 |
|
while row < qweight.shape[0]: |
|
if self.bits in [2,4,8]: |
|
for j in range(i, i + (32//self.bits)): |
|
qweight[row] |= intweight[j] << (self.bits * (j - i)) |
|
i += 32//self.bits |
|
row += 1 |
|
elif self.bits == 3: |
|
for j in range(i, i + 10): |
|
qweight[row] |= intweight[j] << (3 * (j - i)) |
|
i += 10 |
|
qweight[row] |= intweight[i] << 30 |
|
row += 1 |
|
qweight[row] |= (intweight[i] >> 2) & 1 |
|
i += 1 |
|
for j in range(i, i + 10): |
|
qweight[row] |= intweight[j] << (3 * (j - i) + 1) |
|
i += 10 |
|
qweight[row] |= intweight[i] << 31 |
|
row += 1 |
|
qweight[row] |= (intweight[i] >> 1) & 0x3 |
|
i += 1 |
|
for j in range(i, i + 10): |
|
qweight[row] |= intweight[j] << (3 * (j - i) + 2) |
|
i += 10 |
|
row += 1 |
|
else: |
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.") |
|
|
|
qweight = qweight.astype(np.int32) |
|
self.qweight = torch.from_numpy(qweight) |
|
|
|
zeros -= 1; |
|
zeros = zeros.numpy().astype(np.uint32) |
|
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32) |
|
i = 0 |
|
col = 0 |
|
while col < qzeros.shape[1]: |
|
if self.bits in [2,4,8]: |
|
for j in range(i, i + (32//self.bits)): |
|
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) |
|
i += 32//self.bits |
|
col += 1 |
|
elif self.bits == 3: |
|
for j in range(i, i + 10): |
|
qzeros[:, col] |= zeros[:, j] << (3 * (j - i)) |
|
i += 10 |
|
qzeros[:, col] |= zeros[:, i] << 30 |
|
col += 1 |
|
qzeros[:, col] |= (zeros[:, i] >> 2) & 1 |
|
i += 1 |
|
for j in range(i, i + 10): |
|
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1) |
|
i += 10 |
|
qzeros[:, col] |= zeros[:, i] << 31 |
|
col += 1 |
|
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3 |
|
i += 1 |
|
for j in range(i, i + 10): |
|
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2) |
|
i += 10 |
|
col += 1 |
|
else: |
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.") |
|
|
|
qzeros = qzeros.astype(np.int32) |
|
self.qzeros = torch.from_numpy(qzeros) |
|
|
|
def forward(self, x): |
|
intermediate_dtype = torch.float32 |
|
|
|
if not self._initialized_quant_state: |
|
|
|
if self.bias is not None and bool(torch.any(self.bias != 0)): |
|
|
|
self.bias.data = self.bias.data.to(intermediate_dtype) |
|
else: |
|
self.bias = None |
|
|
|
outshape = list(x.shape) |
|
outshape[-1] = self.outfeatures |
|
x = x.reshape(-1, x.shape[-1]) |
|
if self.bias is None: |
|
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device) |
|
else: |
|
y = self.bias.clone().repeat(x.shape[0], 1) |
|
|
|
output_dtype = x.dtype |
|
x = x.to(intermediate_dtype) |
|
if self.bits == 2: |
|
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) |
|
elif self.bits == 3: |
|
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) |
|
elif self.bits == 4: |
|
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) |
|
elif self.bits == 8: |
|
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize) |
|
else: |
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.") |
|
y = y.to(output_dtype) |
|
return y.reshape(outshape) |
|
|
|
def make_quant(module, names, bits, groupsize, name=''): |
|
if isinstance(module, QuantLinear): |
|
return |
|
for attr in dir(module): |
|
tmp = getattr(module, attr) |
|
name1 = name + '.' + attr if name != '' else attr |
|
if name1 in names: |
|
setattr( |
|
module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features) |
|
) |
|
for name1, child in module.named_children(): |
|
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) |