Merge pull request #325 from qwopqwop200/main
remove an unnecessary line (zeors -= 1) to make disable 'sym' feature truely possible
This commit is contained in:
commit
ac23d6b819
14 changed files with 237 additions and 325 deletions
|
@ -39,7 +39,7 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
damp_percent: float = field(default=0.01)
|
||||
desc_act: bool = field(default=True)
|
||||
static_groups: bool = field(default=False)
|
||||
sym: bool = field(default=True)
|
||||
sym: bool = field(default=False)
|
||||
true_sequential: bool = field(default=True)
|
||||
model_name_or_path: Optional[str] = field(default=None)
|
||||
model_file_base_name: Optional[str] = field(default=None)
|
||||
|
@ -967,6 +967,27 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
checkpoint
|
||||
)
|
||||
model.load_state_dict(checkpoint)
|
||||
# Preprocessing for backward compatibility
|
||||
if quantize_config.sym:
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, disable_exllama=disable_exllama, use_qigen=use_qigen,
|
||||
desc_act=quantize_config.desc_act, group_size=quantize_config.group_size, bits=quantize_config.bits)
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, QuantLinear):
|
||||
if use_qigen:
|
||||
submodule.zeros.data = torch.full_like(submodule.zeros.data, (torch.tensor(2 ** quantize_config.bits - 1) + 1) / 2)
|
||||
else:
|
||||
if quantize_config.bits == 2:
|
||||
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -1431655766)
|
||||
elif quantize_config.bits == 3:
|
||||
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] = 613566756
|
||||
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] = 1227133513
|
||||
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] = -1840700270
|
||||
elif quantize_config.bits == 4:
|
||||
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2004318072)
|
||||
elif quantize_config.bits == 8:
|
||||
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2139062144)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
# == step4: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
|
|
|
@ -8,6 +8,8 @@ from transformers.models.gptj.modeling_gptj import GPTJAttention
|
|||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
|
||||
dim = x.shape[-1]
|
||||
|
@ -240,8 +242,13 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
**kwargs
|
||||
):
|
||||
config = model.config
|
||||
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||
|
||||
if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||
return False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, GPTJAttention):
|
||||
continue
|
||||
|
@ -257,11 +264,7 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
if desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
else:
|
||||
g_idx = None
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
|
||||
|
@ -298,6 +301,6 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
|
||||
setattr(parent, child_name, attn)
|
||||
del m
|
||||
|
||||
return True
|
||||
|
||||
__all__ = ["FusedGPTJAttentionForQuantizedModel"]
|
||||
|
|
|
@ -7,6 +7,8 @@ from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotar
|
|||
from ._fused_base import FusedBaseAttentionModule
|
||||
from ..utils.import_utils import compare_pytorch_version, dynamically_import_QuantLinear
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
@ -142,8 +144,15 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
"""
|
||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||
if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||
return False
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, LlamaAttention):
|
||||
continue
|
||||
|
@ -157,13 +166,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
|
||||
|
||||
if QuantLinear.QUANT_TYPE == "exllama":
|
||||
if desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
raise ValueError("Exllama kernel does not support query/key/value fusion with act-order. Please either use inject_fused_attention=False or disable_exllama=True.")
|
||||
else:
|
||||
g_idx = None
|
||||
g_idx = None
|
||||
else:
|
||||
g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
|
||||
|
||||
|
@ -198,6 +201,7 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
child_name = name
|
||||
|
||||
setattr(parent, child_name, attn)
|
||||
return True
|
||||
|
||||
|
||||
__all__ = ["FusedLlamaAttentionForQuantizedModel"]
|
||||
|
|
|
@ -157,7 +157,6 @@ class QuantLinear(nn.Module):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
@ -221,7 +220,6 @@ class QuantLinear(nn.Module):
|
|||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(self.scales.shape)
|
||||
|
||||
weight = torch.bitwise_right_shift(
|
||||
|
@ -239,7 +237,6 @@ class QuantLinear(nn.Module):
|
|||
zeros = zeros & 0x7
|
||||
zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(self.scales.shape)
|
||||
|
||||
weight = self.qweight.reshape(
|
||||
|
|
|
@ -157,7 +157,6 @@ class QuantLinear(nn.Module):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
@ -231,7 +230,6 @@ class QuantLinear(nn.Module):
|
|||
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
||||
scales = self.scales
|
||||
|
@ -248,7 +246,6 @@ class QuantLinear(nn.Module):
|
|||
zeros = zeros & 0x7
|
||||
zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
||||
scales = self.scales
|
||||
|
|
|
@ -146,7 +146,6 @@ class QuantLinear(nn.Module):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
|
|
@ -114,7 +114,6 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
|
|
@ -144,7 +144,6 @@ def quant_matmul_248_kernel(
|
|||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
@ -290,7 +289,6 @@ def transpose_quant_matmul_248_kernel(
|
|||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
|
|
@ -30,9 +30,9 @@
|
|||
// }
|
||||
// #endif
|
||||
|
||||
|
||||
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700) || defined(USE_ROCM)
|
||||
// adapted from https://github.com/torch/cutorch/blob/master/lib/THC/THCAtomics.cuh
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(c10::Half* address, c10::Half val) {
|
||||
unsigned int *address_as_ui = reinterpret_cast<unsigned int *>(reinterpret_cast<char *>(address) - (reinterpret_cast<size_t>(address) & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
|
@ -77,7 +77,7 @@ __global__ void VecQuant2MatMulKernel(
|
|||
const int* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width
|
||||
|
@ -92,7 +92,7 @@ __global__ void VecQuant3MatMulKernel(
|
|||
const int* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width
|
||||
|
@ -113,6 +113,7 @@ __global__ void VecQuant4MatMulKernel(
|
|||
int zero_width
|
||||
);
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void VecQuant8MatMulKernel(
|
||||
const scalar_t* __restrict__ vec,
|
||||
|
@ -122,7 +123,7 @@ __global__ void VecQuant8MatMulKernel(
|
|||
const int* __restrict__ zeros,
|
||||
const int* __restrict__ g_idx,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width
|
||||
|
@ -136,7 +137,7 @@ __global__ void VecQuant2MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -151,7 +152,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -166,7 +167,7 @@ __global__ void VecQuant4MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -181,7 +182,7 @@ __global__ void VecQuant8MatMulKernel_old(
|
|||
const scalar_t* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -209,7 +210,7 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
const float* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -223,7 +224,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
const float* __restrict__ scales,
|
||||
const int* __restrict__ zeros,
|
||||
int batch,
|
||||
int vec_height,
|
||||
int vec_height,
|
||||
int height,
|
||||
int width,
|
||||
int zero_width,
|
||||
|
@ -270,7 +271,7 @@ void vecquant2matmul_cuda(
|
|||
vec.type(), "vecquant2matmul_cuda", ([&] {
|
||||
VecQuant2MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -293,39 +294,39 @@ __global__ void VecQuant2MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT2 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = h * 16;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
int z_w = w / 16;
|
||||
|
||||
int z_w = w / 16;
|
||||
int z_mod = (w % 16) * 2;
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 16);
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 16);
|
||||
int k_bit = (k % 16) * 2;
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
|
||||
|
||||
scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
|
||||
|
||||
|
||||
weight[k] = scale * (w_tmp - zero);
|
||||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -357,7 +358,7 @@ void vecquant3matmul_cuda(
|
|||
vec.type(), "vecquant3matmul_cuda", ([&] {
|
||||
VecQuant3MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -380,15 +381,15 @@ __global__ void VecQuant3MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT3 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = (h / 3) * 32;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
int z_mod = w % 32;
|
||||
int z_bit;
|
||||
unsigned int z_tmp;
|
||||
|
@ -412,14 +413,14 @@ __global__ void VecQuant3MatMulKernel(
|
|||
z_w += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 32) * 3;
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 32) * 3;
|
||||
int k_mod = k % 32;
|
||||
int k_bit;
|
||||
|
||||
|
||||
if (k_mod != 10){
|
||||
if (k_mod != 21){
|
||||
k_bit = k_mod;
|
||||
|
@ -440,20 +441,20 @@ __global__ void VecQuant3MatMulKernel(
|
|||
k_w += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero;
|
||||
if (z_mod == 10) {
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
|
||||
zero = scalar_t((z_tmp) + 1);
|
||||
zero = scalar_t(z_tmp);
|
||||
} else if (z_mod == 21){
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
|
||||
zero = scalar_t((z_tmp) + 1);
|
||||
zero = scalar_t(z_tmp);
|
||||
} else {
|
||||
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
|
||||
zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
|
||||
}
|
||||
|
||||
|
||||
if (k_mod == 10) {
|
||||
w_tmp = (as_unsigned(mat[i + (k_w * width)]) >> 30) | ((as_unsigned(mat[i + ((k_w + 1)* width)]) << 2) & 0x4);
|
||||
} else if (k_mod == 21){
|
||||
|
@ -465,12 +466,12 @@ __global__ void VecQuant3MatMulKernel(
|
|||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -502,7 +503,7 @@ void vecquant4matmul_cuda(
|
|||
vec.type(), "vecquant4matmul_cuda", ([&] {
|
||||
VecQuant4MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -525,40 +526,40 @@ __global__ void VecQuant4MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = h * 8;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
|
||||
|
||||
int z_w = w / 8;
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 8);
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 8);
|
||||
int k_bit = (k % 8) * 4;
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
|
||||
|
||||
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
|
||||
|
||||
|
||||
weight[k] = scale * (w_tmp - zero);
|
||||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -590,7 +591,7 @@ void vecquant8matmul_cuda(
|
|||
vec.type(), "vecquant8matmul_cuda", ([&] {
|
||||
VecQuant8MatMulKernel<<<blocks, threads>>>(
|
||||
vec.data<scalar_t>(), mat.data<int>(), mul.data<scalar_t>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
scales.data<scalar_t>(), zeros.data<int>(), g_idx.data<int>(),
|
||||
batch, vec_height, height, width, zero_width
|
||||
);
|
||||
})
|
||||
|
@ -613,39 +614,39 @@ __global__ void VecQuant8MatMulKernel(
|
|||
) {
|
||||
int h = BLOCKHEIGHT8 * blockIdx.x;
|
||||
int w = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
|
||||
__shared__ scalar_t blockvec[BLOCKWIDTH];
|
||||
int i = width * h + w;
|
||||
int g_h = h * 4;
|
||||
int k;
|
||||
unsigned int g;
|
||||
scalar_t w_tmp;
|
||||
|
||||
int z_w = w / 4;
|
||||
|
||||
int z_w = w / 4;
|
||||
int z_mod = (w % 4) * 8;
|
||||
|
||||
|
||||
float weight[BLOCKWIDTH];
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 4);
|
||||
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
int k_w = (k / 4);
|
||||
int k_bit = (k % 4) * 8;
|
||||
|
||||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
|
||||
|
||||
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
|
||||
|
||||
|
||||
weight[k] = scale * (w_tmp - zero);
|
||||
}
|
||||
|
||||
scalar_t res;
|
||||
for (int b = 0; b < batch; ++b){
|
||||
for (int b = 0; b < batch; ++b){
|
||||
res = 0;
|
||||
|
||||
|
||||
blockvec[threadIdx.x] = vec[b * vec_height + blockIdx.x * BLOCKWIDTH + threadIdx.x];
|
||||
__syncthreads();
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
for (k = 0; k < BLOCKWIDTH; ++k){
|
||||
res += weight[k] * blockvec[k];
|
||||
}
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
|
@ -712,19 +713,19 @@ __global__ void VecQuant2MatMulKernel_old(
|
|||
int i = width * h + w;
|
||||
int g_h = h * 16;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 16;
|
||||
|
||||
int z_w = w / 16;
|
||||
int z_mod = (w % 16) * 2;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
|
||||
|
||||
scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 4) & 0x3) - zero) * blockvec[k + 2];
|
||||
|
@ -741,7 +742,7 @@ __global__ void VecQuant2MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp >> 26) & 0x3) - zero) * blockvec[k + 13];
|
||||
res += (scale * scalar_t((tmp >> 28) & 0x3) - zero) * blockvec[k + 14];
|
||||
res += (scale * scalar_t((tmp >> 30) & 0x3) - zero) * blockvec[k + 15];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 16;
|
||||
}
|
||||
|
@ -807,11 +808,11 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
int i = width * h + w;
|
||||
int g_h = (h / 3) * 32;
|
||||
int k = 0;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
int z_mod = w % 32;
|
||||
int z_bit;
|
||||
|
||||
|
||||
if (z_mod != 10){
|
||||
if (z_mod != 21){
|
||||
z_bit = z_mod;
|
||||
|
@ -832,7 +833,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
z_w += 1;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
unsigned int tmp1;
|
||||
unsigned int tmp2;
|
||||
unsigned int tmp;
|
||||
|
@ -840,20 +841,20 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero;
|
||||
if (z_mod == 10) {
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
|
||||
zero = scale * scalar_t((z_tmp) + 1);
|
||||
zero = scale * scalar_t(z_tmp);
|
||||
} else if (z_mod == 21){
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
|
||||
zero = scale * scalar_t((z_tmp) + 1);
|
||||
zero = scale * scalar_t(z_tmp);
|
||||
} else {
|
||||
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
|
||||
zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
|
||||
}
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
|
@ -864,14 +865,14 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
|
||||
|
||||
i += width;
|
||||
tmp2 = as_unsigned(mat[i]);
|
||||
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x4);
|
||||
tmp2 >>= 1;
|
||||
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
||||
k += 11;
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp2 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp2 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp2 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
|
@ -882,14 +883,14 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp2 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp2 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp2 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
|
||||
|
||||
i += width;
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
tmp = (tmp2 >> 30) | ((tmp1 << 1) & 0x6);
|
||||
tmp1 >>= 2;
|
||||
res += (scale * scalar_t(tmp) - zero) * blockvec[k + 10];
|
||||
k += 11;
|
||||
|
||||
|
||||
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp1 >> 3) & 0x7) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp1 >> 6) & 0x7) - zero) * blockvec[k + 2];
|
||||
|
@ -900,7 +901,7 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp1 >> 21) & 0x7) - zero) * blockvec[k + 7];
|
||||
res += (scale * scalar_t((tmp1 >> 24) & 0x7) - zero) * blockvec[k + 8];
|
||||
res += (scale * scalar_t((tmp1 >> 27) & 0x7) - zero) * blockvec[k + 9];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 10;
|
||||
}
|
||||
|
@ -967,18 +968,18 @@ __global__ void VecQuant4MatMulKernel_old(
|
|||
int g_h = h * 8;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 8;
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
|
||||
|
||||
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0xF) - zero) * blockvec[k + 2];
|
||||
|
@ -987,7 +988,7 @@ __global__ void VecQuant4MatMulKernel_old(
|
|||
res += (scale * scalar_t((tmp >> 20) & 0xF) - zero) * blockvec[k + 5];
|
||||
res += (scale * scalar_t((tmp >> 24) & 0xF) - zero) * blockvec[k + 6];
|
||||
res += (scale * scalar_t((tmp >> 28) & 0xF) - zero) * blockvec[k + 7];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 8;
|
||||
}
|
||||
|
@ -1053,24 +1054,24 @@ __global__ void VecQuant8MatMulKernel_old(
|
|||
int i = width * h + w;
|
||||
int g_h = h * 4;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 4;
|
||||
|
||||
int z_w = w / 4;
|
||||
int z_mod = (w % 4) * 8;
|
||||
|
||||
unsigned int tmp;
|
||||
|
||||
while (k < BLOCKWIDTH) {
|
||||
while (k < BLOCKWIDTH) {
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
||||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
|
||||
|
||||
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
|
||||
res += (scale * scalar_t((tmp >> 16) & 0xFF) - zero) * blockvec[k + 2];
|
||||
res += (scale * scalar_t((tmp >> 24) & 0xFF) - zero) * blockvec[k + 3];
|
||||
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
@ -1092,7 +1093,7 @@ void vecquant2matmul_faster_cuda_old(
|
|||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
int zero_width = zeros.size(1);
|
||||
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT2 - 1) / BLOCKHEIGHT2,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
|
@ -1144,8 +1145,8 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
int i = width * h + w;
|
||||
int g_h = h * 16;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 16;
|
||||
|
||||
int z_w = w / 16;
|
||||
int z_mod = (w % 16) * 2;
|
||||
|
||||
float res = 0;
|
||||
|
@ -1159,8 +1160,8 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
int g = (g_h + (k * 2)) / groupsize;
|
||||
float scale_f = scales[g * width + w];
|
||||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
|
||||
|
||||
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3)));
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xf][off], scale, zero), blockvec[k + 0], res2);
|
||||
|
@ -1192,7 +1193,7 @@ void vecquant3matmul_faster_cuda_old(
|
|||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
int zero_width = zeros.size(1);
|
||||
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT3 - 1) / BLOCKHEIGHT3,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
|
@ -1244,11 +1245,11 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
int i = width * h + w;
|
||||
int g_h = (h / 3) * 32;
|
||||
int k = 0;
|
||||
|
||||
|
||||
int z_w = (w / 32) * 3;
|
||||
int z_mod = w % 32;
|
||||
int z_bit;
|
||||
|
||||
|
||||
if (z_mod != 10){
|
||||
if (z_mod != 21){
|
||||
z_bit = z_mod;
|
||||
|
@ -1287,14 +1288,14 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
half2 zero;
|
||||
if (z_mod == 10) {
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
|
||||
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
|
||||
zero = __float2half2_rn(-(scale_f * z_tmp));
|
||||
} else if (z_mod == 21){
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
|
||||
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
|
||||
zero = __float2half2_rn(-(scale_f * z_tmp));
|
||||
} else {
|
||||
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
|
||||
zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7)));
|
||||
}
|
||||
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
|
||||
|
@ -1344,7 +1345,7 @@ void vecquant4matmul_faster_cuda_old(
|
|||
int height = mat.size(0);
|
||||
int width = mat.size(1);
|
||||
int zero_width = zeros.size(1);
|
||||
|
||||
|
||||
dim3 blocks(
|
||||
(height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
|
||||
(width + BLOCKWIDTH - 1) / BLOCKWIDTH,
|
||||
|
@ -1397,7 +1398,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
int g_h = h * 8;
|
||||
int k = 0;
|
||||
|
||||
int z_w = w / 8;
|
||||
int z_w = w / 8;
|
||||
int z_mod = (w % 8) * 4;
|
||||
|
||||
float res = 0;
|
||||
|
@ -1410,14 +1411,9 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
while (k < blockwidth2) {
|
||||
int g = (g_h + (k * 2)) / groupsize;
|
||||
float scale_f = scales[g * width + w];
|
||||
|
||||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
|
||||
|
||||
//std::memset(&res2, 0, sizeof(half2));
|
||||
|
||||
//res2 = __float2half2_rn((float)0.);
|
||||
|
||||
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF)));
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scale, zero), blockvec[k + 0], res2);
|
||||
|
@ -1426,10 +1422,8 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scale, zero), blockvec[k + 3], res2);
|
||||
i += width;
|
||||
k += 4;
|
||||
|
||||
res += __low2float(res2) + __high2float(res2);
|
||||
|
||||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
}
|
||||
}
|
|
@ -313,7 +313,7 @@ __global__ void VecQuant2MatMulKernel(
|
|||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
|
||||
scalar_t zero = scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0x3);
|
||||
|
||||
|
@ -447,12 +447,12 @@ __global__ void VecQuant3MatMulKernel(
|
|||
scalar_t zero;
|
||||
if (z_mod == 10) {
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
|
||||
zero = scalar_t((z_tmp) + 1);
|
||||
zero = scalar_t(z_tmp);
|
||||
} else if (z_mod == 21){
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
|
||||
zero = scalar_t((z_tmp) + 1);
|
||||
zero = scalar_t(z_tmp);
|
||||
} else {
|
||||
zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
|
||||
zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
|
||||
}
|
||||
|
||||
if (k_mod == 10) {
|
||||
|
@ -546,7 +546,7 @@ __global__ void VecQuant4MatMulKernel(
|
|||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
|
||||
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xF);
|
||||
|
||||
|
@ -633,7 +633,7 @@ __global__ void VecQuant8MatMulKernel(
|
|||
|
||||
g = as_int(g_idx[g_h + k]);
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
|
||||
scalar_t zero = scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
|
||||
|
||||
w_tmp = ((as_unsigned(mat[i + (k_w * width)]) >> k_bit) & 0xFF);
|
||||
|
||||
|
@ -724,7 +724,7 @@ __global__ void VecQuant2MatMulKernel_old(
|
|||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3) + 1);
|
||||
scalar_t zero = scale * scalar_t(as_unsigned(zeros[g * zero_width + z_w]) >> z_mod & 0x3);
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0x3) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 2) & 0x3) - zero) * blockvec[k + 1];
|
||||
|
@ -847,12 +847,12 @@ __global__ void VecQuant3MatMulKernel_old(
|
|||
scalar_t zero;
|
||||
if (z_mod == 10) {
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
|
||||
zero = scale * scalar_t((z_tmp) + 1);
|
||||
zero = scale * scalar_t(z_tmp);
|
||||
} else if (z_mod == 21){
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
|
||||
zero = scale * scalar_t((z_tmp) + 1);
|
||||
zero = scale * scalar_t(z_tmp);
|
||||
} else {
|
||||
zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1);
|
||||
zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7);
|
||||
}
|
||||
|
||||
res += (scale * scalar_t((tmp1 >> 0) & 0x7) - zero) * blockvec[k + 0];
|
||||
|
@ -978,7 +978,7 @@ __global__ void VecQuant4MatMulKernel_old(
|
|||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1);
|
||||
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF);
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 4) & 0xF) - zero) * blockvec[k + 1];
|
||||
|
@ -1065,7 +1065,7 @@ __global__ void VecQuant8MatMulKernel_old(
|
|||
|
||||
int g = (g_h + k) / groupsize;
|
||||
scalar_t scale = scales[g * width + w];
|
||||
scalar_t zero = scale * scalar_t(((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF) + 1);
|
||||
scalar_t zero = scale * scalar_t((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xFF);
|
||||
|
||||
res += (scale * scalar_t((tmp >> 0) & 0xFF) - zero) * blockvec[k + 0];
|
||||
res += (scale * scalar_t((tmp >> 8) & 0xFF) - zero) * blockvec[k + 1];
|
||||
|
@ -1160,7 +1160,7 @@ __global__ void VecQuant2MatMulKernelFaster_old(
|
|||
int g = (g_h + (k * 2)) / groupsize;
|
||||
float scale_f = scales[g * width + w];
|
||||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3) + 1)));
|
||||
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0x3)));
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
@ -1288,12 +1288,12 @@ __global__ void VecQuant3MatMulKernelFaster_old(
|
|||
half2 zero;
|
||||
if (z_mod == 10) {
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 30) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 2) & 0x4);
|
||||
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
|
||||
zero = __float2half2_rn(-(scale_f * z_tmp));
|
||||
} else if (z_mod == 21){
|
||||
z_tmp = (as_unsigned(zeros[g * zero_width + z_w]) >> 31) | ((as_unsigned(zeros[g * zero_width + (z_w + 1)]) << 1) & 0x6);
|
||||
zero = __float2half2_rn(-(scale_f * ((z_tmp) + 1)));
|
||||
zero = __float2half2_rn(-(scale_f * z_tmp));
|
||||
} else {
|
||||
zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7) + 1)));
|
||||
zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_bit) & 0x7)));
|
||||
}
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
|
@ -1412,7 +1412,7 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
int g = (g_h + (k * 2)) / groupsize;
|
||||
float scale_f = scales[g * width + w];
|
||||
half2 scale = __float2half2_rn(scale_f);
|
||||
half2 zero = __float2half2_rn(-(scale_f * (((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF) + 1)));
|
||||
half2 zero = __float2half2_rn(-(scale_f * ((as_unsigned(zeros[g * zero_width + z_w]) >> z_mod) & 0xF)));
|
||||
|
||||
std::memset(&res2, 0, sizeof(half2));
|
||||
tmp = as_unsigned(mat[i]);
|
||||
|
@ -1426,4 +1426,4 @@ __global__ void VecQuant4MatMulKernelFaster_old(
|
|||
}
|
||||
|
||||
atomicAdd(&mul[b * width + w], res);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
|
||||
|
||||
#include "q4_matmul.cuh"
|
||||
#include "column_remap.cuh"
|
||||
|
@ -13,6 +13,8 @@
|
|||
const int THREADS_X = 32; // Block size and thread count along columns in w and out
|
||||
const int THREADS_Y = 1; // Block size and thread count along rows in x and out
|
||||
|
||||
const int GROUP_STEP = 32; // Assumed group size when block_size_z % groupsize != 0
|
||||
|
||||
typedef void (*fp_q4_matmul_kernel)
|
||||
(
|
||||
const half*,
|
||||
|
@ -46,12 +48,15 @@ __global__ void q4_matmul_kernel
|
|||
bool no_zero
|
||||
)
|
||||
{
|
||||
extern __shared__ half2 x_cache[];
|
||||
half* x_cache_h = (half*)x_cache;
|
||||
|
||||
// Start of block
|
||||
|
||||
int x_column = block_size_z * blockIdx.z;
|
||||
int x_column_end = min(dim, block_size_z * (blockIdx.z + 1));
|
||||
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
int w_column = THREADS_X * blockIdx.x + threadIdx.x; // assume width of weight matrix divisible by THREADS_X
|
||||
int x_row = THREADS_Y * blockIdx.y + threadIdx.y;
|
||||
|
||||
int iterations = (x_column_end - x_column) / 8;
|
||||
|
@ -69,8 +74,8 @@ __global__ void q4_matmul_kernel
|
|||
if (!no_zero && blockIdx.z == 0 && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
*((uint32_t*) out_.item_ptr(x_row, w_column)) = 0;
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Loop over part of x row (and w column)
|
||||
|
||||
|
@ -84,48 +89,56 @@ __global__ void q4_matmul_kernel
|
|||
|
||||
for (int k = x_column, group = x_column / groupsize; k < x_column + iterations * 8; group++, k += groupsize)
|
||||
{
|
||||
for (int i = threadIdx.x; i < groupsize; i += THREADS_X)
|
||||
{
|
||||
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
|
||||
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, groupsize / 8);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise assume groupsize is a multiple of 8, do 8 columns per iteration and trust the cache
|
||||
// Otherwise assume groupsize is a multiple of GROUP_STEP, do GROUP_STEP columns per iteration and trust the cache
|
||||
|
||||
for (int k = x_column; k < x_column + iterations * 8; k += 8)
|
||||
for (int k = x_column; k < x_column + iterations * 8; k += GROUP_STEP)
|
||||
{
|
||||
for (int i = threadIdx.x; i < GROUP_STEP; i += THREADS_X)
|
||||
{
|
||||
if constexpr (use_x_map) x_cache_h[i] = *x_.item_ptr(x_row, x_map[k + i]);
|
||||
else x_cache_h[i] = *x_.item_ptr(x_row, k + i);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half2 w_scale = w_scales_.item_half2half2(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc = dot_product_8_x_map(acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc = dot_product_8 (acc, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||
acc = dot_product_8(acc, x_cache, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
|
||||
}
|
||||
else
|
||||
{
|
||||
int group = k / groupsize;
|
||||
half w_scale = w_scales_.item(group, w_column);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column) + 1;
|
||||
|
||||
if constexpr (use_x_map) acc_h = dot_product_8_x_map_h(acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1, x_map);
|
||||
else acc_h = dot_product_8_h (acc_h, x_, x_row, k, w_, k, w_column, w_scale, w_zero, 1);
|
||||
uint32_t w_zero = w_zeros_.item(group, w_column);
|
||||
acc_h = dot_product_8_h(acc_h, x_cache_h, w_, k, w_column, w_scale, w_zero, GROUP_STEP / 8);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -133,7 +146,7 @@ __global__ void q4_matmul_kernel
|
|||
|
||||
if constexpr (use_half2)
|
||||
{
|
||||
half result = __hadd(__low2half(acc), __high2half(acc));
|
||||
half result = __hadd(acc.x, acc.y);
|
||||
atomicAdd(out_.item_ptr(x_row, w_column), result);
|
||||
}
|
||||
else
|
||||
|
@ -215,8 +228,8 @@ void q4_matmul_cuda
|
|||
);
|
||||
|
||||
fp_q4_matmul_kernel kernel = q4_matmul_kernel_pick(tuningParams, block_size_z, w->groupsize, x_map);
|
||||
|
||||
kernel<<<blocks, threads, 0, alt_stream>>> (x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||
int shared_mem = (block_size_z % w->groupsize == 0 ? w->groupsize : GROUP_STEP) * sizeof(half);
|
||||
kernel<<<blocks, threads, shared_mem, alt_stream>>>(x_mapped, w->cuda_qweight, out, w->cuda_scales, w->cuda_qzeros, height, dim, width, w->groupsize, block_size_z, x_map, no_zero);
|
||||
}
|
||||
|
||||
void q4_matmul_recons_cuda
|
||||
|
@ -240,7 +253,7 @@ void q4_matmul_recons_cuda
|
|||
const half* x_mapped = x;
|
||||
if (w->cuda_x_map)
|
||||
{
|
||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "The temp_state buffer is too small in the exllama backend. Please call the exllama_set_max_input_length function to increase the buffer size. Example:\nfrom auto_gptq import exllama_set_max_input_length\nmodel = exllama_set_max_input_length(model, 4096)");
|
||||
TORCH_CHECK(buffers->temp_state_size >= x_height * dim, "temp_state buffer is too small");
|
||||
column_remap_cuda(x, buffers->temp_state, x_height, dim, w->cuda_x_map);
|
||||
x_mapped = buffers->temp_state;
|
||||
}
|
||||
|
@ -248,13 +261,18 @@ void q4_matmul_recons_cuda
|
|||
w->reconstruct(buffers->temp_dq);
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 700
|
||||
|
||||
const float alpha = 1.0f;
|
||||
const float beta = no_zero ? 1.0f : 0.0f;
|
||||
cublasSgemmEx(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, CUDA_R_16F, width,
|
||||
x_mapped, CUDA_R_16F, dim, &beta, out, CUDA_R_16F, width);
|
||||
|
||||
#else
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = no_zero ? __float2half(1.0f) : __float2half(0.0f);
|
||||
cublasHgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, width, height, dim, &alpha, buffers->temp_dq, width, x_mapped, dim, &beta, out, width);
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
|
|
@ -197,7 +197,7 @@ __global__ void reconstruct_kernel
|
|||
int group = row / groupsize;
|
||||
|
||||
half w_scale = w_scales_.item(group, column);
|
||||
uint32_t w_zero = w_zeros_.item(group, column) + 1;
|
||||
uint32_t w_zero = w_zeros_.item(group, column);
|
||||
|
||||
uint32_t w_read = w_.item_uint32_t(row, column);
|
||||
half* out_ptr = out_.item_ptr(row, column);
|
||||
|
@ -222,4 +222,4 @@ void Q4Matrix::reconstruct(half* out)
|
|||
);
|
||||
|
||||
reconstruct_kernel<<<blocks, threads>>>(cuda_qweight, out, cuda_scales, cuda_qzeros, height / 8, width, groupsize);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -87,18 +87,15 @@ public:
|
|||
__device__ __forceinline__ half2 dot_product_8
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
const half2* h_ptr,
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const uint32_t v_zero,
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
|
@ -138,18 +135,15 @@ __device__ __forceinline__ half2 dot_product_8
|
|||
__device__ __forceinline__ half dot_product_8_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
const half* h_ptr,
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const uint32_t v_zero,
|
||||
const int count
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, h_column);
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
|
@ -180,115 +174,4 @@ __device__ __forceinline__ half dot_product_8_h
|
|||
return result;
|
||||
}
|
||||
|
||||
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
|
||||
|
||||
__device__ __forceinline__ half2 dot_product_8_x_map
|
||||
(
|
||||
const half2 acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half2 v_scale_2,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half2 result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half2 v_01 = __halves2half2(v_0, v_1);
|
||||
half2 v_23 = __halves2half2(v_2, v_3);
|
||||
half2 v_45 = __halves2half2(v_4, v_5);
|
||||
half2 v_67 = __halves2half2(v_6, v_7);
|
||||
|
||||
half h_0 = h_ptr[*x_map_ptr++];
|
||||
half h_1 = h_ptr[*x_map_ptr++];
|
||||
half h_2 = h_ptr[*x_map_ptr++];
|
||||
half h_3 = h_ptr[*x_map_ptr++];
|
||||
half h_4 = h_ptr[*x_map_ptr++];
|
||||
half h_5 = h_ptr[*x_map_ptr++];
|
||||
half h_6 = h_ptr[*x_map_ptr++];
|
||||
half h_7 = h_ptr[*x_map_ptr++];
|
||||
|
||||
half2 h_01 = __halves2half2(h_0, h_1);
|
||||
half2 h_23 = __halves2half2(h_2, h_3);
|
||||
half2 h_45 = __halves2half2(h_4, h_5);
|
||||
half2 h_67 = __halves2half2(h_6, h_7);
|
||||
|
||||
half2 tmp = __hmul2(h_01, v_01);
|
||||
tmp = __hfma2(h_23, v_23, tmp);
|
||||
tmp = __hfma2(h_45, v_45, tmp);
|
||||
tmp = __hfma2(h_67, v_67, tmp);
|
||||
result = __hfma2(v_scale_2, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ half dot_product_8_x_map_h
|
||||
(
|
||||
const half acc,
|
||||
MatrixView_half& h_,
|
||||
const int h_row,
|
||||
const int h_column, // divisible by 8
|
||||
MatrixView_q4_column& v_,
|
||||
const int v_row, // divisible by 8
|
||||
const int v_column,
|
||||
const half v_scale,
|
||||
const uint32_t v_zero, // + 1 (!!)
|
||||
const int count,
|
||||
const uint32_t* x_map
|
||||
)
|
||||
{
|
||||
const half* h_ptr = h_.item_ptr(h_row, 0);
|
||||
const uint32_t* x_map_ptr = x_map + h_column;
|
||||
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
|
||||
half result = acc;
|
||||
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
uint32_t v_read = *v_ptr; v_ptr += v_.width;
|
||||
|
||||
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
|
||||
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
|
||||
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
|
||||
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
|
||||
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
|
||||
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
|
||||
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
|
||||
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
|
||||
|
||||
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
|
||||
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
|
||||
result = __hfma(v_scale, tmp, result);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -1162,7 +1162,7 @@ def unpack_zeros(bits):
|
|||
res += f"void unpack_zeros{bits}_cpu(const int* zv, float* ov, int n, int m)"
|
||||
packed = 32//bits
|
||||
mask = (2**bits)-1
|
||||
res += "{\nconst __m256i ones = _mm256_set1_epi32(1);\n"
|
||||
res += "{\n"
|
||||
res += f"const __m256i mask = _mm256_set1_epi32({mask});\n"
|
||||
if bits == 4:
|
||||
res += "const __m256i shift = _mm256_set_epi32(28,24,20,16,12,8,4,0);\n"
|
||||
|
@ -1179,15 +1179,14 @@ def unpack_zeros(bits):
|
|||
res += "__m256i z = _mm256_set1_epi32(zv[i*m/8 + j/8]);\n"
|
||||
res += "__m256i z0 = _mm256_srlv_epi32(z, shift);\n"
|
||||
res += "__m256i z1 = _mm256_and_si256(z0, mask);\n"
|
||||
res += "__m256i z2 = _mm256_add_epi32(z1, ones);\n"
|
||||
res += "__m256 z3 = _mm256_cvtepi32_ps(z2);\n"
|
||||
res += "_mm256_storeu_ps(&ov[i*m +j], z3);\n"
|
||||
res += "__m256 z2 = _mm256_cvtepi32_ps(z1);\n"
|
||||
res += "_mm256_storeu_ps(&ov[i*m +j], z2);\n"
|
||||
elif bits == 2:
|
||||
res += f"for (int j = 0; j < m; j+={packed})"
|
||||
res += "{\n"
|
||||
res += f"for (int k = 0; k < {packed}; k++)"
|
||||
res += "{\n"
|
||||
res += f"ov[i*m + j+k] = (((zv[j/{packed}] >> ({bits}*k)) & {mask})+1);\n"
|
||||
res += f"ov[i*m + j+k] = ((zv[j/{packed}] >> ({bits}*k)) & {mask});\n"
|
||||
res += "}\n"
|
||||
# res += "for(int j = 0; j < m; j+=16){\n"
|
||||
# res += "__m256i z = _mm256_set1_epi32(zv[i*m/16 + j/16]);\n"
|
||||
|
|
Loading…
Add table
Reference in a new issue