From 9e0682a63e37ac46e9e1fa138955cfa09f62aa94 Mon Sep 17 00:00:00 2001 From: qwopqwop200 Date: Thu, 7 Sep 2023 12:54:46 +0900 Subject: [PATCH] Optimize q4_matmul https://github.com/turboderp/exllama/pull/275 --- .../exllama/cuda_func/q4_matmul.cu | 60 ++++++--- autogptq_extension/exllama/matrix.cuh | 121 +----------------- 2 files changed, 41 insertions(+), 140 deletions(-) diff --git a/autogptq_extension/exllama/cuda_func/q4_matmul.cu b/autogptq_extension/exllama/cuda_func/q4_matmul.cu index 18ee972..7e4d6af 100644 --- a/autogptq_extension/exllama/cuda_func/q4_matmul.cu +++ b/autogptq_extension/exllama/cuda_func/q4_matmul.cu @@ -1,4 +1,4 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama +// Adapted from turboderp exllama: https://github.com/turboderp/exllama #include "q4_matmul.cuh" #include "column_remap.cuh" @@ -13,6 +13,8 @@ const int THREADS_X = 32; // Block size and thread count along columns in w and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out +const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0 + typedef void (*fp_q4_matmul_kernel) ( const half*, @@ -46,12 +48,15 @@ __global__ void q4_matmul_kernel bool no_zero ) { + extern __shared__ half2 x_cache[]; + half* x_cache_h = (half*)x_cache; + // Start of block int x_column = block_size_z * blockIdx.z; int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - int w_column = THREADS_X * blockIdx.x + threadIdx.x; + int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X int x_row = THREADS_Y * blockIdx.y + threadIdx.y; int iterations = (x_column_end - x_column) / 8; @@ -69,8 +74,8 @@ __global__ void q4_matmul_kernel if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) { *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; - __syncthreads(); } + __syncthreads(); // Loop over part of x row (and w column) @@ -84,48 +89,56 @@ __global__ void q4_matmul_kernel for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) { + for (int i = threadIdx.x; i < groupsize; i += THREADS_X) + { + if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); + else x_cache_h[i] = *x_.item_ptr(x_row, k + i); + } + __syncthreads(); + if constexpr (use_half2) { half2 w_scale = w_scales_.item_half2half2(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column); - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8); } else { half w_scale = w_scales_.item(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column); - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); + acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8); } + __syncthreads(); } } else { - // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache + // Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache - for (int k = x_column; k < x_column + iterations * 8; k += 8) + for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP) { + for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X) + { + if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); + else x_cache_h[i] = *x_.item_ptr(x_row, k + i); + } + __syncthreads(); + if constexpr (use_half2) { int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column); - - if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); } else { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); uint32_t w_zero = w_zeros_.item(group, w_column); - - if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); - else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); + acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); } + __syncthreads(); } } @@ -133,7 +146,7 @@ __global__ void q4_matmul_kernel if constexpr (use_half2) { - half result = __hadd(__low2half(acc), __high2half(acc)); + half result = __hadd(acc.x, acc.y); atomicAdd(out_.item_ptr(x_row, w_column), result); } else @@ -215,8 +228,8 @@ void q4_matmul_cuda ); fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - - kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); + int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half); + kernel<<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); } void q4_matmul_recons_cuda @@ -240,7 +253,7 @@ void q4_matmul_recons_cuda const half* x_mapped = x; if (w->cuda_x_map) { - TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)"); + TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); x_mapped = buffers->temp_state; } @@ -248,13 +261,18 @@ void q4_matmul_recons_cuda w->reconstruct(buffers->temp_dq); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 + const float alpha = 1.0f; const float beta = no_zero ? 1.0f : 0.0f; cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); + #else + const half alpha = __float2half(1.0f); const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); + #endif + } diff --git a/autogptq_extension/exllama/matrix.cuh b/autogptq_extension/exllama/matrix.cuh index 37eb4b6..e5efd76 100644 --- a/autogptq_extension/exllama/matrix.cuh +++ b/autogptq_extension/exllama/matrix.cuh @@ -87,9 +87,7 @@ public: __device__ __forceinline__ half2 dot_product_8 ( const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 + const half2* h_ptr, MatrixView_q4_column& v_, const int v_row, // divisible by 8 const int v_column, @@ -98,7 +96,6 @@ __device__ __forceinline__ half2 dot_product_8 const int count ) { - const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); half2 result = acc; @@ -138,9 +135,7 @@ __device__ __forceinline__ half2 dot_product_8 __device__ __forceinline__ half dot_product_8_h ( const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 + const half* h_ptr, MatrixView_q4_column& v_, const int v_row, // divisible by 8 const int v_column, @@ -149,7 +144,6 @@ __device__ __forceinline__ half dot_product_8_h const int count ) { - const half* h_ptr = h_.item_ptr(h_row, h_column); const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); half result = acc; @@ -180,115 +174,4 @@ __device__ __forceinline__ half dot_product_8_h return result; } -// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map - -__device__ __forceinline__ half2 dot_product_8_x_map -( - const half2 acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half2 v_scale_2, - const uint32_t v_zero, - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half2 result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half2 v_01 = __halves2half2(v_0, v_1); - half2 v_23 = __halves2half2(v_2, v_3); - half2 v_45 = __halves2half2(v_4, v_5); - half2 v_67 = __halves2half2(v_6, v_7); - - half h_0 = h_ptr[*x_map_ptr++]; - half h_1 = h_ptr[*x_map_ptr++]; - half h_2 = h_ptr[*x_map_ptr++]; - half h_3 = h_ptr[*x_map_ptr++]; - half h_4 = h_ptr[*x_map_ptr++]; - half h_5 = h_ptr[*x_map_ptr++]; - half h_6 = h_ptr[*x_map_ptr++]; - half h_7 = h_ptr[*x_map_ptr++]; - - half2 h_01 = __halves2half2(h_0, h_1); - half2 h_23 = __halves2half2(h_2, h_3); - half2 h_45 = __halves2half2(h_4, h_5); - half2 h_67 = __halves2half2(h_6, h_7); - - half2 tmp = __hmul2(h_01, v_01); - tmp = __hfma2(h_23, v_23, tmp); - tmp = __hfma2(h_45, v_45, tmp); - tmp = __hfma2(h_67, v_67, tmp); - result = __hfma2(v_scale_2, tmp, result); - } - - return result; -} - -__device__ __forceinline__ half dot_product_8_x_map_h -( - const half acc, - MatrixView_half& h_, - const int h_row, - const int h_column, // divisible by 8 - MatrixView_q4_column& v_, - const int v_row, // divisible by 8 - const int v_column, - const half v_scale, - const uint32_t v_zero, - const int count, - const uint32_t* x_map -) -{ - const half* h_ptr = h_.item_ptr(h_row, 0); - const uint32_t* x_map_ptr = x_map + h_column; - const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); - half result = acc; - - for (int i = 0; i < count; i++) - { - uint32_t v_read = *v_ptr; v_ptr += v_.width; - - half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); - half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); - half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); - half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); - half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); - half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); - half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); - half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); - - half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); - tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); - tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); - result = __hfma(v_scale, tmp, result); - } - - return result; -} - #endif