This commit is contained in:
qwopqwop200 2023-05-02 12:00:50 +09:00 committed by GitHub
parent 144bd80436
commit 3c108d4232
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -39,7 +39,7 @@ __global__ void VecQuant2MatMulKernel(
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int vec_height,
int height,
int width,
int zero_width
@ -54,7 +54,7 @@ __global__ void VecQuant3MatMulKernel(
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int vec_height,
int height,
int width,
int zero_width
@ -69,7 +69,7 @@ __global__ void VecQuant4MatMulKernel(
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int vec_height,
int height,
int width,
int zero_width
@ -84,7 +84,7 @@ __global__ void VecQuant8MatMulKernel(
const int* __restrict__ zeros,
const int* __restrict__ g_idx,
int batch,
int vec_height,
int vec_height,
int height,
int width,
int zero_width
@ -129,7 +129,7 @@ void vecquant2matmul_cuda(
vec.type(), "vecquant2matmul_cuda", ([&] {
VecQuant2MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
@ -152,42 +152,43 @@ __global__ void VecQuant2MatMulKernel(
) {
int h = BLOCKHEIGHT2 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 16;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 16;
int z_w = w / 16;
int z_mod = (w % 16) * 2;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 16);
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 16);
int k_bit = (k % 16) * 2;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
__syncthreads();
}
}
@ -215,7 +216,7 @@ void vecquant3matmul_cuda(
vec.type(), "vecquant3matmul_cuda", ([&] {
VecQuant3MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
@ -238,15 +239,15 @@ __global__ void VecQuant3MatMulKernel(
) {
int h = BLOCKHEIGHT3 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = (h / 3) * 32;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = (w / 32) * 3;
int z_w = (w / 32) * 3;
int z_mod = w % 32;
int z_bit;
unsigned int z_tmp;
@ -270,14 +271,14 @@ __global__ void VecQuant3MatMulKernel(
z_w += 1;
}
}
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 32) * 3;
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 32) * 3;
int k_mod = k % 32;
int k_bit;
if (k_mod != 10){
if (k_mod != 21){
k_bit = k_mod;
@ -298,7 +299,7 @@ __global__ void VecQuant3MatMulKernel(
k_w += 1;
}
}
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero;
@ -311,7 +312,7 @@ __global__ void VecQuant3MatMulKernel(
} else {
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
}
if (k_mod == 10) {
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
} else if (k_mod == 21){
@ -323,15 +324,16 @@ __global__ void VecQuant3MatMulKernel(
}
scalar_t res;
for (int b = 0; b < batch; ++b){
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
__syncthreads();
}
}
@ -359,7 +361,7 @@ void vecquant4matmul_cuda(
vec.type(), "vecquant4matmul_cuda", ([&] {
VecQuant4MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
@ -382,43 +384,44 @@ __global__ void VecQuant4MatMulKernel(
) {
int h = BLOCKHEIGHT4 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 8;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 8;
int z_w = w / 8;
int z_mod = (w % 8) * 4;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 8);
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 8);
int k_bit = (k % 8) * 4;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
__syncthreads();
}
}
@ -446,7 +449,7 @@ void vecquant8matmul_cuda(
vec.type(), "vecquant8matmul_cuda", ([&] {
VecQuant8MatMulKernel<<<blocks, threads>>>(
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
batch, vec_height, height, width, zero_width
);
})
@ -469,41 +472,42 @@ __global__ void VecQuant8MatMulKernel(
) {
int h = BLOCKHEIGHT8 * blockIdx.x;
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
__shared__ scalar_t blockvec[BLOCKWIDTH];
int i = width * h + w;
int g_h = h * 4;
int k;
unsigned int g;
scalar_t w_tmp;
int z_w = w / 4;
int z_w = w / 4;
int z_mod = (w % 4) * 8;
float weight[BLOCKWIDTH];
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 4);
for (k = 0; k < BLOCKWIDTH; ++k){
int k_w = (k / 4);
int k_bit = (k % 4) * 8;
g = as_int(g_idx[g_h + k]);
scalar_t scale = scales[g * width + w];
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
weight[k] = scale * (w_tmp - zero);
}
scalar_t res;
for (int b = 0; b < batch; ++b){
for (int b = 0; b < batch; ++b){
res = 0;
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
__syncthreads();
for (k = 0; k < BLOCKWIDTH; ++k){
for (k = 0; k < BLOCKWIDTH; ++k){
res += weight[k] * blockvec[k];
}
atomicAdd(&mul[b * width + w], res);
__syncthreads();
}
}