AutoGPTQ/auto_gptq/quantization/quant_cuda_kernel.cu
2023-04-14 01:09:40 +08:00

509 lines
12 KiB
Text

#include <torch/all.h>
#include <torch/python.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// atomicAdd for double-precision floating-point numbers on hardware with
// compute capability < 6.0 from:
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#atomic-functions
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 600
__device__ double atomicAdd(
double* address,
double val
) {
unsigned long long int* address_as_ull = (unsigned long long int*)address;
unsigned long long int old = *address_as_ull, assumed;
do {
assumed = old;
old = atomicCAS(
address_as_ull,
assumed,
__double_as_longlong(val + __longlong_as_double(assumed))
);
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
} while (assumed != old);
return __longlong_as_double(old);
}
#endif
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
);
const int BLOCKWIDTH = 256;
const int BLOCKHEIGHT2 = 16;
const int BLOCKHEIGHT3 = 24;
const int BLOCKHEIGHT4 = 32;
const int BLOCKHEIGHT8 = 64;
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
__device__ inline int as_int(int i) {
return *reinterpret_cast<int*>(&i);
}
void vecquant2matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant2matmul_cuda", ([&] {
VecQuant2MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant2MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 16;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 16;
int z_mod = (w % 16) * 2;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 16);
int k_bit = (k % 16) * 2;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}
void vecquant3matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant3matmul_cuda", ([&] {
VecQuant3MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant3MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = (h / 3) * 32;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = (w / 32) * 3;
int z_mod = w % 32;
int z_bit;
unsigned int z_tmp;
if (z_mod != 10){
if (z_mod != 21){
z_bit = z_mod;
if (z_bit > 21){
z_bit -= 22;
z_bit *= 3;
z_bit += 2;
z_w += 2;
} else if (z_bit > 10){
z_bit -= 11;
z_bit *= 3;
z_bit += 1;
z_w += 1;
} else {
z_bit *= 3;
}
} else {
z_w += 1;
}
}
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 32) * 3;
int k_mod = k % 32;
int k_bit;
if (k_mod != 10){
if (k_mod != 21){
k_bit = k_mod;
if (k_bit > 21){
k_bit -= 22;
k_bit *= 3;
k_bit += 2;
k_w += 2;
} else if (k_bit > 10){
k_bit -= 11;
k_bit *= 3;
k_bit += 1;
k_w += 1;
} else {
k_bit *= 3;
}
} else {
k_w += 1;
}
}
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero;
if (z_mod == 10) {
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
zero = scalar_t((z_tmp) + 1);
} else if (z_mod == 21){
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
zero = scalar_t((z_tmp) + 1);
} else {
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
}
if (k_mod == 10) {
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
} else if (k_mod == 21){
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 31) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 1) & 0x6);
} else {
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x7);
}
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}
void vecquant4matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant4matmul_cuda", ([&] {
VecQuant4MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant4MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 8;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 8;
int z_mod = (w % 8) * 4;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 8);
int k_bit = (k % 8) * 4;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}
void vecquant8matmul_cuda(
torch::Tensor vec,
torch::Tensor mat,
torch::Tensor mul,
torch::Tensor scales,
torch::Tensor zeros,
torch::Tensor g_idx
) {
int batch = vec.size(0);
int vec_height = vec.size(1);
int height = mat.size(0);
int width = mat.size(1);
int zero_width = zeros.size(1);
dim3 blocks(
(height + BLOCKHEIGHT8 - 1) / BLOCKHEIGHT8,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
);
dim3 threads(BLOCKWIDTH);
AT_DISPATCH_FLOATING_TYPES(
vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
);
}
template <typename scalar_t>
__global__ void VecQuant8MatMulKernel(
const scalar_t* __restrict__ vec,
const int* __restrict__ mat,
scalar_t* __restrict__ mul,
const scalar_t* __restrict__ scales,
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int height,
int width,
int zero_width
) {
int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 4;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 4);
int k_bit = (k % 4) * 8;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
}
}