parent
6b1ceb1897
commit
9e0682a63e
2 changed files with 41 additions and 140 deletions
|
@ -1,4 +1,4 @@
|
||||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||||
|
|
||||||
#include "q4_matmul.cuh"
|
#include "q4_matmul.cuh"
|
||||||
#include "column_remap.cuh"
|
#include "column_remap.cuh"
|
||||||
|
@ -13,6 +13,8 @@
|
||||||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||||
|
|
||||||
|
const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0
|
||||||
|
|
||||||
typedef void (*fp_q4_matmul_kernel)
|
typedef void (*fp_q4_matmul_kernel)
|
||||||
(
|
(
|
||||||
const half*,
|
const half*,
|
||||||
|
@ -46,12 +48,15 @@ __global__ void q4_matmul_kernel
|
||||||
bool no_zero
|
bool no_zero
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
|
extern __shared__ half2 x_cache[];
|
||||||
|
half* x_cache_h = (half*)x_cache;
|
||||||
|
|
||||||
// Start of block
|
// Start of block
|
||||||
|
|
||||||
int x_column = block_size_z * blockIdx.z;
|
int x_column = block_size_z * blockIdx.z;
|
||||||
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||||
|
|
||||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X
|
||||||
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||||
|
|
||||||
int iterations = (x_column_end - x_column) / 8;
|
int iterations = (x_column_end - x_column) / 8;
|
||||||
|
@ -69,8 +74,8 @@ __global__ void q4_matmul_kernel
|
||||||
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||||
{
|
{
|
||||||
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||||
__syncthreads();
|
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// Loop over part of x row (and w column)
|
// Loop over part of x row (and w column)
|
||||||
|
|
||||||
|
@ -84,48 +89,56 @@ __global__ void q4_matmul_kernel
|
||||||
|
|
||||||
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||||
{
|
{
|
||||||
|
for (int i = threadIdx.x; i < groupsize; i += THREADS_X)
|
||||||
|
{
|
||||||
|
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
|
||||||
|
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
if constexpr (use_half2)
|
if constexpr (use_half2)
|
||||||
{
|
{
|
||||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||||
|
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
|
||||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
half w_scale = w_scales_.item(group, w_column);
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||||
|
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
|
||||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
// Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache
|
||||||
|
|
||||||
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP)
|
||||||
{
|
{
|
||||||
|
for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X)
|
||||||
|
{
|
||||||
|
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
|
||||||
|
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
if constexpr (use_half2)
|
if constexpr (use_half2)
|
||||||
{
|
{
|
||||||
int group = k / groupsize;
|
int group = k / groupsize;
|
||||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||||
|
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
|
||||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
|
||||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
int group = k / groupsize;
|
int group = k / groupsize;
|
||||||
half w_scale = w_scales_.item(group, w_column);
|
half w_scale = w_scales_.item(group, w_column);
|
||||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||||
|
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
|
||||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
|
||||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -133,7 +146,7 @@ __global__ void q4_matmul_kernel
|
||||||
|
|
||||||
if constexpr (use_half2)
|
if constexpr (use_half2)
|
||||||
{
|
{
|
||||||
half result = __hadd(__low2half(acc), __high2half(acc));
|
half result = __hadd(acc.x, acc.y);
|
||||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
|
@ -215,8 +228,8 @@ void q4_matmul_cuda
|
||||||
);
|
);
|
||||||
|
|
||||||
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||||
|
int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half);
|
||||||
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
kernel<<<blocks, threads, shared_mem, alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
void q4_matmul_recons_cuda
|
void q4_matmul_recons_cuda
|
||||||
|
@ -240,7 +253,7 @@ void q4_matmul_recons_cuda
|
||||||
const half* x_mapped = x;
|
const half* x_mapped = x;
|
||||||
if (w->cuda_x_map)
|
if (w->cuda_x_map)
|
||||||
{
|
{
|
||||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
|
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
|
||||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||||
x_mapped = buffers->temp_state;
|
x_mapped = buffers->temp_state;
|
||||||
}
|
}
|
||||||
|
@ -248,13 +261,18 @@ void q4_matmul_recons_cuda
|
||||||
w->reconstruct(buffers->temp_dq);
|
w->reconstruct(buffers->temp_dq);
|
||||||
|
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
||||||
|
|
||||||
const float alpha = 1.0f;
|
const float alpha = 1.0f;
|
||||||
const float beta = no_zero ? 1.0f : 0.0f;
|
const float beta = no_zero ? 1.0f : 0.0f;
|
||||||
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||||
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
const half alpha = __float2half(1.0f);
|
const half alpha = __float2half(1.0f);
|
||||||
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||||
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,9 +87,7 @@ public:
|
||||||
__device__ __forceinline__ half2 dot_product_8
|
__device__ __forceinline__ half2 dot_product_8
|
||||||
(
|
(
|
||||||
const half2 acc,
|
const half2 acc,
|
||||||
MatrixView_half& h_,
|
const half2* h_ptr,
|
||||||
const int h_row,
|
|
||||||
const int h_column, // divisible by 8
|
|
||||||
MatrixView_q4_column& v_,
|
MatrixView_q4_column& v_,
|
||||||
const int v_row, // divisible by 8
|
const int v_row, // divisible by 8
|
||||||
const int v_column,
|
const int v_column,
|
||||||
|
@ -98,7 +96,6 @@ __device__ __forceinline__ half2 dot_product_8
|
||||||
const int count
|
const int count
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
|
||||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
half2 result = acc;
|
half2 result = acc;
|
||||||
|
|
||||||
|
@ -138,9 +135,7 @@ __device__ __forceinline__ half2 dot_product_8
|
||||||
__device__ __forceinline__ half dot_product_8_h
|
__device__ __forceinline__ half dot_product_8_h
|
||||||
(
|
(
|
||||||
const half acc,
|
const half acc,
|
||||||
MatrixView_half& h_,
|
const half* h_ptr,
|
||||||
const int h_row,
|
|
||||||
const int h_column, // divisible by 8
|
|
||||||
MatrixView_q4_column& v_,
|
MatrixView_q4_column& v_,
|
||||||
const int v_row, // divisible by 8
|
const int v_row, // divisible by 8
|
||||||
const int v_column,
|
const int v_column,
|
||||||
|
@ -149,7 +144,6 @@ __device__ __forceinline__ half dot_product_8_h
|
||||||
const int count
|
const int count
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
|
||||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||||
half result = acc;
|
half result = acc;
|
||||||
|
|
||||||
|
@ -180,115 +174,4 @@ __device__ __forceinline__ half dot_product_8_h
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
|
||||||
|
|
||||||
__device__ __forceinline__ half2 dot_product_8_x_map
|
|
||||||
(
|
|
||||||
const half2 acc,
|
|
||||||
MatrixView_half& h_,
|
|
||||||
const int h_row,
|
|
||||||
const int h_column, // divisible by 8
|
|
||||||
MatrixView_q4_column& v_,
|
|
||||||
const int v_row, // divisible by 8
|
|
||||||
const int v_column,
|
|
||||||
const half2 v_scale_2,
|
|
||||||
const uint32_t v_zero,
|
|
||||||
const int count,
|
|
||||||
const uint32_t* x_map
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
|
||||||
const uint32_t* x_map_ptr = x_map + h_column;
|
|
||||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
|
||||||
half2 result = acc;
|
|
||||||
|
|
||||||
for (int i = 0; i < count; i++)
|
|
||||||
{
|
|
||||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
|
||||||
|
|
||||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
|
||||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
|
||||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
|
||||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
|
||||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
|
||||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
|
||||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
|
||||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
|
||||||
|
|
||||||
half2 v_01 = __halves2half2(v_0, v_1);
|
|
||||||
half2 v_23 = __halves2half2(v_2, v_3);
|
|
||||||
half2 v_45 = __halves2half2(v_4, v_5);
|
|
||||||
half2 v_67 = __halves2half2(v_6, v_7);
|
|
||||||
|
|
||||||
half h_0 = h_ptr[*x_map_ptr++];
|
|
||||||
half h_1 = h_ptr[*x_map_ptr++];
|
|
||||||
half h_2 = h_ptr[*x_map_ptr++];
|
|
||||||
half h_3 = h_ptr[*x_map_ptr++];
|
|
||||||
half h_4 = h_ptr[*x_map_ptr++];
|
|
||||||
half h_5 = h_ptr[*x_map_ptr++];
|
|
||||||
half h_6 = h_ptr[*x_map_ptr++];
|
|
||||||
half h_7 = h_ptr[*x_map_ptr++];
|
|
||||||
|
|
||||||
half2 h_01 = __halves2half2(h_0, h_1);
|
|
||||||
half2 h_23 = __halves2half2(h_2, h_3);
|
|
||||||
half2 h_45 = __halves2half2(h_4, h_5);
|
|
||||||
half2 h_67 = __halves2half2(h_6, h_7);
|
|
||||||
|
|
||||||
half2 tmp = __hmul2(h_01, v_01);
|
|
||||||
tmp = __hfma2(h_23, v_23, tmp);
|
|
||||||
tmp = __hfma2(h_45, v_45, tmp);
|
|
||||||
tmp = __hfma2(h_67, v_67, tmp);
|
|
||||||
result = __hfma2(v_scale_2, tmp, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
__device__ __forceinline__ half dot_product_8_x_map_h
|
|
||||||
(
|
|
||||||
const half acc,
|
|
||||||
MatrixView_half& h_,
|
|
||||||
const int h_row,
|
|
||||||
const int h_column, // divisible by 8
|
|
||||||
MatrixView_q4_column& v_,
|
|
||||||
const int v_row, // divisible by 8
|
|
||||||
const int v_column,
|
|
||||||
const half v_scale,
|
|
||||||
const uint32_t v_zero,
|
|
||||||
const int count,
|
|
||||||
const uint32_t* x_map
|
|
||||||
)
|
|
||||||
{
|
|
||||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
|
||||||
const uint32_t* x_map_ptr = x_map + h_column;
|
|
||||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
|
||||||
half result = acc;
|
|
||||||
|
|
||||||
for (int i = 0; i < count; i++)
|
|
||||||
{
|
|
||||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
|
||||||
|
|
||||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
|
||||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
|
||||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
|
||||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
|
||||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
|
||||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
|
||||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
|
||||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
|
||||||
|
|
||||||
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
|
||||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
|
||||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
|
||||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
|
||||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
|
||||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
|
||||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
|
||||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
|
||||||
result = __hfma(v_scale, tmp, result);
|
|
||||||
}
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Add table
Reference in a new issue