exllamav2 integration
This commit is contained in:
parent
06e071e68e
commit
c912bf361a
28 changed files with 3355 additions and 18 deletions
|
@ -26,7 +26,7 @@ from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModul
|
||||||
from ..quantization import GPTQ
|
from ..quantization import GPTQ
|
||||||
from ..utils.data_utils import collate_data
|
from ..utils.data_utils import collate_data
|
||||||
from ..utils.import_utils import (
|
from ..utils.import_utils import (
|
||||||
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE
|
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE, EXLLAMAV2_KERNELS_AVAILABLE
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
@ -700,7 +700,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
warmup_triton: bool = False,
|
warmup_triton: bool = False,
|
||||||
trainable: bool = False,
|
trainable: bool = False,
|
||||||
disable_exllama: bool = False,
|
disable_exllama: bool = True,
|
||||||
|
disable_exllamav2: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""load quantized model from local disk"""
|
"""load quantized model from local disk"""
|
||||||
|
@ -743,6 +744,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
"auto_gptq from source."
|
"auto_gptq from source."
|
||||||
)
|
)
|
||||||
disable_exllama = True
|
disable_exllama = True
|
||||||
|
if not disable_exllamav2 and not EXLLAMAV2_KERNELS_AVAILABLE:
|
||||||
|
logger.warning(
|
||||||
|
"Exllamav2 kernel is not installed, reset disable_exllamav2 to True. "
|
||||||
|
"This may because you installed auto_gptq using a pre-build wheel "
|
||||||
|
"on Windows, in which exllama_kernels are not compiled. To use "
|
||||||
|
"exllama_kernels to further speedup inference, you can re-install "
|
||||||
|
"auto_gptq from source."
|
||||||
|
)
|
||||||
|
disable_exllamav2 = True
|
||||||
if not AUTOGPTQ_CUDA_AVAILABLE:
|
if not AUTOGPTQ_CUDA_AVAILABLE:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"CUDA kernels for auto_gptq are not installed, this will result in "
|
"CUDA kernels for auto_gptq are not installed, this will result in "
|
||||||
|
@ -758,6 +768,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
inject_fused_mlp = False
|
inject_fused_mlp = False
|
||||||
use_triton = False
|
use_triton = False
|
||||||
disable_exllama = True
|
disable_exllama = True
|
||||||
|
disable_exllamav2 = True
|
||||||
|
|
||||||
|
if not disable_exllamav2 and not disable_exllama:
|
||||||
|
logger.warning(
|
||||||
|
"You have activated both exllama and exllamav2 kernel. Setting disable_exllama to True and keeping disable_exllamav2 to False"
|
||||||
|
)
|
||||||
|
disable_exllama = True
|
||||||
|
|
||||||
# == step1: prepare configs and file names == #
|
# == step1: prepare configs and file names == #
|
||||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
||||||
|
@ -804,9 +821,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
|
|
||||||
model_save_name = resolved_archive_file
|
model_save_name = resolved_archive_file
|
||||||
|
|
||||||
if not disable_exllama and trainable:
|
if (not disable_exllama or not disable_exllamav2) and trainable:
|
||||||
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
|
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||||
disable_exllama = True
|
disable_exllama = True
|
||||||
|
disable_exllamav2 = True
|
||||||
|
|
||||||
elif not use_triton and trainable:
|
elif not use_triton and trainable:
|
||||||
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
|
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||||
|
@ -853,6 +871,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
quantize_config.group_size,
|
quantize_config.group_size,
|
||||||
use_triton=use_triton,
|
use_triton=use_triton,
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
|
disable_exllamav2=disable_exllamav2,
|
||||||
use_cuda_fp16=use_cuda_fp16,
|
use_cuda_fp16=use_cuda_fp16,
|
||||||
desc_act=quantize_config.desc_act,
|
desc_act=quantize_config.desc_act,
|
||||||
trainable=trainable
|
trainable=trainable
|
||||||
|
@ -926,6 +945,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
quantize_config.group_size,
|
quantize_config.group_size,
|
||||||
use_triton=use_triton,
|
use_triton=use_triton,
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
|
disable_exllamav2=disable_exllamav2,
|
||||||
use_cuda_fp16=use_cuda_fp16,
|
use_cuda_fp16=use_cuda_fp16,
|
||||||
desc_act=quantize_config.desc_act,
|
desc_act=quantize_config.desc_act,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
|
@ -966,6 +986,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
bits=quantize_config.bits,
|
bits=quantize_config.bits,
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
|
disable_exllamav2=disable_exllamav2
|
||||||
)
|
)
|
||||||
if inject_fused_mlp:
|
if inject_fused_mlp:
|
||||||
if cls.fused_mlp_module_type is None:
|
if cls.fused_mlp_module_type is None:
|
||||||
|
|
|
@ -56,13 +56,14 @@ def make_quant(
|
||||||
group_size,
|
group_size,
|
||||||
name='',
|
name='',
|
||||||
use_triton: bool = False,
|
use_triton: bool = False,
|
||||||
disable_exllama: bool = False,
|
disable_exllama: bool = True,
|
||||||
|
disable_exllamav2: bool = False,
|
||||||
use_qigen: bool = False,
|
use_qigen: bool = False,
|
||||||
use_cuda_fp16: bool = True,
|
use_cuda_fp16: bool = True,
|
||||||
desc_act: bool = False,
|
desc_act: bool = False,
|
||||||
trainable: bool = False
|
trainable: bool = False
|
||||||
):
|
):
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, use_qigen=use_qigen)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen)
|
||||||
|
|
||||||
if isinstance(module, QuantLinear):
|
if isinstance(module, QuantLinear):
|
||||||
return
|
return
|
||||||
|
@ -101,6 +102,7 @@ def make_quant(
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
|
disable_exllamav2=disable_exllamav2,
|
||||||
use_qigen=use_qigen
|
use_qigen=use_qigen
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -339,7 +341,31 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in
|
||||||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
|
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
|
||||||
submodule.post_init()
|
submodule.post_init()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
## exllamav2
|
||||||
|
fixed_bytes = {}
|
||||||
|
model_uses_exllamav2 = False
|
||||||
|
|
||||||
|
for _, submodule in model.named_modules():
|
||||||
|
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
|
||||||
|
model_uses_exllamav2 = True
|
||||||
|
device = submodule.qweight.device
|
||||||
|
scratch_fixed = submodule.scratch_space_fixed()
|
||||||
|
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device,0))
|
||||||
|
|
||||||
|
if model_uses_exllamav2:
|
||||||
|
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors
|
||||||
|
device_tensors = {}
|
||||||
|
for device, scratch_bytes in fixed_bytes.items():
|
||||||
|
device_tensors[device] = ExLlamaV2DeviceTensors(device.index, scratch_bytes)
|
||||||
|
|
||||||
|
# have persistent buffers, otherwise we will get OOM
|
||||||
|
model.device_tensors = device_tensors
|
||||||
|
|
||||||
|
for _, submodule in model.named_modules():
|
||||||
|
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
|
||||||
|
device = submodule.qweight.device
|
||||||
|
submodule.post_init(temp_dq = model.device_tensors[device])
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,8 @@ class AutoGPTQForCausalLM:
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
warmup_triton: bool = False,
|
warmup_triton: bool = False,
|
||||||
trainable: bool = False,
|
trainable: bool = False,
|
||||||
disable_exllama: bool = False,
|
disable_exllama: bool = True,
|
||||||
|
disable_exllamav2: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseGPTQForCausalLM:
|
) -> BaseGPTQForCausalLM:
|
||||||
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
|
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
|
||||||
|
@ -123,6 +124,7 @@ class AutoGPTQForCausalLM:
|
||||||
warmup_triton=warmup_triton,
|
warmup_triton=warmup_triton,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
|
disable_exllamav2=disable_exllamav2,
|
||||||
**keywords
|
**keywords
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -235,11 +235,12 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
desc_act=False,
|
desc_act=False,
|
||||||
trainable=False,
|
trainable=False,
|
||||||
bits: int = 4,
|
bits: int = 4,
|
||||||
disable_exllama=False,
|
disable_exllama=True,
|
||||||
|
disable_exllamav2=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
config = model.config
|
config = model.config
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||||
|
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if not isinstance(m, GPTJAttention):
|
if not isinstance(m, GPTJAttention):
|
||||||
|
|
|
@ -135,13 +135,14 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
||||||
desc_act=False,
|
desc_act=False,
|
||||||
trainable=False,
|
trainable=False,
|
||||||
bits: int = 4,
|
bits: int = 4,
|
||||||
disable_exllama=False,
|
disable_exllama=True,
|
||||||
|
disable_exllamav2=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||||
"""
|
"""
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||||
|
|
||||||
for name, m in model.named_modules():
|
for name, m in model.named_modules():
|
||||||
if not isinstance(m, LlamaAttention):
|
if not isinstance(m, LlamaAttention):
|
||||||
|
|
188
auto_gptq/nn_modules/qlinear/qlinear_exllamav2.py
Normal file
188
auto_gptq/nn_modules/qlinear/qlinear_exllamav2.py
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||||
|
except ImportError:
|
||||||
|
logger.error('exllamav2_kernels not installed.')
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||||
|
none_tensor = torch.empty((1, 1), device="meta")
|
||||||
|
|
||||||
|
def _torch_device(idx):
|
||||||
|
if idx == -1: return "cpu"
|
||||||
|
return f"cuda:{idx}"
|
||||||
|
|
||||||
|
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||||
|
"""Matrix multiplication, returns x @ q4"""
|
||||||
|
output_shape = x.shape[:-1] + (q4_width,)
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
|
||||||
|
gemm_half_q_half(x, q_handle, output, force_cuda)
|
||||||
|
return output.view(output_shape)
|
||||||
|
|
||||||
|
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||||
|
"""
|
||||||
|
Create Q matrix
|
||||||
|
"""
|
||||||
|
# EXL2
|
||||||
|
# won't work as the moment because the tensors are not the same.
|
||||||
|
if "q_weight" in w:
|
||||||
|
w["q_scale_max"] /= 256
|
||||||
|
w["q_perm"] = w["q_perm"].short()
|
||||||
|
w["q_invperm"] = w["q_invperm"].short()
|
||||||
|
return make_q_matrix(w["q_weight"],
|
||||||
|
w["q_perm"],
|
||||||
|
w["q_invperm"],
|
||||||
|
w["q_scale"],
|
||||||
|
w["q_scale_max"],
|
||||||
|
w["q_groups"],
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
temp_dq)
|
||||||
|
# GPTQ
|
||||||
|
elif "qweight" in w:
|
||||||
|
if w["scales"].dtype == torch.float:
|
||||||
|
w["scales"] = w["scales"].half()
|
||||||
|
|
||||||
|
# GPTQ with g_idx (act_order)
|
||||||
|
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
|
||||||
|
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
|
||||||
|
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
||||||
|
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||||
|
return make_q_matrix(w["qweight"],
|
||||||
|
w["q_perm"],
|
||||||
|
w["q_invperm"],
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
w["qzeros"],
|
||||||
|
w["scales"],
|
||||||
|
w["g_idx"].cpu(),
|
||||||
|
temp_dq)
|
||||||
|
# GPTQ without g_idx
|
||||||
|
else:
|
||||||
|
return make_q_matrix(w["qweight"],
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
none_tensor,
|
||||||
|
w["qzeros"],
|
||||||
|
w["scales"],
|
||||||
|
none_tensor,
|
||||||
|
temp_dq)
|
||||||
|
|
||||||
|
class QuantLinear(nn.Module):
|
||||||
|
QUANT_TYPE = "exllamav2"
|
||||||
|
|
||||||
|
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||||
|
|
||||||
|
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
if bits != 4:
|
||||||
|
raise ValueError(
|
||||||
|
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
||||||
|
if trainable:
|
||||||
|
raise NotImplementedError("Exllamav2 kernel does not support training.")
|
||||||
|
|
||||||
|
self.q_handle = None
|
||||||
|
self.q_tensors = None
|
||||||
|
self.padding = - outfeatures % 32
|
||||||
|
|
||||||
|
self.infeatures = infeatures
|
||||||
|
self.outfeatures = outfeatures + self.padding
|
||||||
|
self.bits = bits
|
||||||
|
self.group_size = group_size if group_size != -1 else infeatures
|
||||||
|
self.trainable = trainable
|
||||||
|
self.maxq = 2 ** self.bits - 1
|
||||||
|
|
||||||
|
assert infeatures % 32 == 0
|
||||||
|
assert infeatures % self.group_size == 0
|
||||||
|
assert outfeatures % 32 == 0
|
||||||
|
|
||||||
|
# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
||||||
|
self.register_buffer(
|
||||||
|
'qweight',
|
||||||
|
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'qzeros',
|
||||||
|
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'scales',
|
||||||
|
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'g_idx',
|
||||||
|
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||||
|
)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def post_init(self, temp_dq):
|
||||||
|
assert self.qweight.device.type == "cuda"
|
||||||
|
assert self.qweight.device.index is not None
|
||||||
|
self.q_tensors = {
|
||||||
|
"qweight":self.qweight,
|
||||||
|
"qzeros":self.qzeros,
|
||||||
|
"scales":self.scales,
|
||||||
|
"g_idx":self.g_idx
|
||||||
|
}
|
||||||
|
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||||
|
self.q_handle = ext_make_q_matrix(
|
||||||
|
self.q_tensors, temp_dq
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, force_cuda = False):
|
||||||
|
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||||
|
|
||||||
|
if self.bias is not None:
|
||||||
|
output.add_(self.bias)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def temp_dq_size(self):
|
||||||
|
return self.infeatures * self.outfeatures * 2 + 128
|
||||||
|
|
||||||
|
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||||
|
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||||
|
|
||||||
|
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
|
||||||
|
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||||
|
|
||||||
|
|
||||||
|
class ExLlamaV2DeviceTensors:
|
||||||
|
|
||||||
|
device_idx: int
|
||||||
|
scratch_bytes: int
|
||||||
|
scratch_idx: int
|
||||||
|
scratch: torch.tensor = None
|
||||||
|
|
||||||
|
def __init__(self, device_idx, scratch_bytes):
|
||||||
|
self.device_idx = device_idx
|
||||||
|
self.scratch_bytes = scratch_bytes
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
|
||||||
|
|
||||||
|
def get_scratch_slice(self, size_bytes):
|
||||||
|
|
||||||
|
if self.scratch is None: self.prepare()
|
||||||
|
|
||||||
|
size_bytes = ((size_bytes + 127) // 128) * 128
|
||||||
|
size_half = size_bytes // 2
|
||||||
|
scratch_slice = self.scratch.narrow(0, 0, size_half)
|
||||||
|
return scratch_slice
|
|
@ -25,6 +25,13 @@ try:
|
||||||
except:
|
except:
|
||||||
EXLLAMA_KERNELS_AVAILABLE = False
|
EXLLAMA_KERNELS_AVAILABLE = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import exllamav2_kernels
|
||||||
|
|
||||||
|
EXLLAMAV2_KERNELS_AVAILABLE = True
|
||||||
|
except:
|
||||||
|
EXLLAMAV2_KERNELS_AVAILABLE = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import cQIGen as qinfer
|
import cQIGen as qinfer
|
||||||
|
|
||||||
|
@ -35,7 +42,7 @@ except:
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False, use_qigen: bool = False):
|
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = True, disable_exllamav2:bool = False, use_qigen: bool = False):
|
||||||
if use_qigen:
|
if use_qigen:
|
||||||
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
|
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
|
||||||
else:
|
else:
|
||||||
|
@ -45,7 +52,9 @@ def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size:
|
||||||
|
|
||||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||||
else:
|
else:
|
||||||
if bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
|
if bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
|
||||||
|
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
|
||||||
|
elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
|
||||||
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
|
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
|
||||||
elif not desc_act or group_size == -1:
|
elif not desc_act or group_size == -1:
|
||||||
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
|
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
|
||||||
|
|
13
autogptq_extension/exllamav2/config.h
Normal file
13
autogptq_extension/exllamav2/config.h
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
#ifndef _config_h
|
||||||
|
#define _config_h
|
||||||
|
|
||||||
|
#define MAX_Q_GEMM_ROWS 50
|
||||||
|
|
||||||
|
#define QMODE_2BIT 1
|
||||||
|
#define QMODE_3BIT 1
|
||||||
|
#define QMODE_4BIT 1
|
||||||
|
#define QMODE_5BIT 1
|
||||||
|
#define QMODE_6BIT 0
|
||||||
|
#define QMODE_8BIT 0
|
||||||
|
|
||||||
|
#endif
|
12
autogptq_extension/exllamav2/cpp/util.h
Normal file
12
autogptq_extension/exllamav2/cpp/util.h
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
#ifndef _util_h
|
||||||
|
#define _util_h
|
||||||
|
|
||||||
|
#define DBGS(__x) printf("%s\n", __x)
|
||||||
|
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||||
|
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||||
|
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||||
|
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||||
|
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||||
|
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||||
|
|
||||||
|
#endif
|
56
autogptq_extension/exllamav2/cuda/compat.cuh
Normal file
56
autogptq_extension/exllamav2/cuda/compat.cuh
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
#ifndef _compat_cuh
|
||||||
|
#define _compat_cuh
|
||||||
|
|
||||||
|
// atomicAdd for half types, to support CC < 7.x
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||||
|
{
|
||||||
|
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
__half_raw hsum;
|
||||||
|
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||||
|
half tmpres = __hadd(hsum, val);
|
||||||
|
hsum = __half_raw(tmpres);
|
||||||
|
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||||
|
old = atomicCAS(address_as_ui, assumed, old);
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
// atomicAdd for half2 types
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||||
|
{
|
||||||
|
unsigned int* address_as_ui = (unsigned int*)address;
|
||||||
|
unsigned int old = *address_as_ui;
|
||||||
|
unsigned int assumed;
|
||||||
|
do
|
||||||
|
{
|
||||||
|
assumed = old;
|
||||||
|
half2 old_val = *((half2*)&old);
|
||||||
|
half2 new_val = __hadd2(old_val, val);
|
||||||
|
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||||
|
}
|
||||||
|
while (assumed != old);
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
|
||||||
|
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||||
|
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||||
|
|
||||||
|
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||||
|
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
121
autogptq_extension/exllamav2/cuda/matrix_view.cuh
Normal file
121
autogptq_extension/exllamav2/cuda/matrix_view.cuh
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
#ifndef _matrix_view_cuh
|
||||||
|
#define _matrix_view_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
|
#include "quant/qdq_util.cuh"
|
||||||
|
|
||||||
|
class MatrixView_half
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||||
|
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*) item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __low2half(i01);
|
||||||
|
items[1] = __high2half(i01);
|
||||||
|
items[2] = __low2half(i23);
|
||||||
|
items[3] = __high2half(i23);
|
||||||
|
}
|
||||||
|
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __half2float(__low2half(i01));
|
||||||
|
items[1] = __half2float(__high2half(i01));
|
||||||
|
items[2] = __half2float(__low2half(i23));
|
||||||
|
items[3] = __half2float(__high2half(i23));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
half2* ptr = (half2*)item_ptr(row, column);
|
||||||
|
half2 i01 = ptr[0];
|
||||||
|
half2 i23 = ptr[1];
|
||||||
|
items[0] = __half2half2(__low2half(i01));
|
||||||
|
items[1] = __half2half2(__high2half(i01));
|
||||||
|
items[2] = __half2half2(__low2half(i23));
|
||||||
|
items[3] = __half2half2(__high2half(i23));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_half_rw
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
half* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||||
|
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||||
|
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||||
|
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||||
|
|
||||||
|
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||||
|
{
|
||||||
|
half2 v01 = __halves2half2(v0, v1);
|
||||||
|
half2 v23 = __halves2half2(v2, v3);
|
||||||
|
half2* ptr = (half2*) item_ptr(row, column);
|
||||||
|
ptr[0] = v01;
|
||||||
|
ptr[1] = v23;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MatrixView_q4_row
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
const uint32_t* data;
|
||||||
|
const int height;
|
||||||
|
const int width;
|
||||||
|
|
||||||
|
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||||
|
: data(data), height(height), width(width)
|
||||||
|
{ }
|
||||||
|
|
||||||
|
__device__ __forceinline__ int item(int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
|
items[0] = d & 0x0f;
|
||||||
|
items[1] = (d >> 4) & 0x0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||||
|
{
|
||||||
|
int shift = (column & 0x07) * 4;
|
||||||
|
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||||
|
items[0] = d & 0x0f;
|
||||||
|
items[1] = (d >> 4) & 0x0f;
|
||||||
|
items[2] = (d >> 8) & 0x0f;
|
||||||
|
items[3] = (d >> 12) & 0x0f;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
238
autogptq_extension/exllamav2/cuda/q_gemm.cu
Normal file
238
autogptq_extension/exllamav2/cuda/q_gemm.cu
Normal file
|
@ -0,0 +1,238 @@
|
||||||
|
#include "q_gemm.cuh"
|
||||||
|
#include "util.cuh"
|
||||||
|
#include "matrix_view.cuh"
|
||||||
|
#include "../config.h"
|
||||||
|
|
||||||
|
#include "quant/qdq_2.cuh"
|
||||||
|
#include "quant/qdq_3.cuh"
|
||||||
|
#include "quant/qdq_4.cuh"
|
||||||
|
#include "quant/qdq_5.cuh"
|
||||||
|
#include "quant/qdq_6.cuh"
|
||||||
|
#include "quant/qdq_8.cuh"
|
||||||
|
|
||||||
|
#define BLOCK_KN_SIZE 128
|
||||||
|
#define BLOCK_M_SIZE_MAX 8
|
||||||
|
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||||
|
#define CLEAR_N_SIZE 256
|
||||||
|
|
||||||
|
#include "q_gemm_kernel.cuh"
|
||||||
|
#include "q_gemm_kernel_gptq.cuh"
|
||||||
|
|
||||||
|
#if defined(USE_ROCM)
|
||||||
|
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||||
|
hipblasOperation_t transA,
|
||||||
|
hipblasOperation_t transB,
|
||||||
|
int m,
|
||||||
|
int n,
|
||||||
|
int k,
|
||||||
|
const half* alpha,
|
||||||
|
const half* AP,
|
||||||
|
int lda,
|
||||||
|
const half* BP,
|
||||||
|
int ldb,
|
||||||
|
const half* beta,
|
||||||
|
half* CP,
|
||||||
|
int ldc) {
|
||||||
|
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||||
|
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||||
|
reinterpret_cast<const hipblasHalf *>(beta),
|
||||||
|
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||||
|
}
|
||||||
|
#define hipblasHgemm __compat_hipblasHgemm
|
||||||
|
|
||||||
|
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||||
|
#define rocblas_operation_none HIPBLAS_OP_N
|
||||||
|
#define rocblas_hgemm __compat_hipblasHgemm
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void gemm_half_q_half_cuda_part
|
||||||
|
(
|
||||||
|
const half* a,
|
||||||
|
QMatrix* b,
|
||||||
|
half* c,
|
||||||
|
int size_m,
|
||||||
|
int size_n,
|
||||||
|
int size_k,
|
||||||
|
int m_count,
|
||||||
|
bool clear
|
||||||
|
)
|
||||||
|
{
|
||||||
|
if (!b->is_gptq)
|
||||||
|
{
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
blockDim.z = 1;
|
||||||
|
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||||
|
gridDim.y = DIVIDE(size_m, m_count);
|
||||||
|
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
||||||
|
|
||||||
|
kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
a,
|
||||||
|
b->cuda_q_weight,
|
||||||
|
b->cuda_q_scale,
|
||||||
|
b->cuda_q_scale_max,
|
||||||
|
c,
|
||||||
|
size_m,
|
||||||
|
size_n,
|
||||||
|
size_k,
|
||||||
|
b->groups,
|
||||||
|
b->groupsize,
|
||||||
|
b->cuda_q_perm,
|
||||||
|
b->rows_8,
|
||||||
|
b->rows_6,
|
||||||
|
b->rows_5,
|
||||||
|
b->rows_4,
|
||||||
|
b->rows_3,
|
||||||
|
b->rows_2,
|
||||||
|
clear
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
blockDim.z = 1;
|
||||||
|
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||||
|
gridDim.y = DIVIDE(size_m, m_count);
|
||||||
|
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||||
|
|
||||||
|
// DBGX((uint64_t) b->cuda_q_perm);
|
||||||
|
// DBGI(b->rows_4);
|
||||||
|
// DBGI(b->height);
|
||||||
|
|
||||||
|
kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
a,
|
||||||
|
b->cuda_q_weight,
|
||||||
|
b->cuda_gptq_qzeros,
|
||||||
|
b->cuda_gptq_scales,
|
||||||
|
c,
|
||||||
|
size_m,
|
||||||
|
size_n,
|
||||||
|
size_k,
|
||||||
|
b->groups,
|
||||||
|
b->groupsize,
|
||||||
|
b->cuda_q_perm,
|
||||||
|
b->rows_4,
|
||||||
|
clear
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_half_q_half_cuda
|
||||||
|
(
|
||||||
|
cublasHandle_t cublas_handle,
|
||||||
|
const half* a,
|
||||||
|
QMatrix* b,
|
||||||
|
half* c,
|
||||||
|
int size_m,
|
||||||
|
int size_n,
|
||||||
|
int size_k,
|
||||||
|
bool clear,
|
||||||
|
half* temp_dq,
|
||||||
|
bool force_cuda
|
||||||
|
)
|
||||||
|
{
|
||||||
|
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||||
|
{
|
||||||
|
//printf("cublas\n");
|
||||||
|
|
||||||
|
// Reconstruct FP16 matrix, then cuBLAS
|
||||||
|
|
||||||
|
if (!temp_dq) temp_dq = b->temp_dq;
|
||||||
|
b->reconstruct(temp_dq);
|
||||||
|
|
||||||
|
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
|
||||||
|
|
||||||
|
const half alpha = __float2half(1.0f);
|
||||||
|
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
|
||||||
|
cublasHgemm(cublas_handle,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
CUBLAS_OP_N,
|
||||||
|
size_n, size_m, size_k,
|
||||||
|
&alpha, temp_dq, size_n,
|
||||||
|
a, size_k,
|
||||||
|
&beta, c, size_n);
|
||||||
|
|
||||||
|
//const float alpha = 1.0f;
|
||||||
|
//const float beta = clear ? 0.0f : 1.0f;
|
||||||
|
//cublasSgemmEx(cublas_handle,
|
||||||
|
// CUBLAS_OP_N,
|
||||||
|
// CUBLAS_OP_N,
|
||||||
|
// size_n, size_m, size_k,
|
||||||
|
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||||
|
// a, CUDA_R_16F, size_k,
|
||||||
|
// &beta, c, CUDA_R_16F, size_n);
|
||||||
|
|
||||||
|
//const float alpha = 1.0f;
|
||||||
|
//const float beta = clear ? 0.0f : 1.0f;
|
||||||
|
//cublasGemmEx(cublas_handle,
|
||||||
|
// CUBLAS_OP_N, CUBLAS_OP_N,
|
||||||
|
// size_n, size_m, size_k,
|
||||||
|
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||||
|
// a, CUDA_R_16F, size_k,
|
||||||
|
// &beta, c, CUDA_R_16F, size_n,
|
||||||
|
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
//printf("cuda\n");
|
||||||
|
|
||||||
|
// Quantized matmul
|
||||||
|
|
||||||
|
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
||||||
|
|
||||||
|
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||||
|
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||||
|
int last_chunk_size = size_m - last_chunk;
|
||||||
|
|
||||||
|
if (max_chunks)
|
||||||
|
{
|
||||||
|
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (last_chunk_size)
|
||||||
|
{
|
||||||
|
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void clear_kernel
|
||||||
|
(
|
||||||
|
half* __restrict__ c,
|
||||||
|
const int size_m,
|
||||||
|
const int size_n
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int m = blockIdx.y;
|
||||||
|
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
|
||||||
|
if (n >= size_n) return;
|
||||||
|
int4* c_ptr = (int4*)(c + m * size_n + n);
|
||||||
|
*c_ptr = {};
|
||||||
|
}
|
||||||
|
|
||||||
|
void clear_tensor_cuda
|
||||||
|
(
|
||||||
|
half* c,
|
||||||
|
int size_m,
|
||||||
|
int size_n
|
||||||
|
)
|
||||||
|
{
|
||||||
|
return;
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = CLEAR_N_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||||
|
gridDim.y = size_m;
|
||||||
|
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||||
|
}
|
33
autogptq_extension/exllamav2/cuda/q_gemm.cuh
Normal file
33
autogptq_extension/exllamav2/cuda/q_gemm.cuh
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
#ifndef _q_gemm_cuh
|
||||||
|
#define _q_gemm_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#include "q_matrix.cuh"
|
||||||
|
|
||||||
|
void gemm_half_q_half_cuda
|
||||||
|
(
|
||||||
|
cublasHandle_t cublas_handle,
|
||||||
|
const half* a,
|
||||||
|
QMatrix* b,
|
||||||
|
half* c,
|
||||||
|
int size_m,
|
||||||
|
int size_n,
|
||||||
|
int size_k,
|
||||||
|
bool clear = false,
|
||||||
|
half* reconstruct = NULL,
|
||||||
|
bool force_cuda = false
|
||||||
|
);
|
||||||
|
|
||||||
|
void clear_tensor_cuda
|
||||||
|
(
|
||||||
|
half* c,
|
||||||
|
int size_m,
|
||||||
|
int size_n
|
||||||
|
);
|
||||||
|
|
||||||
|
#endif
|
484
autogptq_extension/exllamav2/cuda/q_gemm_kernel.cuh
Normal file
484
autogptq_extension/exllamav2/cuda/q_gemm_kernel.cuh
Normal file
|
@ -0,0 +1,484 @@
|
||||||
|
#include "compat.cuh"
|
||||||
|
|
||||||
|
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||||
|
return fma(result_f, qs_f, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||||
|
return fma(result_f, qs_f, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||||
|
return fma(result_f, qs_f, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
typedef void (*fp_gemm_half_q_half_kernel)
|
||||||
|
(
|
||||||
|
const half*,
|
||||||
|
const uint32_t*,
|
||||||
|
const uint32_t*,
|
||||||
|
const half*,
|
||||||
|
half*,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const uint16_t*,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const bool
|
||||||
|
);
|
||||||
|
|
||||||
|
template <bool first_block, int m_count>
|
||||||
|
__global__ void gemm_half_q_half_kernel
|
||||||
|
(
|
||||||
|
const half* __restrict__ a,
|
||||||
|
const uint32_t* __restrict__ b_q_weight,
|
||||||
|
const uint32_t* __restrict__ b_q_scale,
|
||||||
|
const half* __restrict__ b_q_scale_max,
|
||||||
|
half* __restrict__ c,
|
||||||
|
const int size_m,
|
||||||
|
const int size_n,
|
||||||
|
const int size_k,
|
||||||
|
const int groups,
|
||||||
|
const int groupsize,
|
||||||
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
|
const int rows_8,
|
||||||
|
const int rows_6,
|
||||||
|
const int rows_5,
|
||||||
|
const int rows_4,
|
||||||
|
const int rows_3,
|
||||||
|
const int rows_2,
|
||||||
|
const bool clear
|
||||||
|
)
|
||||||
|
{
|
||||||
|
MatrixView_half a_(a, size_m, size_k);
|
||||||
|
MatrixView_half_rw c_(c, size_m, size_n);
|
||||||
|
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||||
|
|
||||||
|
int t = threadIdx.x;
|
||||||
|
|
||||||
|
// Block
|
||||||
|
|
||||||
|
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||||
|
int offset_m = blockIdx.y * m_count;
|
||||||
|
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||||
|
|
||||||
|
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||||
|
int end_m = min(offset_m + m_count, size_m);
|
||||||
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
int n = offset_n + t * 4;
|
||||||
|
|
||||||
|
// Preload block_a
|
||||||
|
|
||||||
|
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||||
|
|
||||||
|
if (offset_k + t < end_k)
|
||||||
|
{
|
||||||
|
for (int m = 0; m < m_count; ++m)
|
||||||
|
{
|
||||||
|
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||||
|
half* block_a_ptr = block_a[m];
|
||||||
|
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||||
|
block_a_ptr[t] = a0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear
|
||||||
|
|
||||||
|
if (n >= size_n) return;
|
||||||
|
|
||||||
|
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||||
|
{
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Find initial group
|
||||||
|
|
||||||
|
int group = offset_k / groupsize;
|
||||||
|
|
||||||
|
// Preload scales
|
||||||
|
|
||||||
|
float scales[MAX_GROUPS_IN_BLOCK][4];
|
||||||
|
|
||||||
|
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||||
|
for (int g = 0; g < groups_in_block; g++)
|
||||||
|
{
|
||||||
|
int qscales[4];
|
||||||
|
b_q_scale_.item4(qscales, group + g, n);
|
||||||
|
qscales[0]++;
|
||||||
|
qscales[1]++;
|
||||||
|
qscales[2]++;
|
||||||
|
qscales[3]++;
|
||||||
|
float maxscale = __half2float(b_q_scale_max[group + g]);
|
||||||
|
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
||||||
|
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
||||||
|
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
||||||
|
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// a, b offset
|
||||||
|
|
||||||
|
int pre_rows_8 = min(rows_8, offset_k);
|
||||||
|
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||||
|
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||||
|
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||||
|
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||||
|
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||||
|
int qk = 0;
|
||||||
|
qk += pre_rows_8 / 32 * 8;
|
||||||
|
qk += pre_rows_6 / 32 * 6;
|
||||||
|
qk += pre_rows_5 / 32 * 5;
|
||||||
|
qk += pre_rows_4 / 32 * 4;
|
||||||
|
qk += pre_rows_3 / 32 * 3;
|
||||||
|
qk += pre_rows_2 / 32 * 2;
|
||||||
|
|
||||||
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
|
const half* a_ptr = &block_a[0][0];
|
||||||
|
int a_stride = BLOCK_KN_SIZE;
|
||||||
|
|
||||||
|
// Initial group
|
||||||
|
|
||||||
|
int scales_idx = 0;
|
||||||
|
float qs_f0 = scales[scales_idx][0];
|
||||||
|
float qs_f1 = scales[scales_idx][1];
|
||||||
|
float qs_f2 = scales[scales_idx][2];
|
||||||
|
float qs_f3 = scales[scales_idx][3];
|
||||||
|
int nextgroup = offset_k + groupsize;
|
||||||
|
|
||||||
|
// Column result
|
||||||
|
|
||||||
|
float block_c[m_count][4] = {};
|
||||||
|
|
||||||
|
// Dequantize groups
|
||||||
|
|
||||||
|
int k = offset_k;
|
||||||
|
|
||||||
|
while (k < rows_8 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
scales_idx++;
|
||||||
|
qs_f0 = scales[scales_idx][0];
|
||||||
|
qs_f1 = scales[scales_idx][1];
|
||||||
|
qs_f2 = scales[scales_idx][2];
|
||||||
|
qs_f3 = scales[scales_idx][3];
|
||||||
|
nextgroup += groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
int4 load_int4[2];
|
||||||
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
|
||||||
|
half2 dq[4][4];
|
||||||
|
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
|
||||||
|
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
|
||||||
|
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
|
||||||
|
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||||
|
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||||
|
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||||
|
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||||
|
}
|
||||||
|
a_ptr += 8;
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_6 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
scales_idx++;
|
||||||
|
qs_f0 = scales[scales_idx][0];
|
||||||
|
qs_f1 = scales[scales_idx][1];
|
||||||
|
qs_f2 = scales[scales_idx][2];
|
||||||
|
qs_f3 = scales[scales_idx][3];
|
||||||
|
nextgroup += groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 2; j++)
|
||||||
|
{
|
||||||
|
int4 load_int4[3];
|
||||||
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
|
||||||
|
half2 dq[4][8];
|
||||||
|
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||||
|
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||||
|
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||||
|
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||||
|
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||||
|
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||||
|
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||||
|
}
|
||||||
|
a_ptr += 16;
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_5 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
scales_idx++;
|
||||||
|
qs_f0 = scales[scales_idx][0];
|
||||||
|
qs_f1 = scales[scales_idx][1];
|
||||||
|
qs_f2 = scales[scales_idx][2];
|
||||||
|
qs_f3 = scales[scales_idx][3];
|
||||||
|
nextgroup += groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 1; j++)
|
||||||
|
{
|
||||||
|
int4 load_int4[5];
|
||||||
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
|
||||||
|
half2 dq[4][16];
|
||||||
|
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
|
||||||
|
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
|
||||||
|
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
|
||||||
|
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||||
|
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||||
|
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||||
|
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||||
|
}
|
||||||
|
a_ptr += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_4 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
scales_idx++;
|
||||||
|
qs_f0 = scales[scales_idx][0];
|
||||||
|
qs_f1 = scales[scales_idx][1];
|
||||||
|
qs_f2 = scales[scales_idx][2];
|
||||||
|
qs_f3 = scales[scales_idx][3];
|
||||||
|
nextgroup += groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
int4 load_int4[1];
|
||||||
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
|
||||||
|
half2 dq[4][4];
|
||||||
|
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
|
||||||
|
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
|
||||||
|
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
|
||||||
|
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||||
|
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||||
|
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||||
|
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||||
|
}
|
||||||
|
a_ptr += 8;
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_3 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
scales_idx++;
|
||||||
|
qs_f0 = scales[scales_idx][0];
|
||||||
|
qs_f1 = scales[scales_idx][1];
|
||||||
|
qs_f2 = scales[scales_idx][2];
|
||||||
|
qs_f3 = scales[scales_idx][3];
|
||||||
|
nextgroup += groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 1; j++)
|
||||||
|
{
|
||||||
|
int4 load_int4[3];
|
||||||
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
|
||||||
|
half2 dq[4][16];
|
||||||
|
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||||
|
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||||
|
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||||
|
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||||
|
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||||
|
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||||
|
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||||
|
}
|
||||||
|
a_ptr += 32;
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_2 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
scales_idx++;
|
||||||
|
qs_f0 = scales[scales_idx][0];
|
||||||
|
qs_f1 = scales[scales_idx][1];
|
||||||
|
qs_f2 = scales[scales_idx][2];
|
||||||
|
qs_f3 = scales[scales_idx][3];
|
||||||
|
nextgroup += groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 2; j++)
|
||||||
|
{
|
||||||
|
int4 load_int4[1];
|
||||||
|
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||||
|
|
||||||
|
half2 dq[4][8];
|
||||||
|
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
|
||||||
|
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
|
||||||
|
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
|
||||||
|
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||||
|
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||||
|
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||||
|
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||||
|
}
|
||||||
|
|
||||||
|
a_ptr += 16;
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accumulate column sums in c
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||||
|
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||||
|
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||||
|
atomicAdd(out , result01);
|
||||||
|
atomicAdd(out + 1, result23);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
||||||
|
{
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 1
|
||||||
|
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 2
|
||||||
|
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 3
|
||||||
|
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 4
|
||||||
|
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 5
|
||||||
|
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 6
|
||||||
|
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 7
|
||||||
|
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 8
|
||||||
|
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
219
autogptq_extension/exllamav2/cuda/q_gemm_kernel_gptq.cuh
Normal file
219
autogptq_extension/exllamav2/cuda/q_gemm_kernel_gptq.cuh
Normal file
|
@ -0,0 +1,219 @@
|
||||||
|
#include "compat.cuh"
|
||||||
|
|
||||||
|
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return __hadd2(result, g_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||||
|
{
|
||||||
|
half2 result = {};
|
||||||
|
const half2* a2_ptr = (const half2*)a_ptr;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||||
|
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||||
|
(
|
||||||
|
const half*,
|
||||||
|
const uint32_t*,
|
||||||
|
const uint32_t*,
|
||||||
|
const half*,
|
||||||
|
half*,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const int,
|
||||||
|
const uint16_t*,
|
||||||
|
const int,
|
||||||
|
const bool
|
||||||
|
);
|
||||||
|
|
||||||
|
template <bool first_block, int m_count>
|
||||||
|
__global__ void gemm_half_q_half_gptq_kernel
|
||||||
|
(
|
||||||
|
const half* __restrict__ a,
|
||||||
|
const uint32_t* __restrict__ b_q_weight,
|
||||||
|
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||||
|
const half* __restrict__ b_gptq_scales,
|
||||||
|
half* __restrict__ c,
|
||||||
|
const int size_m,
|
||||||
|
const int size_n,
|
||||||
|
const int size_k,
|
||||||
|
const int groups,
|
||||||
|
const int groupsize,
|
||||||
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
|
const int rows_4,
|
||||||
|
const bool clear
|
||||||
|
)
|
||||||
|
{
|
||||||
|
MatrixView_half a_(a, size_m, size_k);
|
||||||
|
MatrixView_half_rw c_(c, size_m, size_n);
|
||||||
|
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
|
int t = threadIdx.x;
|
||||||
|
|
||||||
|
// Block
|
||||||
|
|
||||||
|
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||||
|
int offset_m = blockIdx.y * m_count;
|
||||||
|
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||||
|
|
||||||
|
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||||
|
int end_m = min(offset_m + m_count, size_m);
|
||||||
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
|
int n = offset_n + t * 4;
|
||||||
|
|
||||||
|
// Preload block_a
|
||||||
|
|
||||||
|
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||||
|
|
||||||
|
if (offset_k + t < end_k)
|
||||||
|
{
|
||||||
|
for (int m = 0; m < m_count; ++m)
|
||||||
|
{
|
||||||
|
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||||
|
half* block_a_ptr = block_a[m];
|
||||||
|
|
||||||
|
half a0;
|
||||||
|
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||||
|
else a0 = a_ptr[offset_k + t];
|
||||||
|
block_a_ptr[t] = a0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zero output
|
||||||
|
|
||||||
|
if (n >= size_n) return;
|
||||||
|
|
||||||
|
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||||
|
{
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// Find initial group
|
||||||
|
|
||||||
|
int group = offset_k / groupsize;
|
||||||
|
int nextgroup = offset_k + groupsize;
|
||||||
|
|
||||||
|
// a, b offset
|
||||||
|
|
||||||
|
int qk = offset_k / (32 / 4);
|
||||||
|
|
||||||
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
|
const half* a_ptr = &block_a[0][0];
|
||||||
|
int a_stride = BLOCK_KN_SIZE;
|
||||||
|
|
||||||
|
// Initial group
|
||||||
|
|
||||||
|
int zeros[4];
|
||||||
|
float scales[4];
|
||||||
|
half2 z1z16[4][2];
|
||||||
|
half2 y1y16[4][2];
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_f(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
|
// __syncthreads();
|
||||||
|
|
||||||
|
// Column result
|
||||||
|
|
||||||
|
float block_c[m_count][4] = {};
|
||||||
|
|
||||||
|
// Dequantize and multiply
|
||||||
|
|
||||||
|
int k = offset_k;
|
||||||
|
while (k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
nextgroup += groupsize;
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_f(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
const int4* b_ptr4 = (int4*) b_ptr;
|
||||||
|
int4 load_int4 = *b_ptr4;
|
||||||
|
|
||||||
|
half2 dq[4][4];
|
||||||
|
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||||
|
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||||
|
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||||
|
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
b_ptr += size_n;
|
||||||
|
a_ptr += 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int m = 0; m < m_count; m++)
|
||||||
|
{
|
||||||
|
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||||
|
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||||
|
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||||
|
atomicAdd(out , result01);
|
||||||
|
atomicAdd(out + 1, result23);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||||
|
{
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 1
|
||||||
|
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 2
|
||||||
|
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 3
|
||||||
|
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 4
|
||||||
|
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 5
|
||||||
|
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 6
|
||||||
|
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 7
|
||||||
|
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||||
|
#endif
|
||||||
|
#if BLOCK_M_SIZE_MAX >= 8
|
||||||
|
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||||
|
#endif
|
||||||
|
return NULL;
|
||||||
|
}
|
603
autogptq_extension/exllamav2/cuda/q_matrix.cu
Normal file
603
autogptq_extension/exllamav2/cuda/q_matrix.cu
Normal file
|
@ -0,0 +1,603 @@
|
||||||
|
#include "q_matrix.cuh"
|
||||||
|
#include "matrix_view.cuh"
|
||||||
|
#include "util.cuh"
|
||||||
|
|
||||||
|
#include "quant/qdq_2.cuh"
|
||||||
|
#include "quant/qdq_3.cuh"
|
||||||
|
#include "quant/qdq_4.cuh"
|
||||||
|
#include "quant/qdq_5.cuh"
|
||||||
|
#include "quant/qdq_6.cuh"
|
||||||
|
#include "quant/qdq_8.cuh"
|
||||||
|
|
||||||
|
#define BLOCK_KN_SIZE 128
|
||||||
|
|
||||||
|
#define THREADS_X 32
|
||||||
|
#define THREADS_Y 32
|
||||||
|
|
||||||
|
// Shuffle quantized data on load
|
||||||
|
|
||||||
|
__global__ void shuffle_kernel
|
||||||
|
(
|
||||||
|
uint32_t* __restrict__ b_q_weight,
|
||||||
|
const int size_k,
|
||||||
|
const int size_n,
|
||||||
|
const int rows_8,
|
||||||
|
const int rows_6,
|
||||||
|
const int rows_5,
|
||||||
|
const int rows_4,
|
||||||
|
const int rows_3,
|
||||||
|
const int rows_2
|
||||||
|
)
|
||||||
|
{
|
||||||
|
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||||
|
if (n >= size_n) return;
|
||||||
|
int k = 0;
|
||||||
|
uint32_t* b_ptr = b_q_weight + n;
|
||||||
|
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
|
||||||
|
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
|
||||||
|
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
|
||||||
|
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||||
|
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
|
||||||
|
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// QMatrix constructor
|
||||||
|
|
||||||
|
QMatrix::QMatrix
|
||||||
|
(
|
||||||
|
const int _device,
|
||||||
|
const int _height,
|
||||||
|
const int _width,
|
||||||
|
const int _groups,
|
||||||
|
|
||||||
|
uint32_t* _q_weight,
|
||||||
|
uint16_t* _q_perm,
|
||||||
|
uint16_t* _q_invperm,
|
||||||
|
uint32_t* _q_scale,
|
||||||
|
half* _q_scale_max,
|
||||||
|
uint16_t* _q_groups,
|
||||||
|
|
||||||
|
uint32_t* _gptq_qzeros,
|
||||||
|
half* _gptq_scales,
|
||||||
|
uint32_t* _gptq_g_idx,
|
||||||
|
|
||||||
|
half* _temp_dq
|
||||||
|
) :
|
||||||
|
device(_device),
|
||||||
|
height(_height),
|
||||||
|
width(_width),
|
||||||
|
groups(_groups),
|
||||||
|
temp_dq(_temp_dq)
|
||||||
|
{
|
||||||
|
cudaSetDevice(device);
|
||||||
|
|
||||||
|
cuda_q_weight = _q_weight;
|
||||||
|
cuda_q_perm = _q_perm;
|
||||||
|
cuda_q_invperm = _q_invperm;
|
||||||
|
cuda_q_scale = _q_scale;
|
||||||
|
cuda_q_scale_max = _q_scale_max;
|
||||||
|
cuda_q_groups = _q_groups;
|
||||||
|
cuda_gptq_qzeros = _gptq_qzeros;
|
||||||
|
cuda_gptq_scales = _gptq_scales;
|
||||||
|
|
||||||
|
is_gptq = (_gptq_qzeros != NULL);
|
||||||
|
|
||||||
|
groupsize = 1;
|
||||||
|
while (groupsize * groups < height) groupsize *= 2;
|
||||||
|
|
||||||
|
// Create group map
|
||||||
|
|
||||||
|
rows_8 = 0;
|
||||||
|
rows_6 = 0;
|
||||||
|
rows_5 = 0;
|
||||||
|
rows_4 = 0;
|
||||||
|
rows_3 = 0;
|
||||||
|
rows_2 = 0;
|
||||||
|
|
||||||
|
if (!is_gptq)
|
||||||
|
{
|
||||||
|
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||||
|
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||||
|
|
||||||
|
for (int i = 0; i < groups; i++)
|
||||||
|
{
|
||||||
|
int bits = cpu_q_groups[i * 2];
|
||||||
|
if (bits == 8) rows_8 += groupsize;
|
||||||
|
if (bits == 6) rows_6 += groupsize;
|
||||||
|
if (bits == 5) rows_5 += groupsize;
|
||||||
|
if (bits == 4) rows_4 += groupsize;
|
||||||
|
if (bits == 3) rows_3 += groupsize;
|
||||||
|
if (bits == 2) rows_2 += groupsize;
|
||||||
|
}
|
||||||
|
|
||||||
|
free(cpu_q_groups);
|
||||||
|
|
||||||
|
rows_6 += rows_8;
|
||||||
|
rows_5 += rows_6;
|
||||||
|
rows_4 += rows_5;
|
||||||
|
rows_3 += rows_4;
|
||||||
|
rows_2 += rows_3;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
rows_4 = height;
|
||||||
|
rows_3 = height;
|
||||||
|
rows_2 = height;
|
||||||
|
|
||||||
|
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shuffle quantized data
|
||||||
|
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = THREADS_X;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
|
gridDim.y = 1;
|
||||||
|
|
||||||
|
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Reconstruct b[k,n] (GPTQ)
|
||||||
|
|
||||||
|
__global__ void reconstruct_gptq_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ b_q_weight,
|
||||||
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
|
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||||
|
const half* __restrict__ b_gptq_scales,
|
||||||
|
//const uint16_t* __restrict__ b_q_groups,
|
||||||
|
const int size_k,
|
||||||
|
const int size_n,
|
||||||
|
const int groupsize,
|
||||||
|
const int groups,
|
||||||
|
half* __restrict__ b,
|
||||||
|
const int rows_4
|
||||||
|
)
|
||||||
|
{
|
||||||
|
MatrixView_half_rw b_(b, size_k, size_n);
|
||||||
|
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||||
|
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||||
|
|
||||||
|
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||||
|
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||||
|
|
||||||
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
|
||||||
|
// Preload remapping table
|
||||||
|
|
||||||
|
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||||
|
int t = threadIdx.x;
|
||||||
|
|
||||||
|
if (b_q_perm)
|
||||||
|
{
|
||||||
|
if (offset_k + t < size_k)
|
||||||
|
perm[t] = b_q_perm[offset_k + t];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Column
|
||||||
|
|
||||||
|
int n = offset_n + t * 4;
|
||||||
|
if (n >= size_n) return;
|
||||||
|
|
||||||
|
// Find initial group
|
||||||
|
|
||||||
|
int group = offset_k / groupsize;
|
||||||
|
int nextgroup = offset_k + groupsize;
|
||||||
|
|
||||||
|
// b offset
|
||||||
|
|
||||||
|
int qk = offset_k / (32 / 4);
|
||||||
|
|
||||||
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
|
|
||||||
|
// Initial zeros/scale
|
||||||
|
|
||||||
|
int zeros[4];
|
||||||
|
half2 scales[4];
|
||||||
|
half2 z1z16[4][2];
|
||||||
|
half2 y1y16[4][2];
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
int k = offset_k;
|
||||||
|
int lk = 0;
|
||||||
|
|
||||||
|
while (k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup)
|
||||||
|
{
|
||||||
|
group++;
|
||||||
|
nextgroup += groupsize;
|
||||||
|
b_gptq_qzeros_.item4(zeros, group, n);
|
||||||
|
b_gptq_scales_.item4_h2(scales, group, n);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||||
|
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int p = 0; p < 4; p++)
|
||||||
|
{
|
||||||
|
half2 dq[4][4];
|
||||||
|
const int4* b_ptr4 = (int4*) b_ptr;
|
||||||
|
int4 load_int4 = *b_ptr4;
|
||||||
|
|
||||||
|
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||||
|
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||||
|
|
||||||
|
b_ptr += size_n;
|
||||||
|
//half* dqh = (half*)dq;
|
||||||
|
if (b_q_perm)
|
||||||
|
{
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||||
|
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||||
|
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
for (int j = 0; j < 4; j++)
|
||||||
|
{
|
||||||
|
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||||
|
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||||
|
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Reconstruct b[k,n]
|
||||||
|
|
||||||
|
__global__ void reconstruct_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ b_q_weight,
|
||||||
|
const uint16_t* __restrict__ b_q_perm,
|
||||||
|
const uint32_t* __restrict__ b_q_scale,
|
||||||
|
const half* __restrict__ b_q_scale_max,
|
||||||
|
//const uint16_t* __restrict__ b_q_groups,
|
||||||
|
const int size_k,
|
||||||
|
const int size_n,
|
||||||
|
const int groupsize,
|
||||||
|
const int groups,
|
||||||
|
half* __restrict__ b,
|
||||||
|
const int rows_8,
|
||||||
|
const int rows_6,
|
||||||
|
const int rows_5,
|
||||||
|
const int rows_4,
|
||||||
|
const int rows_3,
|
||||||
|
const int rows_2
|
||||||
|
)
|
||||||
|
{
|
||||||
|
MatrixView_half_rw b_(b, size_k, size_n);
|
||||||
|
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||||
|
|
||||||
|
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||||
|
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
|
||||||
|
|
||||||
|
// Preload remapping table
|
||||||
|
|
||||||
|
int t = threadIdx.x;
|
||||||
|
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||||
|
if (offset_k + t < size_k)
|
||||||
|
perm[t] = b_q_perm[offset_k + t];
|
||||||
|
|
||||||
|
// Column
|
||||||
|
|
||||||
|
int n = offset_n + t;
|
||||||
|
if (n >= size_n) return;
|
||||||
|
|
||||||
|
// Find initial group
|
||||||
|
|
||||||
|
int group = offset_k / groupsize;
|
||||||
|
|
||||||
|
int pre_rows_8 = min(rows_8, offset_k);
|
||||||
|
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||||
|
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||||
|
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||||
|
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||||
|
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||||
|
int qk = 0;
|
||||||
|
qk += pre_rows_8 / 32 * 8;
|
||||||
|
qk += pre_rows_6 / 32 * 6;
|
||||||
|
qk += pre_rows_5 / 32 * 5;
|
||||||
|
qk += pre_rows_4 / 32 * 4;
|
||||||
|
qk += pre_rows_3 / 32 * 3;
|
||||||
|
qk += pre_rows_2 / 32 * 2;
|
||||||
|
|
||||||
|
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||||
|
|
||||||
|
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||||
|
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||||
|
int nextgroup = offset_k + groupsize;
|
||||||
|
|
||||||
|
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||||
|
int k = offset_k;
|
||||||
|
int lk = 0;
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
while (k < rows_8 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
|
for (int p = 0; p < 4; p++)
|
||||||
|
{
|
||||||
|
half2 dq[4];
|
||||||
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||||
|
dequant_8bit_8(q_0, q_1, dq, size_n);
|
||||||
|
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||||
|
half* dqh = (half*) dq;
|
||||||
|
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_6 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
|
for (int p = 0; p < 2; p++)
|
||||||
|
{
|
||||||
|
half2 dq[8];
|
||||||
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||||
|
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
|
||||||
|
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||||
|
half* dqh = (half*) dq;
|
||||||
|
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_5 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
|
for (int p = 0; p < 1; p++)
|
||||||
|
{
|
||||||
|
half2 dq[16];
|
||||||
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_3 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_4 = *b_ptr; b_ptr += size_n;
|
||||||
|
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
|
||||||
|
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||||
|
half* dqh = (half*) dq;
|
||||||
|
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_4 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
|
for (int p = 0; p < 4; p++)
|
||||||
|
{
|
||||||
|
half2 dq[4];
|
||||||
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
|
dequant_4bit_8(q_0, dq, size_n);
|
||||||
|
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||||
|
half* dqh = (half*) dq;
|
||||||
|
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_3 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
|
for (int p = 0; p < 1; p++)
|
||||||
|
{
|
||||||
|
half2 dq[16];
|
||||||
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||||
|
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||||
|
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
|
||||||
|
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||||
|
half* dqh = (half*) dq;
|
||||||
|
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
while (k < rows_2 && k < end_k)
|
||||||
|
{
|
||||||
|
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||||
|
for (int p = 0; p < 2; p++)
|
||||||
|
{
|
||||||
|
half2 dq[8];
|
||||||
|
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||||
|
dequant_2bit_16(q_0, dq, size_n);
|
||||||
|
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||||
|
half* dqh = (half*) dq;
|
||||||
|
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||||
|
}
|
||||||
|
k += 32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void QMatrix::reconstruct(half* out)
|
||||||
|
{
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = BLOCK_KN_SIZE;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||||
|
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||||
|
|
||||||
|
if (!is_gptq)
|
||||||
|
{
|
||||||
|
reconstruct_kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
cuda_q_weight,
|
||||||
|
cuda_q_perm,
|
||||||
|
cuda_q_scale,
|
||||||
|
cuda_q_scale_max,
|
||||||
|
//cuda_q_groups,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groupsize,
|
||||||
|
groups,
|
||||||
|
out,
|
||||||
|
rows_8,
|
||||||
|
rows_6,
|
||||||
|
rows_5,
|
||||||
|
rows_4,
|
||||||
|
rows_3,
|
||||||
|
rows_2
|
||||||
|
);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
cuda_q_weight,
|
||||||
|
cuda_q_perm,
|
||||||
|
cuda_gptq_qzeros,
|
||||||
|
cuda_gptq_scales,
|
||||||
|
//const uint16_t* __restrict__ b_q_groups,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groupsize,
|
||||||
|
groups,
|
||||||
|
out,
|
||||||
|
rows_4
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__global__ void make_sequential_kernel
|
||||||
|
(
|
||||||
|
const uint32_t* __restrict__ w,
|
||||||
|
uint32_t* __restrict__ w_new,
|
||||||
|
const uint16_t* __restrict__ q_perm,
|
||||||
|
const int w_height,
|
||||||
|
const int w_width
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint64_t* w2 = (uint64_t*) w;
|
||||||
|
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||||
|
int w2_stride = w_width >> 1;
|
||||||
|
|
||||||
|
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||||
|
if (w2_column >= w2_stride) return;
|
||||||
|
|
||||||
|
int w_new2_row = blockIdx.y;
|
||||||
|
|
||||||
|
int q_perm_idx = w_new2_row << 3;
|
||||||
|
|
||||||
|
uint64_t dst = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
{
|
||||||
|
int source_row = q_perm[q_perm_idx++];
|
||||||
|
|
||||||
|
int w2_row = source_row >> 3;
|
||||||
|
int w2_subrow = source_row & 0x07;
|
||||||
|
int w2_row_shift = w2_subrow << 2;
|
||||||
|
int wnew2_row_shift = i << 2;
|
||||||
|
|
||||||
|
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||||
|
src >>= w2_row_shift;
|
||||||
|
src &= 0x0000000f0000000f;
|
||||||
|
src <<= wnew2_row_shift;
|
||||||
|
dst |= src;
|
||||||
|
}
|
||||||
|
|
||||||
|
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||||
|
}
|
||||||
|
|
||||||
|
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||||
|
{
|
||||||
|
uint32_t* cuda_new_qweight = NULL;
|
||||||
|
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||||
|
|
||||||
|
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||||
|
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||||
|
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||||
|
|
||||||
|
// Group histogram
|
||||||
|
|
||||||
|
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||||
|
|
||||||
|
// Group map
|
||||||
|
|
||||||
|
for (int i = 0, acc = 0; i < groups; i++)
|
||||||
|
{
|
||||||
|
short tmp = cpu_g_idx_map[i];
|
||||||
|
cpu_g_idx_map[i] = acc;
|
||||||
|
acc += tmp;
|
||||||
|
}
|
||||||
|
|
||||||
|
// X map (inverse)
|
||||||
|
|
||||||
|
for (int row = 0; row < height; row++)
|
||||||
|
{
|
||||||
|
uint32_t target_group = cpu_g_idx[row];
|
||||||
|
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||||
|
cpu_g_idx_map[target_group]++;
|
||||||
|
cpu_x_map_inv[row] = target_row;
|
||||||
|
}
|
||||||
|
|
||||||
|
// X map
|
||||||
|
|
||||||
|
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||||
|
|
||||||
|
// Reduce to uint16_t
|
||||||
|
|
||||||
|
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
|
||||||
|
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
|
||||||
|
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
|
||||||
|
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
|
||||||
|
|
||||||
|
// Move to CUDA
|
||||||
|
|
||||||
|
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||||
|
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||||
|
|
||||||
|
// Rearrange rows in w
|
||||||
|
|
||||||
|
dim3 blockDim, gridDim;
|
||||||
|
blockDim.x = THREADS_X;
|
||||||
|
blockDim.y = 1;
|
||||||
|
gridDim.x = DIVIDE(width, THREADS_X);
|
||||||
|
gridDim.y = height / 8;
|
||||||
|
|
||||||
|
make_sequential_kernel<<<gridDim, blockDim>>>
|
||||||
|
(
|
||||||
|
cuda_q_weight,
|
||||||
|
cuda_new_qweight,
|
||||||
|
cuda_q_perm,
|
||||||
|
height / 8,
|
||||||
|
width
|
||||||
|
);
|
||||||
|
|
||||||
|
// Replace qweights
|
||||||
|
|
||||||
|
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
|
||||||
|
cudaDeviceSynchronize();
|
||||||
|
|
||||||
|
cudaFree(cuda_new_qweight);
|
||||||
|
free(cpu_g_idx_map);
|
||||||
|
free(cpu_x_map);
|
||||||
|
free(cpu_x_map_inv);
|
||||||
|
}
|
71
autogptq_extension/exllamav2/cuda/q_matrix.cuh
Normal file
71
autogptq_extension/exllamav2/cuda/q_matrix.cuh
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
#ifndef _q_matrix_cuh
|
||||||
|
#define _q_matrix_cuh
|
||||||
|
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#define MAX_SUPERGROUPS 16
|
||||||
|
|
||||||
|
class QMatrix
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
|
||||||
|
int device;
|
||||||
|
bool is_gptq;
|
||||||
|
|
||||||
|
int height;
|
||||||
|
int width;
|
||||||
|
int groups;
|
||||||
|
int groupsize;
|
||||||
|
|
||||||
|
int rows_8;
|
||||||
|
int rows_6;
|
||||||
|
int rows_5;
|
||||||
|
int rows_4;
|
||||||
|
int rows_3;
|
||||||
|
int rows_2;
|
||||||
|
|
||||||
|
uint32_t* cuda_q_weight = NULL;
|
||||||
|
uint16_t* cuda_q_perm = NULL;
|
||||||
|
uint16_t* cuda_q_invperm = NULL;
|
||||||
|
uint32_t* cuda_q_scale = NULL;
|
||||||
|
half* cuda_q_scale_max = NULL;
|
||||||
|
uint16_t* cuda_q_groups = NULL;
|
||||||
|
uint32_t* cuda_gptq_qzeros = NULL;
|
||||||
|
half* cuda_gptq_scales = NULL;
|
||||||
|
|
||||||
|
half* temp_dq;
|
||||||
|
|
||||||
|
QMatrix
|
||||||
|
(
|
||||||
|
const int _device,
|
||||||
|
const int _height,
|
||||||
|
const int _width,
|
||||||
|
const int _groups,
|
||||||
|
|
||||||
|
uint32_t* _q_weight,
|
||||||
|
uint16_t* _q_perm,
|
||||||
|
uint16_t* _q_invperm,
|
||||||
|
uint32_t* _q_scale,
|
||||||
|
half* _q_scale_max,
|
||||||
|
uint16_t* _q_groups,
|
||||||
|
|
||||||
|
uint32_t* _gptq_qzeros,
|
||||||
|
half* _gptq_scales,
|
||||||
|
uint32_t* _gptq_g_idx,
|
||||||
|
|
||||||
|
half* _temp_dq
|
||||||
|
);
|
||||||
|
|
||||||
|
~QMatrix();
|
||||||
|
|
||||||
|
void reconstruct(half* out);
|
||||||
|
void make_sequential(const uint32_t* cpu_g_idx);
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
103
autogptq_extension/exllamav2/cuda/quant/qdq_2.cuh
Normal file
103
autogptq_extension/exllamav2/cuda/quant/qdq_2.cuh
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
#ifndef _qdq_2_cuh
|
||||||
|
#define _qdq_2_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
#include "../../config.h"
|
||||||
|
|
||||||
|
#if QMODE_2BIT == 1
|
||||||
|
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// ffddbb99 77553311 eeccaa88 66442200
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_2bit_16
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0];
|
||||||
|
uint32_t qb = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 8; i++)
|
||||||
|
{
|
||||||
|
uint32_t qa0 = qa & 0x03;
|
||||||
|
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||||
|
qa >>= 4;
|
||||||
|
qb |= (qa1 << (i * 2 + 16));
|
||||||
|
qb |= (qa0 << (i * 2));
|
||||||
|
}
|
||||||
|
q[0] = qb;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_2bit_16
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[8],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||||
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
|
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||||
|
const half2 y4 = __halves2half2(y4_, y4_);
|
||||||
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
|
const half2 y64 = __halves2half2(y64_, y64_);
|
||||||
|
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
|
||||||
|
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
|
||||||
|
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
|
||||||
|
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
|
||||||
|
const half2 z1 = __halves2half2(z1_, z1_);
|
||||||
|
const half2 z4 = __halves2half2(z4_, z4_);
|
||||||
|
const half2 z16 = __halves2half2(z16_, z16_);
|
||||||
|
const half2 z64 = __halves2half2(z64_, z64_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||||
|
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||||
|
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||||
|
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||||
|
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||||
|
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||||
|
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||||
|
dq[4] = __hadd2(q4.as_half2, z1);
|
||||||
|
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||||
|
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||||
|
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_2bit_16
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_2bit_16
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[8],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[16];
|
||||||
|
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
169
autogptq_extension/exllamav2/cuda/quant/qdq_3.cuh
Normal file
169
autogptq_extension/exllamav2/cuda/quant/qdq_3.cuh
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
#ifndef _qdq_3_cuh
|
||||||
|
#define _qdq_3_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
#include "../../config.h"
|
||||||
|
|
||||||
|
#if QMODE_3BIT == 1
|
||||||
|
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||||
|
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||||
|
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_3bit_32
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0 * stride];
|
||||||
|
uint32_t qb = q[1 * stride];
|
||||||
|
uint32_t qc = q[2 * stride];
|
||||||
|
|
||||||
|
// qa: aa999888 77766655 54443332 22111000
|
||||||
|
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||||
|
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||||
|
|
||||||
|
uint32_t qd = qc >> 26;
|
||||||
|
qc <<= 4;
|
||||||
|
qc |= qb >> 28;
|
||||||
|
qb <<= 2;
|
||||||
|
qb |= qa >> 30;
|
||||||
|
|
||||||
|
// qa: ..999888 77766655 54443332 22111000
|
||||||
|
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||||
|
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||||
|
// qd: vvvuuu
|
||||||
|
|
||||||
|
uint32_t za = 0;
|
||||||
|
uint32_t zb = 0;
|
||||||
|
uint32_t zc = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||||
|
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||||
|
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||||
|
|
||||||
|
// za: 9997775 55333111 8886664 44222000
|
||||||
|
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||||
|
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||||
|
// qd: vvvuuu
|
||||||
|
|
||||||
|
za |= ((qd & 0x01) >> 0) << 15;
|
||||||
|
zb |= ((qd & 0x02) >> 1) << 15;
|
||||||
|
zc |= ((qd & 0x04) >> 2) << 15;
|
||||||
|
za |= ((qd & 0x08) >> 3) << 31;
|
||||||
|
zb |= ((qd & 0x10) >> 4) << 31;
|
||||||
|
zc |= ((qd & 0x20) >> 5) << 31;
|
||||||
|
|
||||||
|
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||||
|
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||||
|
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||||
|
|
||||||
|
q[0 * stride] = za;
|
||||||
|
q[1 * stride] = zb;
|
||||||
|
q[2 * stride] = zc;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_3bit_32
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
const uint32_t q_2,
|
||||||
|
half2 (&dq)[16],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||||
|
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||||
|
const half2 y8 = __halves2half2(y8_, y8_);
|
||||||
|
const half2 y64 = __halves2half2(y64_, y64_);
|
||||||
|
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
|
||||||
|
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
|
||||||
|
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
|
||||||
|
const half2 z1 = __halves2half2(z1_, z1_);
|
||||||
|
const half2 z8 = __halves2half2(z8_, z8_);
|
||||||
|
const half2 z64 = __halves2half2(z64_, z64_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
uint32_t qb = q_1;
|
||||||
|
uint32_t qc = q_2;
|
||||||
|
|
||||||
|
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||||
|
qa >>= 6;
|
||||||
|
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||||
|
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||||
|
qa >>= 9;
|
||||||
|
qa &= 0x00010001;
|
||||||
|
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||||
|
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||||
|
qb >>= 6;
|
||||||
|
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||||
|
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||||
|
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||||
|
qb >>= 8;
|
||||||
|
qb &= 0x00020002;
|
||||||
|
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||||
|
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||||
|
qc >>= 6;
|
||||||
|
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||||
|
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||||
|
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||||
|
qc >>= 7;
|
||||||
|
qc &= 0x00040004;
|
||||||
|
half2_uint32 q15((qa | qb | qc) | c0);
|
||||||
|
|
||||||
|
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||||
|
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||||
|
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||||
|
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||||
|
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||||
|
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||||
|
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||||
|
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||||
|
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||||
|
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||||
|
dq[10] = __hadd2(q10.as_half2, z1);
|
||||||
|
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||||
|
dq[12] = __hadd2(q12.as_half2, z1);
|
||||||
|
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||||
|
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||||
|
dq[15] = __hadd2(q15.as_half2, z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_3bit_32
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_3bit_32
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
const uint32_t q_2,
|
||||||
|
half2 (&dq)[16],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[32];
|
||||||
|
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
|
||||||
|
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
|
||||||
|
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
|
||||||
|
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
|
||||||
|
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
|
||||||
|
|
||||||
|
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
227
autogptq_extension/exllamav2/cuda/quant/qdq_4.cuh
Normal file
227
autogptq_extension/exllamav2/cuda/quant/qdq_4.cuh
Normal file
|
@ -0,0 +1,227 @@
|
||||||
|
#ifndef _qdq_4_cuh
|
||||||
|
#define _qdq_4_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
#include "../../config.h"
|
||||||
|
|
||||||
|
#if QMODE_4BIT == 1
|
||||||
|
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// 77775555 33331111 66664444 22220000
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_4bit_8
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0];
|
||||||
|
uint32_t qb = 0;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
uint32_t qa0 = qa & 0x0f;
|
||||||
|
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||||
|
qa >>= 8;
|
||||||
|
qb |= (qa1 << (i * 4 + 16));
|
||||||
|
qb |= (qa0 << (i * 4));
|
||||||
|
}
|
||||||
|
q[0] = qb;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||||
|
const half2 y16 = __halves2half2(y16_, y16_);
|
||||||
|
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||||
|
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||||
|
const half2 z1 = __halves2half2(z1_, z1_);
|
||||||
|
const half2 z16 = __halves2half2(z16_, z16_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||||
|
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1);
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
const half scale,
|
||||||
|
half2 (&z1z16)[2],
|
||||||
|
half2 (&y1y16)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
|
half2 scale2 = __half2half2(scale);
|
||||||
|
|
||||||
|
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||||
|
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||||
|
|
||||||
|
const half y1 = __float2half_rn(1.0f);
|
||||||
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
|
|
||||||
|
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||||
|
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
half2(&z1z16)[2],
|
||||||
|
half2(&y1y16)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||||
|
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||||
|
|
||||||
|
z1z16[0] = __half2half2(z1.as_half);
|
||||||
|
z1z16[1] = __half2half2(z16);
|
||||||
|
|
||||||
|
const half y1 = __float2half_rn(1.0f);
|
||||||
|
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||||
|
|
||||||
|
y1y16[0] = __half2half2(y1);
|
||||||
|
y1y16[1] = __half2half2(y16);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
half2 (&z1z16)[2],
|
||||||
|
half2 (&y1y16)[2],
|
||||||
|
int stride,
|
||||||
|
bool scaled
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||||
|
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||||
|
qa >>= 8;
|
||||||
|
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||||
|
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||||
|
|
||||||
|
if (scaled)
|
||||||
|
{
|
||||||
|
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||||
|
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||||
|
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||||
|
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||||
|
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_4bit_8
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[8];
|
||||||
|
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
const half scale,
|
||||||
|
half2 (&z1)[2],
|
||||||
|
half2 (&y1)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half z = __int2half_rn(-((int)zero));
|
||||||
|
z = __hmul(z, scale);
|
||||||
|
z1[0] = __half2half2(z);
|
||||||
|
y1[0] = __half2half2(scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||||
|
(
|
||||||
|
const uint32_t zero,
|
||||||
|
half2(&z1)[2],
|
||||||
|
half2(&y1)[2]
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half z = __int2half_rn(-((int)zero));
|
||||||
|
z1[0] = __half2half2(z);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
half2 (&z1)[2],
|
||||||
|
half2 (&y1)[2],
|
||||||
|
int stride,
|
||||||
|
bool scaled
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half2 dqh2[8];
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
for (int i = 0; i < 4; i++)
|
||||||
|
{
|
||||||
|
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||||
|
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||||
|
dqh2[i] = __halves2half2(d0, d1);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scaled)
|
||||||
|
{
|
||||||
|
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||||
|
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||||
|
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||||
|
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||||
|
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||||
|
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||||
|
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
207
autogptq_extension/exllamav2/cuda/quant/qdq_5.cuh
Normal file
207
autogptq_extension/exllamav2/cuda/quant/qdq_5.cuh
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
#ifndef _qdq_5_cuh
|
||||||
|
#define _qdq_5_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
#include "../../config.h"
|
||||||
|
|
||||||
|
#if QMODE_5BIT == 1
|
||||||
|
|
||||||
|
// Permutation:
|
||||||
|
//
|
||||||
|
// v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||||
|
// vbbbbb99 99977777 uaaaaa88 88866666
|
||||||
|
// vhhhhhff fffddddd ugggggee eeeccccc
|
||||||
|
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||||
|
// vtttttrr rrrppppp usssssqq qqqooooo
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_5bit_32
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
uint32_t qa = q[0 * stride];
|
||||||
|
uint32_t qb = q[1 * stride];
|
||||||
|
uint32_t qc = q[2 * stride];
|
||||||
|
uint32_t qd = q[3 * stride];
|
||||||
|
uint32_t qe = q[4 * stride];
|
||||||
|
|
||||||
|
// qa: 66555554 44443333 32222211 11100000
|
||||||
|
// qb: ccccbbbb baaaaa99 99988888 77777666
|
||||||
|
// qc: jiiiiihh hhhggggg fffffeee eedddddc
|
||||||
|
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
|
||||||
|
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
|
||||||
|
|
||||||
|
uint32_t qf = qe >> 22;
|
||||||
|
qe <<= 8;
|
||||||
|
qe |= qd >> 24;
|
||||||
|
qd <<= 6;
|
||||||
|
qd |= qc >> 26;
|
||||||
|
qc <<= 4;
|
||||||
|
qc |= qb >> 28;
|
||||||
|
qb <<= 2;
|
||||||
|
qb |= qa >> 30;
|
||||||
|
|
||||||
|
// qa: 555554 44443333 32222211 11100000
|
||||||
|
// qb: bbbbba aaaa9999 98888877 77766666
|
||||||
|
// qc: hhhhhg ggggffff feeeeedd dddccccc
|
||||||
|
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
|
||||||
|
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
|
||||||
|
// qf: vv vvvuuuuu
|
||||||
|
|
||||||
|
uint32_t za = 0;
|
||||||
|
uint32_t zb = 0;
|
||||||
|
uint32_t zc = 0;
|
||||||
|
uint32_t zd = 0;
|
||||||
|
uint32_t ze = 0;
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
|
||||||
|
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
|
||||||
|
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
|
||||||
|
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
|
||||||
|
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
|
||||||
|
|
||||||
|
// za: 5555533 33311111 4444422 22200000
|
||||||
|
// zb: bbbbb99 99977777 aaaaa88 88866666
|
||||||
|
// zc: hhhhhff fffddddd gggggee eeeccccc
|
||||||
|
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
|
||||||
|
// ze: tttttrr rrrppppp sssssqq qqqooooo
|
||||||
|
// qf: vv vvvuuuuu
|
||||||
|
|
||||||
|
za |= ((qf & 0x001) >> 0) << 15;
|
||||||
|
zb |= ((qf & 0x002) >> 1) << 15;
|
||||||
|
zc |= ((qf & 0x004) >> 2) << 15;
|
||||||
|
zd |= ((qf & 0x008) >> 3) << 15;
|
||||||
|
ze |= ((qf & 0x010) >> 4) << 15;
|
||||||
|
za |= ((qf & 0x020) >> 5) << 31;
|
||||||
|
zb |= ((qf & 0x040) >> 6) << 31;
|
||||||
|
zc |= ((qf & 0x080) >> 7) << 31;
|
||||||
|
zd |= ((qf & 0x100) >> 8) << 31;
|
||||||
|
ze |= ((qf & 0x200) >> 9) << 31;
|
||||||
|
|
||||||
|
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||||
|
// zb: vbbbbb99 99977777 uaaaaa88 88866666
|
||||||
|
// zc: vhhhhhff fffddddd ugggggee eeeccccc
|
||||||
|
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||||
|
// ze: vtttttrr rrrppppp usssssqq qqqooooo
|
||||||
|
|
||||||
|
q[0 * stride] = za;
|
||||||
|
q[1 * stride] = zb;
|
||||||
|
q[2 * stride] = zc;
|
||||||
|
q[3 * stride] = zd;
|
||||||
|
q[4 * stride] = ze;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_5bit_32
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
const uint32_t q_2,
|
||||||
|
const uint32_t q_3,
|
||||||
|
const uint32_t q_4,
|
||||||
|
half2 (&dq)[16],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
const uint32_t c0 = 0x64006400;
|
||||||
|
const half y32_ = __float2half_rn(1.0f / 32.0f);
|
||||||
|
const half2 y32 = __halves2half2(y32_, y32_);
|
||||||
|
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
|
||||||
|
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
|
||||||
|
const half2 z1 = __halves2half2(z1_, z1_);
|
||||||
|
const half2 z32 = __halves2half2(z32_, z32_);
|
||||||
|
|
||||||
|
uint32_t qa = q_0;
|
||||||
|
uint32_t qb = q_1;
|
||||||
|
uint32_t qc = q_2;
|
||||||
|
uint32_t qd = q_3;
|
||||||
|
uint32_t qe = q_4;
|
||||||
|
|
||||||
|
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||||
|
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
|
||||||
|
qa >>= 10;
|
||||||
|
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||||
|
qa >>= 5;
|
||||||
|
qa &= 0x00010001;
|
||||||
|
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
|
||||||
|
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
|
||||||
|
qb >>= 10;
|
||||||
|
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
|
||||||
|
qb >>= 4;
|
||||||
|
qb &= 0x00020002;
|
||||||
|
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
|
||||||
|
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
|
||||||
|
qc >>= 10;
|
||||||
|
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
|
||||||
|
qc >>= 3;
|
||||||
|
qc &= 0x00040004;
|
||||||
|
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
|
||||||
|
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
|
||||||
|
qd >>= 10;
|
||||||
|
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
|
||||||
|
qd >>= 2;
|
||||||
|
qd &= 0x00080008;
|
||||||
|
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
|
||||||
|
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
|
||||||
|
qe >>= 10;
|
||||||
|
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
|
||||||
|
qe >>= 1;
|
||||||
|
qe &= 0x00100010;
|
||||||
|
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
|
||||||
|
|
||||||
|
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||||
|
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
|
||||||
|
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||||
|
dq[ 3] = __hadd2( q3.as_half2, z1);
|
||||||
|
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
|
||||||
|
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||||
|
dq[ 6] = __hadd2( q6.as_half2, z1);
|
||||||
|
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
|
||||||
|
dq[ 8] = __hadd2( q8.as_half2, z1);
|
||||||
|
dq[ 9] = __hadd2( q9.as_half2, z1);
|
||||||
|
dq[10] = __hfma2(q10.as_half2, y32, z32);
|
||||||
|
dq[11] = __hadd2(q11.as_half2, z1);
|
||||||
|
dq[12] = __hadd2(q12.as_half2, z1);
|
||||||
|
dq[13] = __hfma2(q13.as_half2, y32, z32);
|
||||||
|
dq[14] = __hadd2(q14.as_half2, z1);
|
||||||
|
dq[15] = __hadd2(q15.as_half2, z1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_5bit_32
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_5bit_32
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
const uint32_t q_2,
|
||||||
|
const uint32_t q_3,
|
||||||
|
const uint32_t q_4,
|
||||||
|
half2 (&dq)[16],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[32];
|
||||||
|
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
|
||||||
|
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
|
||||||
|
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
|
||||||
|
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
|
||||||
|
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
|
||||||
|
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
|
||||||
|
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
|
||||||
|
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
|
||||||
|
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
|
||||||
|
|
||||||
|
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
44
autogptq_extension/exllamav2/cuda/quant/qdq_6.cuh
Normal file
44
autogptq_extension/exllamav2/cuda/quant/qdq_6.cuh
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
#ifndef _qdq_6_cuh
|
||||||
|
#define _qdq_6_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
#include "../../config.h"
|
||||||
|
|
||||||
|
#if QMODE_6BIT == 1
|
||||||
|
|
||||||
|
// Not implemented
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_6bit_16
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_6bit_16
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
const uint32_t q_2,
|
||||||
|
half2 (&dq)[8],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[16];
|
||||||
|
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
|
||||||
|
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
|
||||||
|
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
|
||||||
|
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
|
||||||
|
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
|
||||||
|
|
||||||
|
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
38
autogptq_extension/exllamav2/cuda/quant/qdq_8.cuh
Normal file
38
autogptq_extension/exllamav2/cuda/quant/qdq_8.cuh
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
#ifndef _qdq_8_cuh
|
||||||
|
#define _qdq_8_cuh
|
||||||
|
|
||||||
|
#include "qdq_util.cuh"
|
||||||
|
#include "../../config.h"
|
||||||
|
|
||||||
|
#if QMODE_8BIT == 1
|
||||||
|
|
||||||
|
// Not implemented
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
__forceinline__ __device__ void shuffle_8bit_4
|
||||||
|
(
|
||||||
|
uint32_t* q,
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ void dequant_8bit_8
|
||||||
|
(
|
||||||
|
const uint32_t q_0,
|
||||||
|
const uint32_t q_1,
|
||||||
|
half2 (&dq)[4],
|
||||||
|
int stride
|
||||||
|
)
|
||||||
|
{
|
||||||
|
half dqh[8];
|
||||||
|
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
|
||||||
|
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
51
autogptq_extension/exllamav2/cuda/quant/qdq_util.cuh
Normal file
51
autogptq_extension/exllamav2/cuda/quant/qdq_util.cuh
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
#ifndef _qdq_util_cuh
|
||||||
|
#define _qdq_util_cuh
|
||||||
|
|
||||||
|
union half2_uint32
|
||||||
|
{
|
||||||
|
uint32_t as_uint32;
|
||||||
|
half2 as_half2;
|
||||||
|
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||||
|
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
union half_uint16
|
||||||
|
{
|
||||||
|
uint16_t as_uint16;
|
||||||
|
half as_half;
|
||||||
|
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||||
|
__device__ half_uint16(half val) : as_half(val) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Max_scale premultiplied by 1/256
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||||
|
{
|
||||||
|
int qs_i = qs + 1;
|
||||||
|
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||||
|
qs_h = __hmul(qs_h, max_scale);
|
||||||
|
return qs_h;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||||
|
{
|
||||||
|
return __hmul(__int2half_rn(q - qzero), scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||||
|
{
|
||||||
|
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||||
|
return __int2half_rn(q - qzero);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||||
|
{
|
||||||
|
return (int)((q >> shift) & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||||
|
{
|
||||||
|
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
32
autogptq_extension/exllamav2/cuda/util.cuh
Normal file
32
autogptq_extension/exllamav2/cuda/util.cuh
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
|
||||||
|
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||||
|
|
||||||
|
#define DBGS(__x) printf("%s\n", __x)
|
||||||
|
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||||
|
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||||
|
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||||
|
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
|
||||||
|
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
|
||||||
|
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
|
||||||
|
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||||
|
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||||
|
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||||
|
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
|
||||||
|
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
|
||||||
|
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
|
||||||
|
|
||||||
|
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
|
||||||
|
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
|
||||||
|
|
||||||
|
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
|
||||||
|
{
|
||||||
|
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
|
||||||
|
qs_h = __hmul(qs_h, qs_h);
|
||||||
|
qs_h = __hmul(qs_h, max_scale);
|
||||||
|
return qs_h;
|
||||||
|
}
|
||||||
|
|
||||||
|
__forceinline__ __device__ float clamp(float x, float a, float b)
|
||||||
|
{
|
||||||
|
return fmaxf(a, fminf(b, x));
|
||||||
|
}
|
134
autogptq_extension/exllamav2/ext.cpp
Normal file
134
autogptq_extension/exllamav2/ext.cpp
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
#include <torch/extension.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
#include <cuda_fp16.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <cstdio>
|
||||||
|
|
||||||
|
#include "config.h"
|
||||||
|
|
||||||
|
#include "cuda/q_matrix.cuh"
|
||||||
|
#include "cuda/q_gemm.cuh"
|
||||||
|
|
||||||
|
#include "cpp/util.h"
|
||||||
|
|
||||||
|
// Some decluttering macros
|
||||||
|
|
||||||
|
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||||
|
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||||
|
|
||||||
|
|
||||||
|
// Quant matrix
|
||||||
|
|
||||||
|
uintptr_t make_q_matrix
|
||||||
|
(
|
||||||
|
torch::Tensor q_weight,
|
||||||
|
torch::Tensor q_perm,
|
||||||
|
torch::Tensor q_invperm,
|
||||||
|
torch::Tensor q_scale,
|
||||||
|
torch::Tensor q_scale_max,
|
||||||
|
torch::Tensor q_groups,
|
||||||
|
torch::Tensor gptq_qzeros,
|
||||||
|
torch::Tensor gptq_scales,
|
||||||
|
torch::Tensor gptq_g_idx,
|
||||||
|
torch::Tensor temp_dq
|
||||||
|
)
|
||||||
|
{
|
||||||
|
TORCH_CHECK_DTYPE(q_weight, kInt);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||||
|
|
||||||
|
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
||||||
|
|
||||||
|
int device = q_weight.device().index();
|
||||||
|
int width = q_weight.size(1);
|
||||||
|
int groups;
|
||||||
|
int height;
|
||||||
|
|
||||||
|
if (!q_scale.device().is_meta())
|
||||||
|
{
|
||||||
|
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
|
||||||
|
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
|
||||||
|
groups = q_scale.size(0);
|
||||||
|
height = q_invperm.size(0);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
|
||||||
|
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
|
||||||
|
groups = gptq_qzeros.size(0);
|
||||||
|
height = q_weight.size(0) * 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
||||||
|
|
||||||
|
QMatrix* m = new QMatrix
|
||||||
|
(
|
||||||
|
device,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
groups,
|
||||||
|
(uint32_t*) q_weight.data_ptr(),
|
||||||
|
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
|
||||||
|
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
|
||||||
|
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||||
|
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||||
|
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||||
|
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||||
|
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||||
|
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||||
|
(half*) temp_dq.data_ptr()
|
||||||
|
);
|
||||||
|
|
||||||
|
return reinterpret_cast<uintptr_t> (m);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_half_q_half
|
||||||
|
(
|
||||||
|
torch::Tensor a,
|
||||||
|
uintptr_t b,
|
||||||
|
torch::Tensor c,
|
||||||
|
bool force_cuda
|
||||||
|
)
|
||||||
|
{
|
||||||
|
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
||||||
|
|
||||||
|
TORCH_CHECK_DTYPE(a, kHalf);
|
||||||
|
TORCH_CHECK_DTYPE(c, kHalf);
|
||||||
|
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
|
||||||
|
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
|
||||||
|
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
|
||||||
|
|
||||||
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
|
|
||||||
|
gemm_half_q_half_cuda
|
||||||
|
(
|
||||||
|
at::cuda::getCurrentCUDABlasHandle(),
|
||||||
|
(const half*) a.data_ptr(),
|
||||||
|
qm,
|
||||||
|
(half*) c.data_ptr(),
|
||||||
|
c.size(0), // m
|
||||||
|
c.size(1), // n
|
||||||
|
a.size(1), // k
|
||||||
|
true,
|
||||||
|
NULL,
|
||||||
|
force_cuda
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bindings
|
||||||
|
|
||||||
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||||
|
{
|
||||||
|
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
|
||||||
|
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
|
||||||
|
}
|
11
setup.py
11
setup.py
|
@ -158,6 +158,17 @@ if BUILD_CUDA_EXT:
|
||||||
extra_link_args=extra_link_args
|
extra_link_args=extra_link_args
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
extensions.append(
|
||||||
|
cpp_extension.CUDAExtension(
|
||||||
|
"exllamav2_kernels",
|
||||||
|
[
|
||||||
|
"autogptq_extension/exllamav2/ext.cpp",
|
||||||
|
"autogptq_extension/exllamav2/cuda/q_matrix.cu",
|
||||||
|
"autogptq_extension/exllamav2/cuda/q_gemm.cu",
|
||||||
|
],
|
||||||
|
extra_link_args=extra_link_args
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
additional_setup_kwargs = {
|
additional_setup_kwargs = {
|
||||||
"ext_modules": extensions,
|
"ext_modules": extensions,
|
||||||
|
|
233
tests/test_q4.py
233
tests/test_q4.py
|
@ -143,7 +143,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
||||||
n = 1024
|
n = 1024
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4)
|
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllama=False, disable_exllamav2=True)
|
||||||
|
|
||||||
linear = linear_class(
|
linear = linear_class(
|
||||||
bits=4,
|
bits=4,
|
||||||
|
@ -197,7 +197,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
||||||
revision = "actorder"
|
revision = "actorder"
|
||||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||||
|
|
||||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
|
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
@ -227,7 +227,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
||||||
|
|
||||||
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
|
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
|
||||||
model_basename = "model"
|
model_basename = "model"
|
||||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename)
|
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
@ -249,7 +249,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
||||||
revision = "actorder"
|
revision = "actorder"
|
||||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||||
|
|
||||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
|
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
@ -338,7 +338,7 @@ class TestsQ4CUDA(unittest.TestCase):
|
||||||
n = 256
|
n = 256
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
|
|
||||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllama=True)
|
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllamav2=True)
|
||||||
|
|
||||||
linear = linear_class(
|
linear = linear_class(
|
||||||
bits=4,
|
bits=4,
|
||||||
|
@ -369,3 +369,226 @@ class TestsQ4CUDA(unittest.TestCase):
|
||||||
reference = self.REFERENCE_OLD_NO_HALF.to(device)
|
reference = self.REFERENCE_OLD_NO_HALF.to(device)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(res, reference), get_diff(res, reference))
|
self.assertTrue(torch.allclose(res, reference), get_diff(res, reference))
|
||||||
|
|
||||||
|
|
||||||
|
class TestsQ4ExllamaV2(unittest.TestCase):
|
||||||
|
|
||||||
|
# reference generated with cuda_old
|
||||||
|
REFERENCE = torch.Tensor([5.8398, 6.8555, 7.2734, 6.4219, 6.2070, 5.8203, 6.5664, 6.4219, 6.2148,
|
||||||
|
5.3281, 5.7578, 7.5312, 8.1016, 6.1133, 7.2031, 6.6484, 6.5156, 6.0117,
|
||||||
|
6.0312, 6.1914, 6.2109, 6.8125, 5.8125, 7.1172, 7.3125, 6.7305, 5.9961,
|
||||||
|
6.5117, 6.1914, 5.9648, 7.1680, 6.4766, 7.2070, 6.5469, 6.7734, 6.4219,
|
||||||
|
6.8086, 7.0469, 5.9297, 6.4727, 6.2539, 5.9570, 7.2383, 5.8945, 6.0820,
|
||||||
|
5.7969, 7.1094, 6.2188, 6.7500, 7.3555, 6.2930, 6.7734, 5.9219, 7.4805,
|
||||||
|
6.8750, 6.4102, 6.5898, 6.5469, 7.6016, 6.7461, 5.9492, 7.2227, 5.8164,
|
||||||
|
5.4570, 6.2930, 7.3984, 6.0938, 7.3984, 5.9609, 6.3516, 6.5664, 5.7969,
|
||||||
|
7.1250, 6.0781, 6.7930, 5.9492, 6.1641, 6.5898, 6.0586, 6.3359, 6.7930,
|
||||||
|
7.0469, 6.0664, 6.3320, 5.4414, 6.7617, 5.1641, 7.2891, 6.8516, 6.5312,
|
||||||
|
5.6914, 7.3711, 6.8203, 5.9492, 7.0781, 6.3164, 7.1992, 7.1133, 7.4219,
|
||||||
|
7.5586, 7.1836, 6.9102, 6.4844, 6.9805, 6.1953, 6.5156, 5.4844, 6.6602,
|
||||||
|
6.6719, 7.9844, 6.4727, 6.6367, 6.2227, 6.4531, 5.0625, 6.4609, 6.7031,
|
||||||
|
6.6445, 6.5234, 6.8633, 6.6055, 5.6055, 6.4453, 7.2617, 6.3945, 6.6367,
|
||||||
|
6.1055, 7.0664, 6.0820, 6.6875, 6.1445, 6.8672, 6.2070, 6.8828, 6.1484,
|
||||||
|
6.7070, 6.8516, 6.2734, 7.1055, 7.0586, 6.9648, 5.9727, 6.1016, 6.8750,
|
||||||
|
7.0078, 7.1523, 5.7383, 5.9531, 6.5508, 7.5352, 6.1602, 6.2578, 6.3906,
|
||||||
|
5.7383, 6.7031, 5.7344, 6.3516, 5.2852, 7.5312, 6.4531, 6.6406, 6.2266,
|
||||||
|
6.1094, 5.9102, 5.7617, 6.3789, 7.0508, 6.3750, 6.3320, 6.8555, 6.7266,
|
||||||
|
7.0352, 7.7695, 6.3984, 6.5039, 6.8320, 6.1602, 6.0312, 6.3828, 6.9023,
|
||||||
|
7.4336, 7.3711, 6.1016, 7.0703, 6.3281, 6.8281, 6.4922, 5.9453, 5.1016,
|
||||||
|
6.7188, 6.1406, 6.6289, 7.2695, 6.2070, 6.7070, 7.2930, 7.1836, 6.3828,
|
||||||
|
6.1992, 6.7070, 7.8008, 7.7773, 5.6602, 7.0273, 6.6172, 6.0898, 5.3516,
|
||||||
|
7.3359, 5.9727, 6.0078, 7.0586, 6.3086, 6.8555, 7.2617, 7.3477, 6.3828,
|
||||||
|
7.1133, 6.6328, 7.3516, 6.9141, 7.2031, 6.9805, 6.1719, 6.7812, 8.3047,
|
||||||
|
6.5898, 6.3633, 6.2539, 7.2773, 6.5938, 6.4141, 6.8203, 6.8906, 7.8828,
|
||||||
|
5.9609, 6.4180, 7.3984, 5.7539, 7.1758, 6.6641, 6.9062, 6.2578, 7.5508,
|
||||||
|
6.1719, 6.5742, 5.9375, 6.7891, 6.2109, 6.5039, 6.8750, 6.2031, 6.8828,
|
||||||
|
7.1094, 5.9570, 7.2969, 6.6797, 6.8828, 5.5430, 6.9648, 5.8398, 6.5430,
|
||||||
|
6.3945, 6.5664, 5.8086, 6.6172, 7.0586, 6.8867, 6.0820, 5.8125, 6.7070,
|
||||||
|
7.5742, 6.2578, 6.1328, 6.5391, 5.4531, 6.8242, 6.6953, 6.8008, 6.3398,
|
||||||
|
6.4805, 7.2266, 6.3281, 6.6875, 6.4688, 5.9414, 7.4297, 5.8711, 6.0625,
|
||||||
|
5.8750, 6.5664, 5.8867, 6.3477, 6.1133, 6.9453, 5.0547, 6.7812, 6.4922,
|
||||||
|
7.2422, 5.4688, 6.2109, 7.2148, 6.1758, 5.9297, 7.1953, 5.5195, 6.3203,
|
||||||
|
5.9961, 7.9297, 6.2695, 6.4414, 6.7266, 7.1875, 7.3203, 5.4062, 6.0625,
|
||||||
|
7.0898, 5.3828, 5.6133, 6.0742, 6.6836, 5.7109, 7.2852, 7.7539, 7.5820,
|
||||||
|
6.4258, 5.9336, 6.3750, 6.3555, 7.5469, 6.2539, 6.5898, 6.4102, 7.0469,
|
||||||
|
5.7344, 7.2031, 6.7969, 5.6836, 7.6523, 6.9297, 7.8672, 6.4766, 6.3008,
|
||||||
|
7.0977, 6.5430, 7.0938, 5.8398, 6.9883, 6.5312, 6.3203, 6.3594, 5.4062,
|
||||||
|
6.9688, 5.7930, 6.3164, 6.5547, 7.1992, 5.8750, 6.3008, 6.7930, 6.0391,
|
||||||
|
7.4766, 6.6094, 6.5625, 5.9805, 6.2422, 7.2109, 6.6875, 5.3047, 7.6211,
|
||||||
|
5.9453, 6.5625, 6.1641, 6.1250, 6.5977, 7.7422, 7.0742, 5.6875, 6.2656,
|
||||||
|
6.6250, 6.8945, 5.7070, 6.3203, 5.7500, 6.2695, 6.2773, 6.8516, 6.4883,
|
||||||
|
7.0000, 6.7578, 6.1875, 5.9844, 5.5703, 6.7188, 5.5273, 5.3438, 7.2500,
|
||||||
|
6.7852, 6.5195, 6.8125, 6.0664, 6.7852, 7.0000, 7.0781, 6.8477, 7.2930,
|
||||||
|
6.3438, 7.1523, 6.3281, 6.8047, 7.3203, 5.3359, 6.1484, 6.5586, 7.3828,
|
||||||
|
6.2344, 7.1523, 6.4102, 5.5898, 7.0195, 7.1172, 5.8008, 6.5742, 6.2891,
|
||||||
|
8.0312, 6.9023, 6.5898, 7.1953, 6.7266, 6.0078, 5.5430, 6.4766, 6.4258,
|
||||||
|
5.9648, 8.0859, 5.0547, 7.2188, 7.4375, 6.5156, 5.9922, 6.3281, 6.2852,
|
||||||
|
6.7734, 6.2461, 6.9805, 5.4648, 5.8867, 6.8242, 6.3008, 6.3281, 7.3047,
|
||||||
|
7.1836, 6.5195, 6.6328, 6.7188, 5.4336, 6.5078, 5.3477, 5.5508, 7.3125,
|
||||||
|
5.8750, 6.5195, 6.2383, 6.3594, 6.0898, 6.4141, 5.9844, 6.6250, 7.7109,
|
||||||
|
6.0391, 7.2344, 5.9453, 5.9453, 7.0586, 5.6641, 7.2773, 6.5195, 7.2227,
|
||||||
|
6.3359, 5.3203, 6.4375, 7.2383, 6.4023, 6.2148, 7.3750, 5.8164, 6.2109,
|
||||||
|
6.5430, 5.8164, 6.1680, 6.7656, 6.0820, 6.1094, 6.5312, 6.8906, 6.8320,
|
||||||
|
6.1289, 6.3125, 7.6797, 6.3008, 6.0000, 7.3320, 6.7852, 6.9297, 6.6328,
|
||||||
|
6.2266, 5.1602, 6.2031, 7.0547, 5.9492, 6.0703, 6.0977, 6.8086, 6.0742,
|
||||||
|
6.0195, 7.0625, 6.5781, 5.7461, 6.1562, 7.0430, 6.7148, 6.5312, 6.5820,
|
||||||
|
6.4570, 7.5508, 5.6289, 6.0547, 6.5000, 7.3125, 5.8477, 5.9297, 6.2578,
|
||||||
|
6.0078, 5.9922, 7.3398, 7.4922, 7.8906, 7.5547, 5.4648, 6.5156, 6.3242,
|
||||||
|
6.1094, 6.9219, 6.7227, 6.6836, 7.4023, 5.9648, 7.2383, 6.7695, 6.6797,
|
||||||
|
7.0547, 6.3047, 6.4688, 6.9961, 6.0391, 5.9727, 6.8398, 6.7422, 5.7656,
|
||||||
|
5.4766, 6.7852, 7.0820, 5.3516, 7.6523, 5.1562, 6.6445, 6.1211, 6.2695,
|
||||||
|
6.0703, 6.3594, 6.4062, 6.3398, 5.7578, 6.5391, 6.2500, 6.5742, 6.5000,
|
||||||
|
7.5625, 7.0117, 6.5547, 7.1250, 6.4453, 6.6094, 6.1875, 6.4219, 6.6172,
|
||||||
|
6.4336, 6.5703, 6.1758, 6.4219, 6.6016, 6.7383, 6.7070, 6.1328, 5.5586,
|
||||||
|
6.6367, 6.3789, 6.2578, 5.5039, 6.6172, 6.4648, 5.8086, 7.2031, 5.8125,
|
||||||
|
6.3711, 7.6758, 7.1289, 5.8086, 6.3008, 6.2109, 6.1602, 6.1797, 7.2305,
|
||||||
|
6.7266, 6.2422, 5.6719, 6.7070, 6.9414, 6.8594, 7.4023, 7.2109, 6.0156,
|
||||||
|
6.6680, 6.6172, 7.1250, 6.6523, 6.9531, 6.7617, 6.4961, 6.9414, 5.7188,
|
||||||
|
7.6367, 6.5469, 6.2305, 6.4414, 7.4648, 5.9102, 6.2461, 6.1367, 6.8203,
|
||||||
|
6.5703, 6.8867, 7.0000, 6.7539, 6.1719, 6.5469, 6.2422, 5.4297, 5.7305,
|
||||||
|
5.1641, 6.1875, 7.0312, 6.6484, 6.0234, 7.4102, 6.8711, 6.3086, 6.3711,
|
||||||
|
6.7344, 6.6992, 5.9766, 7.3906, 7.1875, 6.4883, 6.3984, 7.3438, 6.9688,
|
||||||
|
6.9062, 6.4375, 6.7891, 7.0117, 6.4883, 5.7500, 7.0898, 7.0742, 6.7070,
|
||||||
|
5.8750, 6.0469, 6.6445, 5.2773, 6.8984, 6.1641, 7.0508, 7.4609, 5.0273,
|
||||||
|
6.7734, 6.4531, 5.7656, 6.5312, 7.4648, 6.1250, 6.5625, 7.1367, 6.0625,
|
||||||
|
6.1211, 6.9766, 6.6758, 6.3164, 6.8828, 6.8203, 6.7500, 6.5352, 7.3008,
|
||||||
|
6.7852, 6.1914, 5.0508, 6.7188, 7.1172, 6.8008, 6.8086, 5.4883, 6.9180,
|
||||||
|
6.5742, 6.1719, 7.0469, 7.1523, 5.9492, 5.8594, 6.8320, 6.1719, 6.2031,
|
||||||
|
6.8398, 7.3008, 6.6289, 6.4922, 6.0000, 5.4766, 6.3320, 6.5117, 6.2812,
|
||||||
|
7.5742, 6.3516, 7.0039, 6.4570, 7.1523, 7.6289, 6.2578, 7.1875, 6.4844,
|
||||||
|
5.7930, 6.7070, 7.5508, 7.1797, 6.0430, 6.8711, 6.5742, 7.5781, 6.4766,
|
||||||
|
6.5391, 6.9453, 6.1992, 6.6367, 6.2812, 6.0234, 6.6953, 7.0312, 6.2031,
|
||||||
|
6.5625, 6.6719, 6.1719, 6.5586, 5.7031, 7.4609, 6.6211, 7.7227, 6.9141,
|
||||||
|
6.0469, 6.2500, 5.3828, 6.0078, 5.8164, 5.8867, 6.1523, 6.6523, 6.6953,
|
||||||
|
7.3125, 6.4844, 5.9570, 5.9531, 6.2109, 5.5039, 6.5117, 6.8203, 6.6133,
|
||||||
|
6.4766, 5.9297, 7.1445, 7.1914, 6.0117, 6.8281, 6.7422, 6.1328, 6.9805,
|
||||||
|
6.5625, 6.9180, 7.1133, 7.3359, 5.7617, 5.8711, 6.4961, 6.5859, 6.2422,
|
||||||
|
6.5273, 6.7461, 6.6992, 6.7695, 6.6289, 5.9453, 5.9805, 7.1172, 6.6719,
|
||||||
|
6.0039, 7.6875, 6.7812, 7.8359, 6.9531, 7.4336, 7.6602, 6.8164, 7.3945,
|
||||||
|
7.1602, 6.8789, 5.0078, 6.0547, 6.8086, 6.7070, 6.4688, 6.4492, 6.6172,
|
||||||
|
5.5625, 6.6914, 6.4297, 5.7461, 5.3359, 6.8750, 6.4609, 7.4062, 5.2070,
|
||||||
|
6.0820, 6.7383, 6.5703, 6.1797, 6.7070, 6.5977, 5.9961, 6.6328, 6.9375,
|
||||||
|
6.3906, 6.6484, 4.9609, 6.6445, 6.5898, 7.1875, 7.5195, 6.7969, 6.1367,
|
||||||
|
6.8906, 7.4297, 6.3633, 6.0508, 6.5000, 6.4648, 6.7539, 6.7109, 5.8086,
|
||||||
|
6.6016, 7.1133, 4.8672, 6.6367, 6.1641, 5.1758, 6.9453, 6.3242, 7.0664,
|
||||||
|
6.4805, 6.3516, 6.7383, 8.4688, 6.7305, 5.9844, 6.5938, 7.2969, 6.5977,
|
||||||
|
7.5898, 6.2969, 6.8672, 6.6680, 7.1289, 6.6875, 5.4258, 8.1875, 8.0391,
|
||||||
|
7.7969, 6.6445, 7.0703, 7.3359, 6.9805, 6.6328, 6.5352, 6.2422, 5.5820,
|
||||||
|
6.8633, 6.8047, 6.5703, 6.0117, 6.7539, 7.1719, 6.8438, 7.3633, 6.6016,
|
||||||
|
7.2070, 6.4727, 5.8008, 7.4062, 7.4805, 6.6445, 5.9023, 6.3984, 6.9961,
|
||||||
|
6.6680, 6.8242, 6.7148, 6.6172, 6.9727, 6.8320, 5.9766, 6.6133, 5.5977,
|
||||||
|
6.7773, 7.3906, 6.9219, 7.0781, 6.6914, 5.7539, 6.7969, 6.8008, 5.8047,
|
||||||
|
7.1055, 6.4961, 6.0352, 5.6211, 7.4414, 7.0703, 6.1172, 6.7461, 6.4492,
|
||||||
|
7.7148, 6.4258, 6.0039, 6.5156, 7.2188, 7.4531, 7.4844, 7.5938, 7.4023,
|
||||||
|
6.7617, 6.0078, 6.3320, 5.8906, 7.5977, 5.6523, 6.7734, 6.3008, 5.2227,
|
||||||
|
7.1719, 7.1289, 6.6602, 5.4609, 7.0312, 6.0820, 6.1719, 6.0000, 6.5547,
|
||||||
|
6.6328, 7.0547, 7.0859, 6.2656, 5.5234, 6.0273, 6.7891, 7.1875, 6.9531,
|
||||||
|
6.8203, 6.3516, 6.1172, 6.4648, 6.9180, 7.3906, 6.2812, 5.7109, 6.1484,
|
||||||
|
6.9102, 6.8711, 7.0156, 6.1445, 5.8867, 6.3828, 5.9961, 6.6914, 6.7891,
|
||||||
|
7.0820, 6.6719, 6.9297, 6.3750, 6.7578, 6.4883, 6.2227, 6.2305, 6.0508,
|
||||||
|
6.6484, 5.7578, 7.2070, 7.2383, 6.9375, 7.2578, 6.5312, 6.0312, 6.7930,
|
||||||
|
6.2578, 7.0625, 7.2148, 6.4961, 7.0703, 6.4727, 7.3906]).to(torch.float16)
|
||||||
|
|
||||||
|
def test_exllamav2(self):
|
||||||
|
from auto_gptq.nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
|
||||||
|
|
||||||
|
group_size = 128
|
||||||
|
|
||||||
|
m = 1
|
||||||
|
k = 1024
|
||||||
|
n = 1024
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4)
|
||||||
|
|
||||||
|
linear = linear_class(
|
||||||
|
bits=4,
|
||||||
|
group_size=group_size,
|
||||||
|
infeatures=k,
|
||||||
|
outfeatures=n,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(isinstance(linear, QuantLinear))
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
|
||||||
|
linear.scales = linear.scales + 0.002
|
||||||
|
|
||||||
|
linear = linear.eval()
|
||||||
|
linear = linear.to(device)
|
||||||
|
|
||||||
|
linear = autogptq_post_init(linear, use_act_order=False)
|
||||||
|
|
||||||
|
inp = torch.rand(1, m, k, dtype=torch.float16).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
res = linear(inp)[0][0]
|
||||||
|
|
||||||
|
reference = self.REFERENCE.to(device)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(res, reference, rtol=3e-5, atol=2e-2), get_diff(res, reference))
|
||||||
|
|
||||||
|
def test_generation_no_act_order(self):
|
||||||
|
prompt = "I am in Paris and"
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
# Reference generated with the cuda-old kernel
|
||||||
|
reference_output = "<s> I am in Paris and I am going to the Louvre Museum. What time does it open and what is the best way to get there?\nThe Louvre Museum in Paris is open from 9:00 AM to 6:00 PM every day except for Tuesdays. The best way to get"
|
||||||
|
|
||||||
|
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
|
||||||
|
model_basename = "model"
|
||||||
|
|
||||||
|
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, model_basename=model_basename)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
|
||||||
|
|
||||||
|
predicted_text = tokenizer.decode(res[0])
|
||||||
|
|
||||||
|
|
||||||
|
self.assertEqual(predicted_text, reference_output)
|
||||||
|
|
||||||
|
def test_generation_with_act_order(self):
|
||||||
|
prompt = "I am in Paris and"
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
# Reference generated with the cuda-old kernel
|
||||||
|
reference_output = "<s> I am in Paris and it is a beautiful day. I am sitting in a café, drinking coffee and writing this book. I am surrounded by the sights and sounds of the city, and I am filled with a sense of contentment and gratitude.\n\nI am grateful for the opportunity to live and"
|
||||||
|
|
||||||
|
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
|
||||||
|
revision = "actorder"
|
||||||
|
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||||
|
|
||||||
|
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
|
||||||
|
|
||||||
|
predicted_text = tokenizer.decode(res[0])
|
||||||
|
|
||||||
|
self.assertEqual(predicted_text, reference_output)
|
||||||
|
|
||||||
|
def test_exllama_buffer_size(self):
|
||||||
|
# prompt = "I'm in Paris and" * 450
|
||||||
|
prompt = "I'm in Paris and" * 1000
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
|
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
|
||||||
|
revision = "actorder"
|
||||||
|
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||||
|
|
||||||
|
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||||
|
|
||||||
|
self.assertTrue(inp["input_ids"].shape[1] > 2048) # 2048 is the default max_input_length for LLama
|
||||||
|
|
||||||
|
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
|
Loading…
Add table
Reference in a new issue