From 3de7fbb0d53ccc4516910a7a4000d526c6289d2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BD=98=E5=85=B6=E5=A8=81=28William=29?= <46810637+PanQiWei@users.noreply.github.com> Date: Wed, 27 Sep 2023 10:37:31 +0800 Subject: [PATCH] Revert "fix bug(breaking change) remove (zeors -= 1)" --- auto_gptq/modeling/_base.py | 23 +- auto_gptq/nn_modules/fused_gptj_attn.py | 17 +- auto_gptq/nn_modules/fused_llama_attn.py | 20 +- auto_gptq/nn_modules/qlinear/qlinear_cuda.py | 3 + .../nn_modules/qlinear/qlinear_cuda_old.py | 3 + .../nn_modules/qlinear/qlinear_exllama.py | 1 + .../nn_modules/qlinear/qlinear_triton.py | 1 + auto_gptq/nn_modules/triton_utils/kernels.py | 2 + .../cuda_256/autogptq_cuda_kernel_256.cu | 250 +++++++++--------- .../cuda_64/autogptq_cuda_kernel_64.cu | 36 +-- .../exllama/cuda_func/q4_matmul.cu | 68 ++--- .../exllama/cuda_func/q4_matrix.cu | 4 +- autogptq_extension/exllama/matrix.cuh | 125 ++++++++- autogptq_extension/qigen/generate.py | 9 +- 14 files changed, 325 insertions(+), 237 deletions(-) diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index e86498e..84e92bf 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -39,7 +39,7 @@ class BaseQuantizeConfig(PushToHubMixin): damp_percent: float = field(default=0.01) desc_act: bool = field(default=True) static_groups: bool = field(default=False) - sym: bool = field(default=False) + sym: bool = field(default=True) true_sequential: bool = field(default=True) model_name_or_path: Optional[str] = field(default=None) model_file_base_name: Optional[str] = field(default=None) @@ -967,27 +967,6 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): checkpoint ) model.load_state_dict(checkpoint) - # Preprocessing for backward compatibility - if quantize_config.sym: - QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, disable_exllama=disable_exllama, use_qigen=use_qigen, - desc_act=quantize_config.desc_act, group_size=quantize_config.group_size, bits=quantize_config.bits) - for name, submodule in model.named_modules(): - if isinstance(submodule, QuantLinear): - if use_qigen: - submodule.zeros.data = torch.full_like(submodule.zeros.data, (torch.tensor(2 ** quantize_config.bits - 1) + 1) / 2) - else: - if quantize_config.bits == 2: - submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -1431655766) - elif quantize_config.bits == 3: - submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] = 613566756 - submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] = 1227133513 - submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] = -1840700270 - elif quantize_config.bits == 4: - submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2004318072) - elif quantize_config.bits == 8: - submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2139062144) - else: - raise NotImplementedError("Only 2,3,4,8 bits are supported.") # == step4: set seqlen == # model_config = model.config.to_dict() seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] diff --git a/auto_gptq/nn_modules/fused_gptj_attn.py b/auto_gptq/nn_modules/fused_gptj_attn.py index c4d05ca..236785e 100644 --- a/auto_gptq/nn_modules/fused_gptj_attn.py +++ b/auto_gptq/nn_modules/fused_gptj_attn.py @@ -8,8 +8,6 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention from ._fused_base import FusedBaseAttentionModule from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear -from logging import getLogger -logger = getLogger(__name__) def fixed_pos_embedding(x, seq_dim=1, seq_len=None): dim = x.shape[-1] @@ -242,13 +240,8 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule): **kwargs ): config = model.config - QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2) - if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act: - # See fused_llama_attn.py comment - logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.") - return False - + for name, m in model.named_modules(): if not isinstance(m, GPTJAttention): continue @@ -264,7 +257,11 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule): scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) if QuantLinear.QUANT_TYPE == "exllama": - g_idx = None + if desc_act: + # See fused_llama_attn.py comment + raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.") + else: + g_idx = None else: g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) @@ -301,6 +298,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule): setattr(parent, child_name, attn) del m - return True + __all__ = ["FusedGPTJAttentionForQuantizedModel"] diff --git a/auto_gptq/nn_modules/fused_llama_attn.py b/auto_gptq/nn_modules/fused_llama_attn.py index e81e97a..5936dda 100644 --- a/auto_gptq/nn_modules/fused_llama_attn.py +++ b/auto_gptq/nn_modules/fused_llama_attn.py @@ -7,8 +7,6 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar from ._fused_base import FusedBaseAttentionModule from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear -from logging import getLogger -logger = getLogger(__name__) class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -144,15 +142,8 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): """ Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. """ - QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2) - if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act: - # TODO: support it. The issue lies maybe in the line: - # int groups = qzeros.size(0); - # in exllama_ext.cpp - logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.") - return False - + for name, m in model.named_modules(): if not isinstance(m, LlamaAttention): continue @@ -166,7 +157,13 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) if QuantLinear.QUANT_TYPE == "exllama": - g_idx = None + if desc_act: + # TODO: support it. The issue lies maybe in the line: + # int groups = qzeros.size(0); + # in exllama_ext.cpp + raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.") + else: + g_idx = None else: g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) @@ -201,7 +198,6 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule): child_name = name setattr(parent, child_name, attn) - return True __all__ = ["FusedLlamaAttentionForQuantizedModel"] diff --git a/auto_gptq/nn_modules/qlinear/qlinear_cuda.py b/auto_gptq/nn_modules/qlinear/qlinear_cuda.py index 8c737cb..6355c3c 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_cuda.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_cuda.py @@ -157,6 +157,7 @@ class QuantLinear(nn.Module): qweight = qweight.astype(np.int32) self.qweight = torch.from_numpy(qweight) + zeros -= 1 zeros = zeros.numpy().astype(np.uint32) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) i = 0 @@ -220,6 +221,7 @@ class QuantLinear(nn.Module): ).to(torch.int16 if self.bits == 8 else torch.int8) torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + zeros = zeros + 1 zeros = zeros.reshape(self.scales.shape) weight = torch.bitwise_right_shift( @@ -237,6 +239,7 @@ class QuantLinear(nn.Module): zeros = zeros & 0x7 zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2) + zeros = zeros + 1 zeros = zeros.reshape(self.scales.shape) weight = self.qweight.reshape( diff --git a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py index b328b6d..4cd25f2 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_cuda_old.py @@ -157,6 +157,7 @@ class QuantLinear(nn.Module): qweight = qweight.astype(np.int32) self.qweight = torch.from_numpy(qweight) + zeros -= 1 zeros = zeros.numpy().astype(np.uint32) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) i = 0 @@ -230,6 +231,7 @@ class QuantLinear(nn.Module): zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8) torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros) + zeros = zeros + 1 zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) scales = self.scales @@ -246,6 +248,7 @@ class QuantLinear(nn.Module): zeros = zeros & 0x7 zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2) + zeros = zeros + 1 zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2]) scales = self.scales diff --git a/auto_gptq/nn_modules/qlinear/qlinear_exllama.py b/auto_gptq/nn_modules/qlinear/qlinear_exllama.py index b28fd28..87268a9 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_exllama.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_exllama.py @@ -146,6 +146,7 @@ class QuantLinear(nn.Module): qweight = qweight.astype(np.int32) self.qweight = torch.from_numpy(qweight) + zeros -= 1 zeros = zeros.numpy().astype(np.uint32) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) i = 0 diff --git a/auto_gptq/nn_modules/qlinear/qlinear_triton.py b/auto_gptq/nn_modules/qlinear/qlinear_triton.py index 79bc8e8..4ebc7a3 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_triton.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_triton.py @@ -114,6 +114,7 @@ class QuantLinear(nn.Module, TritonModuleMixin): qweight = qweight.astype(np.int32) self.qweight = torch.from_numpy(qweight) + zeros -= 1 zeros = zeros.numpy().astype(np.uint32) qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) i = 0 diff --git a/auto_gptq/nn_modules/triton_utils/kernels.py b/auto_gptq/nn_modules/triton_utils/kernels.py index 4483585..b8c777c 100644 --- a/auto_gptq/nn_modules/triton_utils/kernels.py +++ b/auto_gptq/nn_modules/triton_utils/kernels.py @@ -144,6 +144,7 @@ def quant_matmul_248_kernel( zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated @@ -289,6 +290,7 @@ def transpose_quant_matmul_248_kernel( zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated diff --git a/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu b/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu index b356dc4..21c06d3 100644 --- a/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu +++ b/autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu @@ -30,9 +30,9 @@ // } // #endif - #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM) // adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh + __device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) { unsigned int *address_as_ui = reinterpret_cast(reinterpret_cast(address) - (reinterpret_cast(address) & 2)); unsigned int old = *address_as_ui; @@ -77,7 +77,7 @@ __global__ void VecQuant2MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -92,7 +92,7 @@ __global__ void VecQuant3MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -113,7 +113,6 @@ __global__ void VecQuant4MatMulKernel( int zero_width ); - template __global__ void VecQuant8MatMulKernel( const scalar_t* __restrict__ vec, @@ -123,7 +122,7 @@ __global__ void VecQuant8MatMulKernel( const int* __restrict__ zeros, const int* __restrict__ g_idx, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width @@ -137,7 +136,7 @@ __global__ void VecQuant2MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -152,7 +151,7 @@ __global__ void VecQuant3MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -167,7 +166,7 @@ __global__ void VecQuant4MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -182,7 +181,7 @@ __global__ void VecQuant8MatMulKernel_old( const scalar_t* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -210,7 +209,7 @@ __global__ void VecQuant3MatMulKernelFaster_old( const float* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -224,7 +223,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( const float* __restrict__ scales, const int* __restrict__ zeros, int batch, - int vec_height, + int vec_height, int height, int width, int zero_width, @@ -271,7 +270,7 @@ void vecquant2matmul_cuda( vec.type(), "vecquant2matmul_cuda", ([&] { VecQuant2MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -294,39 +293,39 @@ __global__ void VecQuant2MatMulKernel( ) { int h = BLOCKHEIGHT2 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 16; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 16); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 16); int k_bit = (k % 16) * 2; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); - + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -358,7 +357,7 @@ void vecquant3matmul_cuda( vec.type(), "vecquant3matmul_cuda", ([&] { VecQuant3MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -381,15 +380,15 @@ __global__ void VecQuant3MatMulKernel( ) { int h = BLOCKHEIGHT3 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = (h / 3) * 32; int k; unsigned int g; scalar_t w_tmp; - - int z_w = (w / 32) * 3; + + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; unsigned int z_tmp; @@ -413,14 +412,14 @@ __global__ void VecQuant3MatMulKernel( z_w += 1; } } - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 32) * 3; + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 32) * 3; int k_mod = k % 32; int k_bit; - + if (k_mod != 10){ if (k_mod != 21){ k_bit = k_mod; @@ -441,20 +440,20 @@ __global__ void VecQuant3MatMulKernel( k_w += 1; } } - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scalar_t(z_tmp); + zero = scalar_t((z_tmp) + 1); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scalar_t(z_tmp); + zero = scalar_t((z_tmp) + 1); } else { - zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); + zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } - + if (k_mod == 10) { w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4); } else if (k_mod == 21){ @@ -466,12 +465,12 @@ __global__ void VecQuant3MatMulKernel( } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -503,7 +502,7 @@ void vecquant4matmul_cuda( vec.type(), "vecquant4matmul_cuda", ([&] { VecQuant4MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -526,40 +525,40 @@ __global__ void VecQuant4MatMulKernel( ) { int h = BLOCKHEIGHT4 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 8; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 8; + + int z_w = w / 8; int z_mod = (w % 8) * 4; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 8); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 8); int k_bit = (k % 8) * 4; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); - + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -591,7 +590,7 @@ void vecquant8matmul_cuda( vec.type(), "vecquant8matmul_cuda", ([&] { VecQuant8MatMulKernel<<>>( vec.data(), mat.data(), mul.data(), - scales.data(), zeros.data(), g_idx.data(), + scales.data(), zeros.data(), g_idx.data(), batch, vec_height, height, width, zero_width ); }) @@ -614,39 +613,39 @@ __global__ void VecQuant8MatMulKernel( ) { int h = BLOCKHEIGHT8 * blockIdx.x; int w = BLOCKWIDTH * blockIdx.y + threadIdx.x; - + __shared__ scalar_t blockvec[BLOCKWIDTH]; int i = width * h + w; int g_h = h * 4; int k; unsigned int g; scalar_t w_tmp; - - int z_w = w / 4; + + int z_w = w / 4; int z_mod = (w % 4) * 8; - + float weight[BLOCKWIDTH]; - - for (k = 0; k < BLOCKWIDTH; ++k){ - int k_w = (k / 4); + + for (k = 0; k < BLOCKWIDTH; ++k){ + int k_w = (k / 4); int k_bit = (k % 4) * 8; - + g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); - + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); - + weight[k] = scale * (w_tmp - zero); } scalar_t res; - for (int b = 0; b < batch; ++b){ + for (int b = 0; b < batch; ++b){ res = 0; - + blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x]; __syncthreads(); - for (k = 0; k < BLOCKWIDTH; ++k){ + for (k = 0; k < BLOCKWIDTH; ++k){ res += weight[k] * blockvec[k]; } atomicAdd(&mul[b * width + w], res); @@ -713,19 +712,19 @@ __global__ void VecQuant2MatMulKernel_old( int i = width * h + w; int g_h = h * 16; int k = 0; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); - + scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); + res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2]; @@ -742,7 +741,7 @@ __global__ void VecQuant2MatMulKernel_old( res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13]; res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14]; res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15]; - + i += width; k += 16; } @@ -808,11 +807,11 @@ __global__ void VecQuant3MatMulKernel_old( int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; - - int z_w = (w / 32) * 3; + + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; - + if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; @@ -833,7 +832,7 @@ __global__ void VecQuant3MatMulKernel_old( z_w += 1; } } - + unsigned int tmp1; unsigned int tmp2; unsigned int tmp; @@ -841,20 +840,20 @@ __global__ void VecQuant3MatMulKernel_old( while (k < BLOCKWIDTH) { tmp1 = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scale * scalar_t(z_tmp); + zero = scale * scalar_t((z_tmp) + 1); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scale * scalar_t(z_tmp); + zero = scale * scalar_t((z_tmp) + 1); } else { - zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); + zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } - + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -865,14 +864,14 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; tmp2 = as_unsigned(mat[i]); tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4); tmp2 >>= 1; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; - + res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -883,14 +882,14 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; tmp1 = as_unsigned(mat[i]); tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6); tmp1 >>= 2; res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10]; k += 11; - + res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2]; @@ -901,7 +900,7 @@ __global__ void VecQuant3MatMulKernel_old( res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7]; res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8]; res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9]; - + i += width; k += 10; } @@ -968,18 +967,18 @@ __global__ void VecQuant4MatMulKernel_old( int g_h = h * 8; int k = 0; - int z_w = w / 8; + int z_w = w / 8; int z_mod = (w % 8) * 4; unsigned int tmp; while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); - + scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); + res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2]; @@ -988,7 +987,7 @@ __global__ void VecQuant4MatMulKernel_old( res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5]; res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6]; res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7]; - + i += width; k += 8; } @@ -1054,24 +1053,24 @@ __global__ void VecQuant8MatMulKernel_old( int i = width * h + w; int g_h = h * 4; int k = 0; - - int z_w = w / 4; + + int z_w = w / 4; int z_mod = (w % 4) * 8; unsigned int tmp; - while (k < BLOCKWIDTH) { + while (k < BLOCKWIDTH) { tmp = as_unsigned(mat[i]); - + int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); - + scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); + res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2]; res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3]; - + i += width; k += 4; } @@ -1093,7 +1092,7 @@ void vecquant2matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1145,8 +1144,8 @@ __global__ void VecQuant2MatMulKernelFaster_old( int i = width * h + w; int g_h = h * 16; int k = 0; - - int z_w = w / 16; + + int z_w = w / 16; int z_mod = (w % 16) * 2; float res = 0; @@ -1160,8 +1159,8 @@ __global__ void VecQuant2MatMulKernelFaster_old( int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3))); - + half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2); @@ -1193,7 +1192,7 @@ void vecquant3matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1245,11 +1244,11 @@ __global__ void VecQuant3MatMulKernelFaster_old( int i = width * h + w; int g_h = (h / 3) * 32; int k = 0; - + int z_w = (w / 32) * 3; int z_mod = w % 32; int z_bit; - + if (z_mod != 10){ if (z_mod != 21){ z_bit = z_mod; @@ -1288,14 +1287,14 @@ __global__ void VecQuant3MatMulKernelFaster_old( half2 zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = __float2half2_rn(-(scale_f * z_tmp)); + zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = __float2half2_rn(-(scale_f * z_tmp)); + zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); } else { - zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7))); + zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); } - + std::memset(&res2, 0, sizeof(half2)); tmp1 = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2); @@ -1345,7 +1344,7 @@ void vecquant4matmul_faster_cuda_old( int height = mat.size(0); int width = mat.size(1); int zero_width = zeros.size(1); - + dim3 blocks( (height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, (width + BLOCKWIDTH - 1) / BLOCKWIDTH, @@ -1398,7 +1397,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( int g_h = h * 8; int k = 0; - int z_w = w / 8; + int z_w = w / 8; int z_mod = (w % 8) * 4; float res = 0; @@ -1411,9 +1410,14 @@ __global__ void VecQuant4MatMulKernelFaster_old( while (k < blockwidth2) { int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; + half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF))); - + half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); + + //std::memset(&res2, 0, sizeof(half2)); + + //res2 = __float2half2_rn((float)0.); + std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2); @@ -1422,8 +1426,10 @@ __global__ void VecQuant4MatMulKernelFaster_old( res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2); i += width; k += 4; + res += __low2float(res2) + __high2float(res2); + } atomicAdd(&mul[b * width + w], res); -} \ No newline at end of file +} diff --git a/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu b/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu index ba232bf..d0ddc7c 100644 --- a/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu +++ b/autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu @@ -313,7 +313,7 @@ __global__ void VecQuant2MatMulKernel( g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); + scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3); @@ -447,12 +447,12 @@ __global__ void VecQuant3MatMulKernel( scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scalar_t(z_tmp); + zero = scalar_t((z_tmp) + 1); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scalar_t(z_tmp); + zero = scalar_t((z_tmp) + 1); } else { - zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); + zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } if (k_mod == 10) { @@ -546,7 +546,7 @@ __global__ void VecQuant4MatMulKernel( g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF); @@ -633,7 +633,7 @@ __global__ void VecQuant8MatMulKernel( g = as_int(g_idx[g_h + k]); scalar_t scale = scales[g * width + w]; - scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); + scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF); @@ -724,7 +724,7 @@ __global__ void VecQuant2MatMulKernel_old( int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3); + scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1); res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1]; @@ -847,12 +847,12 @@ __global__ void VecQuant3MatMulKernel_old( scalar_t zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = scale * scalar_t(z_tmp); + zero = scale * scalar_t((z_tmp) + 1); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = scale * scalar_t(z_tmp); + zero = scale * scalar_t((z_tmp) + 1); } else { - zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7); + zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1); } res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0]; @@ -978,7 +978,7 @@ __global__ void VecQuant4MatMulKernel_old( int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF); + scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1); res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1]; @@ -1065,7 +1065,7 @@ __global__ void VecQuant8MatMulKernel_old( int g = (g_h + k) / groupsize; scalar_t scale = scales[g * width + w]; - scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF); + scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1); res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0]; res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1]; @@ -1160,7 +1160,7 @@ __global__ void VecQuant2MatMulKernelFaster_old( int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3))); + half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1))); std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); @@ -1288,12 +1288,12 @@ __global__ void VecQuant3MatMulKernelFaster_old( half2 zero; if (z_mod == 10) { z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4); - zero = __float2half2_rn(-(scale_f * z_tmp)); + zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); } else if (z_mod == 21){ z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6); - zero = __float2half2_rn(-(scale_f * z_tmp)); + zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1))); } else { - zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7))); + zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1))); } std::memset(&res2, 0, sizeof(half2)); @@ -1412,7 +1412,7 @@ __global__ void VecQuant4MatMulKernelFaster_old( int g = (g_h + (k * 2)) / groupsize; float scale_f = scales[g * width + w]; half2 scale = __float2half2_rn(scale_f); - half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF))); + half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1))); std::memset(&res2, 0, sizeof(half2)); tmp = as_unsigned(mat[i]); @@ -1426,4 +1426,4 @@ __global__ void VecQuant4MatMulKernelFaster_old( } atomicAdd(&mul[b * width + w], res); -} \ No newline at end of file +} diff --git a/autogptq_extension/exllama/cuda_func/q4_matmul.cu b/autogptq_extension/exllama/cuda_func/q4_matmul.cu index 7e4d6af..0ee6e16 100644 --- a/autogptq_extension/exllama/cuda_func/q4_matmul.cu +++ b/autogptq_extension/exllama/cuda_func/q4_matmul.cu @@ -1,4 +1,4 @@ -// Adapted from turboderp exllama: https://github.com/turboderp/exllama +// Adapted from turboderp exllama: https://github.com/turboderp/exllama #include "q4_matmul.cuh" #include "column_remap.cuh" @@ -13,8 +13,6 @@ const int THREADS_X = 32; // Block size and thread count along columns in w and out const int THREADS_Y = 1; // Block size and thread count along rows in x and out -const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0 - typedef void (*fp_q4_matmul_kernel) ( const half*, @@ -48,15 +46,12 @@ __global__ void q4_matmul_kernel bool no_zero ) { - extern __shared__ half2 x_cache[]; - half* x_cache_h = (half*)x_cache; - // Start of block int x_column = block_size_z * blockIdx.z; int x_column_end = min(dim, block_size_z * (blockIdx.z + 1)); - int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X + int w_column = THREADS_X * blockIdx.x + threadIdx.x; int x_row = THREADS_Y * blockIdx.y + threadIdx.y; int iterations = (x_column_end - x_column) / 8; @@ -74,8 +69,8 @@ __global__ void q4_matmul_kernel if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0) { *((uint32_t*) out_.item_ptr(x_row, w_column)) = 0; + __syncthreads(); } - __syncthreads(); // Loop over part of x row (and w column) @@ -89,56 +84,48 @@ __global__ void q4_matmul_kernel for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize) { - for (int i = threadIdx.x; i < groupsize; i += THREADS_X) - { - if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); - else x_cache_h[i] = *x_.item_ptr(x_row, k + i); - } - __syncthreads(); - if constexpr (use_half2) { half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column); - acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); } else { half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column); - acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8); } - __syncthreads(); } } else { - // Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache + // Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache - for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP) + for (int k = x_column; k < x_column + iterations * 8; k += 8) { - for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X) - { - if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]); - else x_cache_h[i] = *x_.item_ptr(x_row, k + i); - } - __syncthreads(); - if constexpr (use_half2) { int group = k / groupsize; half2 w_scale = w_scales_.item_half2half2(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column); - acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); } else { int group = k / groupsize; half w_scale = w_scales_.item(group, w_column); - uint32_t w_zero = w_zeros_.item(group, w_column); - acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8); + uint32_t w_zero = w_zeros_.item(group, w_column) + 1; + + if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map); + else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1); } - __syncthreads(); } } @@ -146,7 +133,7 @@ __global__ void q4_matmul_kernel if constexpr (use_half2) { - half result = __hadd(acc.x, acc.y); + half result = __hadd(__low2half(acc), __high2half(acc)); atomicAdd(out_.item_ptr(x_row, w_column), result); } else @@ -228,8 +215,8 @@ void q4_matmul_cuda ); fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map); - int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half); - kernel<<>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); + + kernel<<>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero); } void q4_matmul_recons_cuda @@ -253,7 +240,7 @@ void q4_matmul_recons_cuda const half* x_mapped = x; if (w->cuda_x_map) { - TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small"); + TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)"); column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map); x_mapped = buffers->temp_state; } @@ -261,18 +248,13 @@ void q4_matmul_recons_cuda w->reconstruct(buffers->temp_dq); #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700 - const float alpha = 1.0f; const float beta = no_zero ? 1.0f : 0.0f; cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width, x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width); - #else - const half alpha = __float2half(1.0f); const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f); cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width); - #endif - } diff --git a/autogptq_extension/exllama/cuda_func/q4_matrix.cu b/autogptq_extension/exllama/cuda_func/q4_matrix.cu index a3774f2..2b3600e 100644 --- a/autogptq_extension/exllama/cuda_func/q4_matrix.cu +++ b/autogptq_extension/exllama/cuda_func/q4_matrix.cu @@ -197,7 +197,7 @@ __global__ void reconstruct_kernel int group = row / groupsize; half w_scale = w_scales_.item(group, column); - uint32_t w_zero = w_zeros_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; uint32_t w_read = w_.item_uint32_t(row, column); half* out_ptr = out_.item_ptr(row, column); @@ -222,4 +222,4 @@ void Q4Matrix::reconstruct(half* out) ); reconstruct_kernel<<>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize); -} +} \ No newline at end of file diff --git a/autogptq_extension/exllama/matrix.cuh b/autogptq_extension/exllama/matrix.cuh index e5efd76..2fd5ab0 100644 --- a/autogptq_extension/exllama/matrix.cuh +++ b/autogptq_extension/exllama/matrix.cuh @@ -87,15 +87,18 @@ public: __device__ __forceinline__ half2 dot_product_8 ( const half2 acc, - const half2* h_ptr, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 MatrixView_q4_column& v_, const int v_row, // divisible by 8 const int v_column, const half2 v_scale_2, - const uint32_t v_zero, + const uint32_t v_zero, // + 1 (!!) const int count ) { + const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column); const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); half2 result = acc; @@ -135,15 +138,18 @@ __device__ __forceinline__ half2 dot_product_8 __device__ __forceinline__ half dot_product_8_h ( const half acc, - const half* h_ptr, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 MatrixView_q4_column& v_, const int v_row, // divisible by 8 const int v_column, const half v_scale, - const uint32_t v_zero, + const uint32_t v_zero, // + 1 (!!) const int count ) { + const half* h_ptr = h_.item_ptr(h_row, h_column); const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); half result = acc; @@ -174,4 +180,115 @@ __device__ __forceinline__ half dot_product_8_h return result; } +// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map + +__device__ __forceinline__ half2 dot_product_8_x_map +( + const half2 acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half2 v_scale_2, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half2 result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half2 v_01 = __halves2half2(v_0, v_1); + half2 v_23 = __halves2half2(v_2, v_3); + half2 v_45 = __halves2half2(v_4, v_5); + half2 v_67 = __halves2half2(v_6, v_7); + + half h_0 = h_ptr[*x_map_ptr++]; + half h_1 = h_ptr[*x_map_ptr++]; + half h_2 = h_ptr[*x_map_ptr++]; + half h_3 = h_ptr[*x_map_ptr++]; + half h_4 = h_ptr[*x_map_ptr++]; + half h_5 = h_ptr[*x_map_ptr++]; + half h_6 = h_ptr[*x_map_ptr++]; + half h_7 = h_ptr[*x_map_ptr++]; + + half2 h_01 = __halves2half2(h_0, h_1); + half2 h_23 = __halves2half2(h_2, h_3); + half2 h_45 = __halves2half2(h_4, h_5); + half2 h_67 = __halves2half2(h_6, h_7); + + half2 tmp = __hmul2(h_01, v_01); + tmp = __hfma2(h_23, v_23, tmp); + tmp = __hfma2(h_45, v_45, tmp); + tmp = __hfma2(h_67, v_67, tmp); + result = __hfma2(v_scale_2, tmp, result); + } + + return result; +} + +__device__ __forceinline__ half dot_product_8_x_map_h +( + const half acc, + MatrixView_half& h_, + const int h_row, + const int h_column, // divisible by 8 + MatrixView_q4_column& v_, + const int v_row, // divisible by 8 + const int v_column, + const half v_scale, + const uint32_t v_zero, // + 1 (!!) + const int count, + const uint32_t* x_map +) +{ + const half* h_ptr = h_.item_ptr(h_row, 0); + const uint32_t* x_map_ptr = x_map + h_column; + const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column); + half result = acc; + + for (int i = 0; i < count; i++) + { + uint32_t v_read = *v_ptr; v_ptr += v_.width; + + half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero); + half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero); + half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero); + half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero); + half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero); + half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero); + half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero); + half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero); + + half tmp = __hmul(h_ptr[*x_map_ptr++], v_0); + tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp); + tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp); + result = __hfma(v_scale, tmp, result); + } + + return result; +} + #endif diff --git a/autogptq_extension/qigen/generate.py b/autogptq_extension/qigen/generate.py index ac7a8cd..724c891 100644 --- a/autogptq_extension/qigen/generate.py +++ b/autogptq_extension/qigen/generate.py @@ -1162,7 +1162,7 @@ def unpack_zeros(bits): res += f"void unpack_zeros{bits}_cpu(const int* zv, float* ov, int n, int m)" packed = 32//bits mask = (2**bits)-1 - res += "{\n" + res += "{\nconst __m256i ones = _mm256_set1_epi32(1);\n" res += f"const __m256i mask = _mm256_set1_epi32({mask});\n" if bits == 4: res += "const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0);\n" @@ -1179,14 +1179,15 @@ def unpack_zeros(bits): res += "__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]);\n" res += "__m256i z0 = _mm256_srlv_epi32(z, shift);\n" res += "__m256i z1 = _mm256_and_si256(z0, mask);\n" - res += "__m256 z2 = _mm256_cvtepi32_ps(z1);\n" - res += "_mm256_storeu_ps(&ov[i*m +j], z2);\n" + res += "__m256i z2 = _mm256_add_epi32(z1, ones);\n" + res += "__m256 z3 = _mm256_cvtepi32_ps(z2);\n" + res += "_mm256_storeu_ps(&ov[i*m +j], z3);\n" elif bits == 2: res += f"for (int j = 0; j < m; j+={packed})" res += "{\n" res += f"for (int k = 0; k < {packed}; k++)" res += "{\n" - res += f"ov[i*m + j+k] = ((zv[j/{packed}] >> ({bits}*k)) & {mask});\n" + res += f"ov[i*m + j+k] = (((zv[j/{packed}] >> ({bits}*k)) & {mask})+1);\n" res += "}\n" # res += "for(int j = 0; j < m; j+=16){\n" # res += "__m256i z = _mm256_set1_epi32(zv[i*m/16 + j/16]);\n"