From f752336cdaecf68d4718c992772389a41d6a36e8 Mon Sep 17 00:00:00 2001 From: qwopqwop200 Date: Wed, 6 Sep 2023 16:39:22 +0900 Subject: [PATCH] fix bug --- .../cuda_256/autogptq_cuda_kernel_256.cu | 250 +++++++++--------- .../cuda_64/autogptq_cuda_kernel_64.cu | 36 +-- .../exllama/cuda_func/q4_matmul.cu | 8 +- .../exllama/cuda_func/q4_matrix.cu | 4 +- autogptq_extension/exllama/matrix.cuh | 8 +- autogptq_extension/qigen/generate.py | 9 +- 6 files changed, 154 insertions(+), 161 deletions(-) diff --git a/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu b/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu index 21c06d3..b356dc4 100644 --- a/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu +++ b/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu @@ -30,9 +30,9 @@ // } // #endif + #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh - __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; @@ -77,7 +77,7 @@ __global__ void VecQuant2MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -92,7 +92,7 @@ __global__ void VecQuant3MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -113,6 +113,7 @@ __global__ void VecQuant4MatMulKernel( int zero_width ); + template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, @@ -122,7 +123,7 @@ __global__ void VecQuant8MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -136,7 +137,7 @@ __global__ void VecQuant2MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -151,7 +152,7 @@ __global__ void VecQuant3MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -166,7 +167,7 @@ __global__ void VecQuant4MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -181,7 +182,7 @@ __global__ void VecQuant8MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -209,7 +210,7 @@ __global__ void VecQuant3MatMulKernelFaster_old( const float* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -223,7 +224,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( const float* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -270,7 +271,7 @@ void vecquant2matmul_cuda( vec.type(), "vecquant2matmul_cuda", ([&] { VecQuant2MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -293,39 +294,39 @@ __global__ void VecQuant2MatMulKernel( ) { int h = BLOCKHEIGHT2 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 16; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 16); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); int k_bit = (k % 16) * 2; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); - + scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -357,7 +358,7 @@ void vecquant3matmul_cuda( vec.type(), "vecquant3matmul_cuda", ([&] { VecQuant3MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -380,15 +381,15 @@ __global__ void VecQuant3MatMulKernel( ) { int h = BLOCKHEIGHT3 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = (h / 3) * 32; int k; unsigned int g; scalar_t w_tmp; - - int z_w = (w / 32) * 3; + + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; unsigned int z_tmp; @@ -412,14 +413,14 @@ __global__ void VecQuant3MatMulKernel( z_w += 1; } } - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 32) * 3; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; int k_mod = k % 32; int k_bit; - + if (k_mod != 10){ if (k_mod != 21){ k_bit = k_mod; @@ -440,20 +441,20 @@ __global__ void VecQuant3MatMulKernel( k_w += 1; } } - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scalar_t((z_tmp) + 1); + zero = scalar_t(z_tmp); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scalar_t((z_tmp) + 1); + zero = scalar_t(z_tmp); } else { - zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); } - + if (k_mod == 10) { w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); } else if (k_mod == 21){ @@ -465,12 +466,12 @@ __global__ void VecQuant3MatMulKernel( } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -502,7 +503,7 @@ void vecquant4matmul_cuda( vec.type(), "vecquant4matmul_cuda", ([&] { VecQuant4MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -525,40 +526,40 @@ __global__ void VecQuant4MatMulKernel( ) { int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 8; int k; unsigned int g; scalar_t w_tmp; + - - int z_w = w / 8; + int z_w = w / 8; int z_mod = (w % 8) * 4; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 8); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); int k_bit = (k % 8) * 4; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); - + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -590,7 +591,7 @@ void vecquant8matmul_cuda( vec.type(), "vecquant8matmul_cuda", ([&] { VecQuant8MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -613,39 +614,39 @@ __global__ void VecQuant8MatMulKernel( ) { int h = BLOCKHEIGHT8 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 4; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 4; + + int z_w = w / 4; int z_mod = (w % 4) * 8; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 4); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); int k_bit = (k % 4) * 8; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); - + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -712,19 +713,19 @@ __global__ void VecQuant2MatMulKernel_old( int i = width * h + w; int g_h = h * 16; int k = 0; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); - + scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); + res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; @@ -741,7 +742,7 @@ __global__ void VecQuant2MatMulKernel_old( res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; - + i += width; k += 16; } @@ -807,11 +808,11 @@ __global__ void VecQuant3MatMulKernel_old( int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; - - int z_w = (w / 32) * 3; + + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; - + if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; @@ -832,7 +833,7 @@ __global__ void VecQuant3MatMulKernel_old( z_w += 1; } } - + unsigned int tmp1; unsigned int tmp2; unsigned int tmp; @@ -840,20 +841,20 @@ __global__ void VecQuant3MatMulKernel_old( while (k < BLOCKWIDTH) { tmp1 = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scale * scalar_t((z_tmp) + 1); + zero = scale * scalar_t(z_tmp); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scale * scalar_t((z_tmp) + 1); + zero = scale * scalar_t(z_tmp); } else { - zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); } - + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -864,14 +865,14 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; tmp2 = as_unsigned(mat[i]); tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); tmp2 >>= 1; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; - + res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -882,14 +883,14 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; tmp1 = as_unsigned(mat[i]); tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); tmp1 >>= 2; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; - + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -900,7 +901,7 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; k += 10; } @@ -967,18 +968,18 @@ __global__ void VecQuant4MatMulKernel_old( int g_h = h * 8; int k = 0; - int z_w = w / 8; + int z_w = w / 8; int z_mod = (w % 8) * 4; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); - + scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); + res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; @@ -987,7 +988,7 @@ __global__ void VecQuant4MatMulKernel_old( res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; - + i += width; k += 8; } @@ -1053,24 +1054,24 @@ __global__ void VecQuant8MatMulKernel_old( int i = width * h + w; int g_h = h * 4; int k = 0; - - int z_w = w / 4; + + int z_w = w / 4; int z_mod = (w % 4) * 8; unsigned int tmp; - while (k < BLOCKWIDTH) { + while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); - + scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); + res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; - + i += width; k += 4; } @@ -1092,7 +1093,7 @@ void vecquant2matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1144,8 +1145,8 @@ __global__ void VecQuant2MatMulKernelFaster_old( int i = width * h + w; int g_h = h * 16; int k = 0; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; float res = 0; @@ -1159,8 +1160,8 @@ __global__ void VecQuant2MatMulKernelFaster_old( int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); - + half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3))); + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); @@ -1192,7 +1193,7 @@ void vecquant3matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1244,11 +1245,11 @@ __global__ void VecQuant3MatMulKernelFaster_old( int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; - + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; - + if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; @@ -1287,14 +1288,14 @@ __global__ void VecQuant3MatMulKernelFaster_old( half2 zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); + zero = __float2half2_rn(-(scale_f * z_tmp)); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); + zero = __float2half2_rn(-(scale_f * z_tmp)); } else { - zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); + zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7))); } - + std::memset(&res2, 0, sizeof(half2)); tmp1 = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); @@ -1344,7 +1345,7 @@ void vecquant4matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1397,7 +1398,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( int g_h = h * 8; int k = 0; - int z_w = w / 8; + int z_w = w / 8; int z_mod = (w % 8) * 4; float res = 0; @@ -1410,14 +1411,9 @@ __global__ void VecQuant4MatMulKernelFaster_old( while (k < blockwidth2) { int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; - half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); - - //std::memset(&res2, 0, sizeof(half2)); - - //res2 = __float2half2_rn((float)0.); - + half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF))); + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); @@ -1426,10 +1422,8 @@ __global__ void VecQuant4MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); i += width; k += 4; - res += __low2float(res2) + __high2float(res2); - } atomicAdd(&mul[b * width + w], res); -} +} \ No newline at end of file diff --git a/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu b/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu index d0ddc7c..ba232bf 100644 --- a/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu +++ b/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu @@ -313,7 +313,7 @@ __global__ void VecQuant2MatMulKernel( g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); @@ -447,12 +447,12 @@ __global__ void VecQuant3MatMulKernel( scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scalar_t((z_tmp) + 1); + zero = scalar_t(z_tmp); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scalar_t((z_tmp) + 1); + zero = scalar_t(z_tmp); } else { - zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); } if (k_mod == 10) { @@ -546,7 +546,7 @@ __global__ void VecQuant4MatMulKernel( g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); @@ -633,7 +633,7 @@ __global__ void VecQuant8MatMulKernel( g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); @@ -724,7 +724,7 @@ __global__ void VecQuant2MatMulKernel_old( int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; @@ -847,12 +847,12 @@ __global__ void VecQuant3MatMulKernel_old( scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scale * scalar_t((z_tmp) + 1); + zero = scale * scalar_t(z_tmp); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scale * scalar_t((z_tmp) + 1); + zero = scale * scalar_t(z_tmp); } else { - zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); + zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); } res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; @@ -978,7 +978,7 @@ __global__ void VecQuant4MatMulKernel_old( int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; @@ -1065,7 +1065,7 @@ __global__ void VecQuant8MatMulKernel_old( int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; @@ -1160,7 +1160,7 @@ __global__ void VecQuant2MatMulKernelFaster_old( int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); + half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3))); std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); @@ -1288,12 +1288,12 @@ __global__ void VecQuant3MatMulKernelFaster_old( half2 zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); + zero = __float2half2_rn(-(scale_f * z_tmp)); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); + zero = __float2half2_rn(-(scale_f * z_tmp)); } else { - zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); + zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7))); } std::memset(&res2, 0, sizeof(half2)); @@ -1412,7 +1412,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); + half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF))); std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); @@ -1426,4 +1426,4 @@ __global__ void VecQuant4MatMulKernelFaster_old( } atomicAdd(&mul[b * width + w], res); -} +} \ No newline at end of file diff --git a/autogptq_extension/exllama/cuda_func/q4_matmul.cu b/autogptq_extension/exllama/cuda_func/q4_matmul.cu index 0ee6e16..18ee972 100644 --- a/autogptq_extension/exllama/cuda_func/q4_matmul.cu +++ b/autogptq_extension/exllama/cuda_func/q4_matmul.cu @@ -87,7 +87,7 @@ __global__ void q4_matmul_kernel if constexpr (use_half2) { half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = w_zeros_.item(group, w_column); if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); @@ -95,7 +95,7 @@ __global__ void q4_matmul_kernel else { half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = w_zeros_.item(group, w_column); if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); @@ -112,7 +112,7 @@ __global__ void q4_matmul_kernel { int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = w_zeros_.item(group, w_column); if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); @@ -121,7 +121,7 @@ __global__ void q4_matmul_kernel { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + uint32_t w_zero = w_zeros_.item(group, w_column); if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); diff --git a/autogptq_extension/exllama/cuda_func/q4_matrix.cu b/autogptq_extension/exllama/cuda_func/q4_matrix.cu index 2b3600e..a3774f2 100644 --- a/autogptq_extension/exllama/cuda_func/q4_matrix.cu +++ b/autogptq_extension/exllama/cuda_func/q4_matrix.cu @@ -197,7 +197,7 @@ __global__ void reconstruct_kernel int group = row / groupsize; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column) + 1; + uint32_t w_zero = w_zeros_.item(group, column); uint32_t w_read = w_.item_uint32_t(row, column); half* out_ptr = out_.item_ptr(row, column); @@ -222,4 +222,4 @@ void Q4Matrix::reconstruct(half* out) ); reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} \ No newline at end of file +} diff --git a/autogptq_extension/exllama/matrix.cuh b/autogptq_extension/exllama/matrix.cuh index 2fd5ab0..37eb4b6 100644 --- a/autogptq_extension/exllama/matrix.cuh +++ b/autogptq_extension/exllama/matrix.cuh @@ -94,7 +94,7 @@ __device__ __forceinline__ half2 dot_product_8 const int v_row, // divisible by 8 const int v_column, const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) + const uint32_t v_zero, const int count ) { @@ -145,7 +145,7 @@ __device__ __forceinline__ half dot_product_8_h const int v_row, // divisible by 8 const int v_column, const half v_scale, - const uint32_t v_zero, // + 1 (!!) + const uint32_t v_zero, const int count ) { @@ -192,7 +192,7 @@ __device__ __forceinline__ half2 dot_product_8_x_map const int v_row, // divisible by 8 const int v_column, const half2 v_scale_2, - const uint32_t v_zero, // + 1 (!!) + const uint32_t v_zero, const int count, const uint32_t* x_map ) @@ -254,7 +254,7 @@ __device__ __forceinline__ half dot_product_8_x_map_h const int v_row, // divisible by 8 const int v_column, const half v_scale, - const uint32_t v_zero, // + 1 (!!) + const uint32_t v_zero, const int count, const uint32_t* x_map ) diff --git a/autogptq_extension/qigen/generate.py b/autogptq_extension/qigen/generate.py index 724c891..ac7a8cd 100644 --- a/autogptq_extension/qigen/generate.py +++ b/autogptq_extension/qigen/generate.py @@ -1162,7 +1162,7 @@ def unpack_zeros(bits): res += f"void unpack_zeros{bits}_cpu(const int* zv, float* ov, int n, int m)" packed = 32//bits mask = (2**bits)-1 - res += "{\nconst __m256i ones = _mm256_set1_epi32(1);\n" + res += "{\n" res += f"const __m256i mask = _mm256_set1_epi32({mask});\n" if bits == 4: res += "const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0);\n" @@ -1179,15 +1179,14 @@ def unpack_zeros(bits): res += "__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]);\n" res += "__m256i z0 = _mm256_srlv_epi32(z, shift);\n" res += "__m256i z1 = _mm256_and_si256(z0, mask);\n" - res += "__m256i z2 = _mm256_add_epi32(z1, ones);\n" - res += "__m256 z3 = _mm256_cvtepi32_ps(z2);\n" - res += "_mm256_storeu_ps(&ov[i*m +j], z3);\n" + res += "__m256 z2 = _mm256_cvtepi32_ps(z1);\n" + res += "_mm256_storeu_ps(&ov[i*m +j], z2);\n" elif bits == 2: res += f"for (int j = 0; j < m; j+={packed})" res += "{\n" res += f"for (int k = 0; k < {packed}; k++)" res += "{\n" - res += f"ov[i*m + j+k] = (((zv[j/{packed}] >> ({bits}*k)) & {mask})+1);\n" + res += f"ov[i*m + j+k] = ((zv[j/{packed}] >> ({bits}*k)) & {mask});\n" res += "}\n" # res += "for(int j = 0; j < m; j+=16){\n" # res += "__m256i z = _mm256_set1_epi32(zv[i*m/16 + j/16]);\n"