Merge branch 'main' into main
This commit is contained in:
commit
62fd0371ac
32 changed files with 3385 additions and 26 deletions
|
@ -15,7 +15,12 @@
|
|||
</p>
|
||||
</h4>
|
||||
|
||||
*<center>📣 Long time no see! 👋 Architecture upgrade, performance optimization and more new features will come in July and August, stay tune! 🥂</center>*
|
||||
|
||||
## The path to v1.0.0
|
||||
|
||||
Hi, fellow community members, long time no see! I'm sorry that I haven't been able to update this project more frequently due to personal reasons during this period. The past few weeks have been huge in terms of my career plans. Not long ago, I officially bid farewell to the startup team that I joined for two years after graduation. I'm very grateful to the leaders and colleagues of the team for their trust and guidance, which enabled me to grow rapidly in two years; at the same time, I'm also really grateful to the team for allowing me to use the internal A100 GPU server cluster free of charge since the start of the AutoGPTQ project to complete various experiments and performance evaluations. (Of course, it can no longer be used in the future, so **it will mean a lot to me if there will be new hardware sponsorship!**) In the past two years, I have served as an AI engineer in this team, responsible for the LLM based dialogue system's architecture design and develop. We had successfully launched a product called gemsouls, but unfortunately it has ceased operations. Now, the team is about to launch a new product called [modelize](https://www.beta.modelize.ai/), which is **a LLM-native AI agent platform, where users can use multiple AI agents to build a highly automated team, allowing them to interact with each other in the workflow, collaborate to complete complex projects efficiently.**
|
||||
|
||||
Getting back to the topic, I'm very excited to see that in the past few months, research on optimizing the inference performance of LLMs has made tremendous progress. Now we can not only complete the inference of LLMs on high-end GPUs efficiently, but also on CPUs and even edge devices. A series of technological advancements make me eager to make more contributions to the open source community. Therefore, I will first use about four weeks to gradually update AutoGPTQ to the v1.0.0 official version. During this period, there will also be 2~3 minor versions are released to allow users to experience performance optimization and new features timely. In my vision, **by the time v1.0.0 is officially released, AutoGPTQ will be able to serve as an extendable and flexible quantization backend that supports all GPTQ-like methods and automatically quantize LLMs written by Pytorch**. I detailed the development plan in [this issue](https://github.com/PanQiWei/AutoGPTQ/issues/348), feel free to drop in there for discussion and give your suggestions!
|
||||
|
||||
## News or Update
|
||||
|
||||
|
|
|
@ -15,7 +15,11 @@
|
|||
</p>
|
||||
</h4>
|
||||
|
||||
*<center>📣 好久不见!👋 七月和八月将会迎来架构升级,性能优化和新特性,敬请关注!🥂</center>*
|
||||
## 通向 v1.0.0 之路
|
||||
|
||||
嗨,社区的伙伴们,好久不见!很抱歉这段时间由于个人原因,我没能以较高的频率来更新这个项目。过去几周对我的职业生涯规划而言意义重大。在不久前,我正式告别了毕业后便加入两年之久的创业团队,非常感谢团队的领导和同事们给予我的信任与指导,让我能够在两年时间里飞速地成长;同时也十分感激团队允许我自 AutoGPTQ 项目创立以来一直无偿使用内部的 A100 GPU 服务器集群以完成各项实验与性能测评。(当然今后是无法继续使用了,因此**若有新的硬件赞助我将感激不尽**!)过去的两年里,我在这个团队中担任算法工程师的角色,负责基于大语言模型的对话系统架构设计与开发,我们曾成功推出一款名为 gemsouls 的产品,但不幸的是它已经停止运营。而现在,这个团队即将推出一款名为 [modelize](https://www.beta.modelize.ai/) 的新产品,**这是一个大模型原生的 AI 智能体平台,用户可以使用多个 AI 智能体搭建一个高度自动化的团队,让它们在工作流中相互合作,高效完成复杂的项目。**
|
||||
|
||||
话归正题,我非常兴奋地看到,在过去几个月的时间里,针对大语言模型推理性能优化的研究取得了巨大的进展,如今我们不仅能够在高端显卡上完成大语言模型的推理,甚至在 CPU 和边缘设备上都可以轻松运行大语言模型。一系列的技术进步,让我同样迫不及待地在开源社区上做出更多的贡献,因此,首先,我将用约四周的时间将 AutoGPTQ 迭代至 v1.0.0 正式版本,在此期间,也会有 2~3 个小版本发布以让用户能够及时体验性能优化和新特性。在我的愿景里,**到 v1.0.0 版本正式发布时,AutoGPTQ 将能够作为一个灵活可拓展的、支持所有 GPTQ-like 方法的量化后端,自动地完成各种基于 Pytorch 编写的大语言模型的量化工作**。我在[这里](https://github.com/PanQiWei/AutoGPTQ/issues/348)详细介绍了开发计划,欢迎移步至此进行讨论并给出你们的建议!
|
||||
|
||||
## 新闻或更新
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from ..nn_modules._fused_base import FusedBaseAttentionModule, FusedBaseMLPModul
|
|||
from ..quantization import GPTQ
|
||||
from ..utils.data_utils import collate_data
|
||||
from ..utils.import_utils import (
|
||||
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE
|
||||
dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE, EXLLAMAV2_KERNELS_AVAILABLE
|
||||
)
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
@ -91,9 +91,17 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
_commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
field_names = [field.name for field in fields(cls)]
|
||||
with open(resolved_config_file, "r", encoding="utf-8") as f:
|
||||
return cls(**json.load(f))
|
||||
|
||||
args_from_json = json.load(f)
|
||||
filtered_args = {}
|
||||
for key, val in args_from_json.items():
|
||||
if key in field_names:
|
||||
filtered_args[key] = val
|
||||
else:
|
||||
logger.warning(f"ignoring unknown parameter in {quantize_config_filename}: {key}.")
|
||||
return cls(**filtered_args)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"bits": self.bits,
|
||||
|
@ -700,7 +708,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
disable_exllama: bool = False,
|
||||
disable_exllama: bool = True,
|
||||
disable_exllamav2: bool = False,
|
||||
**kwargs
|
||||
):
|
||||
"""load quantized model from local disk"""
|
||||
|
@ -743,6 +752,15 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
"auto_gptq from source."
|
||||
)
|
||||
disable_exllama = True
|
||||
if not disable_exllamav2 and not EXLLAMAV2_KERNELS_AVAILABLE:
|
||||
logger.warning(
|
||||
"Exllamav2 kernel is not installed, reset disable_exllamav2 to True. "
|
||||
"This may because you installed auto_gptq using a pre-build wheel "
|
||||
"on Windows, in which exllama_kernels are not compiled. To use "
|
||||
"exllama_kernels to further speedup inference, you can re-install "
|
||||
"auto_gptq from source."
|
||||
)
|
||||
disable_exllamav2 = True
|
||||
if not AUTOGPTQ_CUDA_AVAILABLE:
|
||||
logger.warning(
|
||||
"CUDA kernels for auto_gptq are not installed, this will result in "
|
||||
|
@ -758,6 +776,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
inject_fused_mlp = False
|
||||
use_triton = False
|
||||
disable_exllama = True
|
||||
disable_exllamav2 = True
|
||||
|
||||
if not disable_exllamav2 and not disable_exllama:
|
||||
logger.warning(
|
||||
"You have activated both exllama and exllamav2 kernel. Setting disable_exllama to True and keeping disable_exllamav2 to False"
|
||||
)
|
||||
disable_exllama = True
|
||||
|
||||
# == step1: prepare configs and file names == #
|
||||
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs)
|
||||
|
@ -804,9 +829,10 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
|
||||
model_save_name = resolved_archive_file
|
||||
|
||||
if not disable_exllama and trainable:
|
||||
if (not disable_exllama or not disable_exllamav2) and trainable:
|
||||
logger.warning("QuantLinear with exllama backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||
disable_exllama = True
|
||||
disable_exllamav2 = True
|
||||
|
||||
elif not use_triton and trainable:
|
||||
logger.warning("QuantLinear with cuda backend not support trainable mode yet, Switch to the pytorch backend.")
|
||||
|
@ -853,6 +879,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
quantize_config.group_size,
|
||||
use_triton=use_triton,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable
|
||||
|
@ -926,6 +953,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
quantize_config.group_size,
|
||||
use_triton=use_triton,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable,
|
||||
|
@ -987,6 +1015,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
trainable=trainable,
|
||||
bits=quantize_config.bits,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2
|
||||
)
|
||||
if inject_fused_mlp:
|
||||
if cls.fused_mlp_module_type is None:
|
||||
|
|
|
@ -24,6 +24,8 @@ SUPPORTED_MODELS = [
|
|||
]
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
if compare_transformers_version("v4.33.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("falcon")
|
||||
|
||||
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
||||
|
||||
|
|
|
@ -56,13 +56,14 @@ def make_quant(
|
|||
group_size,
|
||||
name='',
|
||||
use_triton: bool = False,
|
||||
disable_exllama: bool = False,
|
||||
disable_exllama: bool = True,
|
||||
disable_exllamav2: bool = False,
|
||||
use_qigen: bool = False,
|
||||
use_cuda_fp16: bool = True,
|
||||
desc_act: bool = False,
|
||||
trainable: bool = False
|
||||
):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, use_qigen=use_qigen)
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_qigen=use_qigen)
|
||||
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
|
@ -101,6 +102,7 @@ def make_quant(
|
|||
desc_act=desc_act,
|
||||
trainable=trainable,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
use_qigen=use_qigen
|
||||
)
|
||||
|
||||
|
@ -339,8 +341,32 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in
|
|||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllama":
|
||||
submodule.post_init()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
## exllamav2
|
||||
fixed_bytes = {}
|
||||
model_uses_exllamav2 = False
|
||||
|
||||
for _, submodule in model.named_modules():
|
||||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
|
||||
model_uses_exllamav2 = True
|
||||
device = submodule.qweight.device
|
||||
scratch_fixed = submodule.scratch_space_fixed()
|
||||
fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device,0))
|
||||
|
||||
if model_uses_exllamav2:
|
||||
from ..nn_modules.qlinear.qlinear_exllamav2 import ExLlamaV2DeviceTensors
|
||||
device_tensors = {}
|
||||
for device, scratch_bytes in fixed_bytes.items():
|
||||
device_tensors[device] = ExLlamaV2DeviceTensors(device.index, scratch_bytes)
|
||||
|
||||
# have persistent buffers, otherwise we will get OOM
|
||||
model.device_tensors = device_tensors
|
||||
|
||||
for _, submodule in model.named_modules():
|
||||
if hasattr(submodule, "QUANT_TYPE") and submodule.QUANT_TYPE == "exllamav2":
|
||||
device = submodule.qweight.device
|
||||
submodule.post_init(temp_dq = model.device_tensors[device])
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
|
@ -29,6 +29,7 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
|||
"codegen": CodeGenGPTQForCausalLM,
|
||||
"RefinedWebModel": RWGPTQForCausalLM,
|
||||
"RefinedWeb": RWGPTQForCausalLM,
|
||||
"falcon": RWGPTQForCausalLM,
|
||||
"baichuan": BaiChuanGPTQForCausalLM,
|
||||
"internlm": InternLMGPTQForCausalLM,
|
||||
"qwen": QwenGPTQForCausalLM,
|
||||
|
@ -81,7 +82,8 @@ class AutoGPTQForCausalLM:
|
|||
trust_remote_code: bool = False,
|
||||
warmup_triton: bool = False,
|
||||
trainable: bool = False,
|
||||
disable_exllama: bool = False,
|
||||
disable_exllama: bool = True,
|
||||
disable_exllamav2: bool = False,
|
||||
**kwargs
|
||||
) -> BaseGPTQForCausalLM:
|
||||
model_type = check_and_get_model_type(model_name_or_path, trust_remote_code)
|
||||
|
@ -122,6 +124,7 @@ class AutoGPTQForCausalLM:
|
|||
warmup_triton=warmup_triton,
|
||||
trainable=trainable,
|
||||
disable_exllama=disable_exllama,
|
||||
disable_exllamav2=disable_exllamav2,
|
||||
**keywords
|
||||
)
|
||||
|
||||
|
|
|
@ -237,16 +237,18 @@ class FusedGPTJAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
desc_act=False,
|
||||
trainable=False,
|
||||
bits: int = 4,
|
||||
disable_exllama=False,
|
||||
disable_exllama=True,
|
||||
disable_exllamav2=False,
|
||||
**kwargs
|
||||
):
|
||||
config = model.config
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||
if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act:
|
||||
# See fused_llama_attn.py comment
|
||||
logger.warning(f"Exllama kernel does not support query/key/value fusion with act-order. Because of this, Fused attention is automatically disabled.")
|
||||
return False
|
||||
|
||||
|
||||
for name, m in model.named_modules():
|
||||
if not isinstance(m, GPTJAttention):
|
||||
continue
|
||||
|
|
|
@ -137,14 +137,16 @@ class FusedLlamaAttentionForQuantizedModel(FusedBaseAttentionModule):
|
|||
desc_act=False,
|
||||
trainable=False,
|
||||
bits: int = 4,
|
||||
disable_exllama=False,
|
||||
disable_exllama=True,
|
||||
disable_exllamav2=False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
|
||||
"""
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||
if QuantLinear.QUANT_TYPE == "exllama" and desc_act:
|
||||
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2)
|
||||
if QuantLinear.QUANT_TYPE in ["exllama", "exllamav2"] and desc_act:
|
||||
# TODO: support it. The issue lies maybe in the line:
|
||||
# int groups = qzeros.size(0);
|
||||
# in exllama_ext.cpp
|
||||
|
|
188
auto_gptq/nn_modules/qlinear/qlinear_exllamav2.py
Normal file
188
auto_gptq/nn_modules/qlinear/qlinear_exllamav2.py
Normal file
|
@ -0,0 +1,188 @@
|
|||
# Adapted from turboderp exllama: https://github.com/turboderp/exllamav2
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from exllamav2_kernels import make_q_matrix, gemm_half_q_half
|
||||
except ImportError:
|
||||
logger.error('exllamav2_kernels not installed.')
|
||||
raise
|
||||
|
||||
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
|
||||
none_tensor = torch.empty((1, 1), device="meta")
|
||||
|
||||
def _torch_device(idx):
|
||||
if idx == -1: return "cpu"
|
||||
return f"cuda:{idx}"
|
||||
|
||||
def ext_gemm_half_q_half(x, q_handle, q4_width, force_cuda):
|
||||
"""Matrix multiplication, returns x @ q4"""
|
||||
output_shape = x.shape[:-1] + (q4_width,)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = torch.empty((x.shape[0], q4_width), dtype = torch.half, device = x.device)
|
||||
gemm_half_q_half(x, q_handle, output, force_cuda)
|
||||
return output.view(output_shape)
|
||||
|
||||
def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
|
||||
"""
|
||||
Create Q matrix
|
||||
"""
|
||||
# EXL2
|
||||
# won't work as the moment because the tensors are not the same.
|
||||
if "q_weight" in w:
|
||||
w["q_scale_max"] /= 256
|
||||
w["q_perm"] = w["q_perm"].short()
|
||||
w["q_invperm"] = w["q_invperm"].short()
|
||||
return make_q_matrix(w["q_weight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
w["q_scale"],
|
||||
w["q_scale_max"],
|
||||
w["q_groups"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
temp_dq)
|
||||
# GPTQ
|
||||
elif "qweight" in w:
|
||||
if w["scales"].dtype == torch.float:
|
||||
w["scales"] = w["scales"].half()
|
||||
|
||||
# GPTQ with g_idx (act_order)
|
||||
if "g_idx" in w and not (w["g_idx"] == 0).all().item():
|
||||
w["q_perm"] = torch.empty((w["qweight"].shape[0] * 8,), dtype = torch.short, device = w["qweight"].device)
|
||||
w["q_invperm"] = torch.empty_like(w["q_perm"])
|
||||
# make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx.
|
||||
return make_q_matrix(w["qweight"],
|
||||
w["q_perm"],
|
||||
w["q_invperm"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
w["g_idx"].cpu(),
|
||||
temp_dq)
|
||||
# GPTQ without g_idx
|
||||
else:
|
||||
return make_q_matrix(w["qweight"],
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
none_tensor,
|
||||
w["qzeros"],
|
||||
w["scales"],
|
||||
none_tensor,
|
||||
temp_dq)
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "exllamav2"
|
||||
|
||||
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
|
||||
|
||||
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
|
||||
super().__init__()
|
||||
if bits != 4:
|
||||
raise ValueError(
|
||||
f"Exllamav2 kernel supports only bits=4, requested bits={bits}. Something is wrong in the model initialization.")
|
||||
if trainable:
|
||||
raise NotImplementedError("Exllamav2 kernel does not support training.")
|
||||
|
||||
self.q_handle = None
|
||||
self.q_tensors = None
|
||||
self.padding = - outfeatures % 32
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures + self.padding
|
||||
self.bits = bits
|
||||
self.group_size = group_size if group_size != -1 else infeatures
|
||||
self.trainable = trainable
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
assert infeatures % 32 == 0
|
||||
assert infeatures % self.group_size == 0
|
||||
assert outfeatures % 32 == 0
|
||||
|
||||
# I need to register the tensors, otherwise, we won't be able to load them easily using transformers ...
|
||||
self.register_buffer(
|
||||
'qweight',
|
||||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def post_init(self, temp_dq):
|
||||
assert self.qweight.device.type == "cuda"
|
||||
assert self.qweight.device.index is not None
|
||||
self.q_tensors = {
|
||||
"qweight":self.qweight,
|
||||
"qzeros":self.qzeros,
|
||||
"scales":self.scales,
|
||||
"g_idx":self.g_idx
|
||||
}
|
||||
temp_dq = temp_dq.get_scratch_slice(self.temp_dq_size())
|
||||
self.q_handle = ext_make_q_matrix(
|
||||
self.q_tensors, temp_dq
|
||||
)
|
||||
|
||||
def forward(self, x, force_cuda = False):
|
||||
output = ext_gemm_half_q_half(x, self.q_handle, self.outfeatures, force_cuda)
|
||||
|
||||
if self.bias is not None:
|
||||
output.add_(self.bias)
|
||||
return output
|
||||
|
||||
def temp_dq_size(self):
|
||||
return self.infeatures * self.outfeatures * 2 + 128
|
||||
|
||||
def temp_fwd_size(self, max_input_len, max_batch_size):
|
||||
return self.outfeatures * max_input_len * max_batch_size * 4 + 128
|
||||
|
||||
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
|
||||
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
|
||||
|
||||
|
||||
class ExLlamaV2DeviceTensors:
|
||||
|
||||
device_idx: int
|
||||
scratch_bytes: int
|
||||
scratch_idx: int
|
||||
scratch: torch.tensor = None
|
||||
|
||||
def __init__(self, device_idx, scratch_bytes):
|
||||
self.device_idx = device_idx
|
||||
self.scratch_bytes = scratch_bytes
|
||||
|
||||
def prepare(self):
|
||||
self.scratch = torch.empty((self.scratch_bytes // 2,), dtype = torch.half, device = _torch_device(self.device_idx))
|
||||
|
||||
def get_scratch_slice(self, size_bytes):
|
||||
|
||||
if self.scratch is None: self.prepare()
|
||||
|
||||
size_bytes = ((size_bytes + 127) // 128) * 128
|
||||
size_half = size_bytes // 2
|
||||
scratch_slice = self.scratch.narrow(0, 0, size_half)
|
||||
return scratch_slice
|
|
@ -24,7 +24,14 @@ try:
|
|||
EXLLAMA_KERNELS_AVAILABLE = True
|
||||
except:
|
||||
EXLLAMA_KERNELS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import exllamav2_kernels
|
||||
|
||||
EXLLAMAV2_KERNELS_AVAILABLE = True
|
||||
except:
|
||||
EXLLAMAV2_KERNELS_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import cQIGen as qinfer
|
||||
|
||||
|
@ -35,7 +42,7 @@ except:
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False, use_qigen: bool = False):
|
||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = True, disable_exllamav2:bool = False, use_qigen: bool = False):
|
||||
if use_qigen:
|
||||
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
|
||||
else:
|
||||
|
@ -45,7 +52,9 @@ def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size:
|
|||
|
||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||
else:
|
||||
if bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
|
||||
if bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
|
||||
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
|
||||
elif bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
|
||||
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
|
||||
elif not desc_act or group_size == -1:
|
||||
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
|
||||
|
|
|
@ -402,7 +402,7 @@ def get_gptq_peft_model(
|
|||
with hijack_peft_mappings():
|
||||
try:
|
||||
if train_mode:
|
||||
peft_model = get_peft_model(model.model, peft_config)
|
||||
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
|
||||
else:
|
||||
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
|
||||
except:
|
||||
|
|
13
autogptq_extension/exllamav2/config.h
Normal file
13
autogptq_extension/exllamav2/config.h
Normal file
|
@ -0,0 +1,13 @@
|
|||
#ifndef _config_h
|
||||
#define _config_h
|
||||
|
||||
#define MAX_Q_GEMM_ROWS 50
|
||||
|
||||
#define QMODE_2BIT 1
|
||||
#define QMODE_3BIT 1
|
||||
#define QMODE_4BIT 1
|
||||
#define QMODE_5BIT 1
|
||||
#define QMODE_6BIT 0
|
||||
#define QMODE_8BIT 0
|
||||
|
||||
#endif
|
12
autogptq_extension/exllamav2/cpp/util.h
Normal file
12
autogptq_extension/exllamav2/cpp/util.h
Normal file
|
@ -0,0 +1,12 @@
|
|||
#ifndef _util_h
|
||||
#define _util_h
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
|
||||
#endif
|
56
autogptq_extension/exllamav2/cuda/compat.cuh
Normal file
56
autogptq_extension/exllamav2/cuda/compat.cuh
Normal file
|
@ -0,0 +1,56 @@
|
|||
#ifndef _compat_cuh
|
||||
#define _compat_cuh
|
||||
|
||||
// atomicAdd for half types, to support CC < 7.x
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
|
||||
{
|
||||
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
__half_raw hsum;
|
||||
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
|
||||
half tmpres = __hadd(hsum, val);
|
||||
hsum = __half_raw(tmpres);
|
||||
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
|
||||
old = atomicCAS(address_as_ui, assumed, old);
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
// atomicAdd for half2 types
|
||||
|
||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
|
||||
{
|
||||
unsigned int* address_as_ui = (unsigned int*)address;
|
||||
unsigned int old = *address_as_ui;
|
||||
unsigned int assumed;
|
||||
do
|
||||
{
|
||||
assumed = old;
|
||||
half2 old_val = *((half2*)&old);
|
||||
half2 new_val = __hadd2(old_val, val);
|
||||
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
|
||||
}
|
||||
while (assumed != old);
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
|
||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
|
||||
|
||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
|
||||
|
||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
|
||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#endif
|
121
autogptq_extension/exllamav2/cuda/matrix_view.cuh
Normal file
121
autogptq_extension/exllamav2/cuda/matrix_view.cuh
Normal file
|
@ -0,0 +1,121 @@
|
|||
#ifndef _matrix_view_cuh
|
||||
#define _matrix_view_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include "quant/qdq_util.cuh"
|
||||
|
||||
class MatrixView_half
|
||||
{
|
||||
public:
|
||||
const half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
|
||||
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
|
||||
|
||||
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __low2half(i01);
|
||||
items[1] = __high2half(i01);
|
||||
items[2] = __low2half(i23);
|
||||
items[3] = __high2half(i23);
|
||||
}
|
||||
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2float(__low2half(i01));
|
||||
items[1] = __half2float(__high2half(i01));
|
||||
items[2] = __half2float(__low2half(i23));
|
||||
items[3] = __half2float(__high2half(i23));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
|
||||
{
|
||||
half2* ptr = (half2*)item_ptr(row, column);
|
||||
half2 i01 = ptr[0];
|
||||
half2 i23 = ptr[1];
|
||||
items[0] = __half2half2(__low2half(i01));
|
||||
items[1] = __half2half2(__high2half(i01));
|
||||
items[2] = __half2half2(__low2half(i23));
|
||||
items[3] = __half2half2(__high2half(i23));
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_half_rw
|
||||
{
|
||||
public:
|
||||
half* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
|
||||
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
|
||||
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
|
||||
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
|
||||
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
|
||||
|
||||
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
|
||||
{
|
||||
half2 v01 = __halves2half2(v0, v1);
|
||||
half2 v23 = __halves2half2(v2, v3);
|
||||
half2* ptr = (half2*) item_ptr(row, column);
|
||||
ptr[0] = v01;
|
||||
ptr[1] = v23;
|
||||
}
|
||||
};
|
||||
|
||||
class MatrixView_q4_row
|
||||
{
|
||||
public:
|
||||
const uint32_t* data;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
|
||||
: data(data), height(height), width(width)
|
||||
{ }
|
||||
|
||||
__device__ __forceinline__ int item(int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
|
||||
{
|
||||
int shift = (column & 0x07) * 4;
|
||||
uint32_t d = data[row * width / 8 + column / 8] >> shift;
|
||||
items[0] = d & 0x0f;
|
||||
items[1] = (d >> 4) & 0x0f;
|
||||
items[2] = (d >> 8) & 0x0f;
|
||||
items[3] = (d >> 12) & 0x0f;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
238
autogptq_extension/exllamav2/cuda/q_gemm.cu
Normal file
238
autogptq_extension/exllamav2/cuda/q_gemm.cu
Normal file
|
@ -0,0 +1,238 @@
|
|||
#include "q_gemm.cuh"
|
||||
#include "util.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "../config.h"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
#define BLOCK_M_SIZE_MAX 8
|
||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
|
||||
#define CLEAR_N_SIZE 256
|
||||
|
||||
#include "q_gemm_kernel.cuh"
|
||||
#include "q_gemm_kernel_gptq.cuh"
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle,
|
||||
hipblasOperation_t transA,
|
||||
hipblasOperation_t transB,
|
||||
int m,
|
||||
int n,
|
||||
int k,
|
||||
const half* alpha,
|
||||
const half* AP,
|
||||
int lda,
|
||||
const half* BP,
|
||||
int ldb,
|
||||
const half* beta,
|
||||
half* CP,
|
||||
int ldc) {
|
||||
return hipblasHgemm(handle, transA, transB, m, n, k,
|
||||
reinterpret_cast<const hipblasHalf *>(alpha),
|
||||
reinterpret_cast<const hipblasHalf *>(AP), lda,
|
||||
reinterpret_cast<const hipblasHalf *>(BP), ldb,
|
||||
reinterpret_cast<const hipblasHalf *>(beta),
|
||||
reinterpret_cast<hipblasHalf *>(CP), ldc);
|
||||
}
|
||||
#define hipblasHgemm __compat_hipblasHgemm
|
||||
|
||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
|
||||
#define rocblas_operation_none HIPBLAS_OP_N
|
||||
#define rocblas_hgemm __compat_hipblasHgemm
|
||||
#endif
|
||||
|
||||
void gemm_half_q_half_cuda_part
|
||||
(
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
int m_count,
|
||||
bool clear
|
||||
)
|
||||
{
|
||||
if (!b->is_gptq)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_q_scale,
|
||||
b->cuda_q_scale_max,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_8,
|
||||
b->rows_6,
|
||||
b->rows_5,
|
||||
b->rows_4,
|
||||
b->rows_3,
|
||||
b->rows_2,
|
||||
clear
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
blockDim.z = 1;
|
||||
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
|
||||
gridDim.y = DIVIDE(size_m, m_count);
|
||||
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
|
||||
|
||||
// DBGX((uint64_t) b->cuda_q_perm);
|
||||
// DBGI(b->rows_4);
|
||||
// DBGI(b->height);
|
||||
|
||||
kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
a,
|
||||
b->cuda_q_weight,
|
||||
b->cuda_gptq_qzeros,
|
||||
b->cuda_gptq_scales,
|
||||
c,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
b->groups,
|
||||
b->groupsize,
|
||||
b->cuda_q_perm,
|
||||
b->rows_4,
|
||||
clear
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear,
|
||||
half* temp_dq,
|
||||
bool force_cuda
|
||||
)
|
||||
{
|
||||
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
|
||||
{
|
||||
//printf("cublas\n");
|
||||
|
||||
// Reconstruct FP16 matrix, then cuBLAS
|
||||
|
||||
if (!temp_dq) temp_dq = b->temp_dq;
|
||||
b->reconstruct(temp_dq);
|
||||
|
||||
//cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH);
|
||||
|
||||
const half alpha = __float2half(1.0f);
|
||||
const half beta = clear ? __float2half(0.0f) : __float2half(1.0f);
|
||||
cublasHgemm(cublas_handle,
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
size_n, size_m, size_k,
|
||||
&alpha, temp_dq, size_n,
|
||||
a, size_k,
|
||||
&beta, c, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasSgemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N,
|
||||
// CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n);
|
||||
|
||||
//const float alpha = 1.0f;
|
||||
//const float beta = clear ? 0.0f : 1.0f;
|
||||
//cublasGemmEx(cublas_handle,
|
||||
// CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
// size_n, size_m, size_k,
|
||||
// &alpha, temp_dq, CUDA_R_16F, size_n,
|
||||
// a, CUDA_R_16F, size_k,
|
||||
// &beta, c, CUDA_R_16F, size_n,
|
||||
// CUDA_R_16F, CUBLAS_GEMM_DFALT_TENSOR_OP);
|
||||
}
|
||||
else
|
||||
{
|
||||
//printf("cuda\n");
|
||||
|
||||
// Quantized matmul
|
||||
|
||||
//if (clear) clear_tensor_cuda(c, size_m, size_n);
|
||||
|
||||
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
|
||||
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
|
||||
int last_chunk_size = size_m - last_chunk;
|
||||
|
||||
if (max_chunks)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
|
||||
}
|
||||
|
||||
if (last_chunk_size)
|
||||
{
|
||||
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void clear_kernel
|
||||
(
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n
|
||||
)
|
||||
{
|
||||
int m = blockIdx.y;
|
||||
int n = (blockIdx.x * CLEAR_N_SIZE + threadIdx.x) * 8;
|
||||
if (n >= size_n) return;
|
||||
int4* c_ptr = (int4*)(c + m * size_n + n);
|
||||
*c_ptr = {};
|
||||
}
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
)
|
||||
{
|
||||
return;
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = CLEAR_N_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
|
||||
gridDim.y = size_m;
|
||||
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
|
||||
}
|
33
autogptq_extension/exllamav2/cuda/q_gemm.cuh
Normal file
33
autogptq_extension/exllamav2/cuda/q_gemm.cuh
Normal file
|
@ -0,0 +1,33 @@
|
|||
#ifndef _q_gemm_cuh
|
||||
#define _q_gemm_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "q_matrix.cuh"
|
||||
|
||||
void gemm_half_q_half_cuda
|
||||
(
|
||||
cublasHandle_t cublas_handle,
|
||||
const half* a,
|
||||
QMatrix* b,
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n,
|
||||
int size_k,
|
||||
bool clear = false,
|
||||
half* reconstruct = NULL,
|
||||
bool force_cuda = false
|
||||
);
|
||||
|
||||
void clear_tensor_cuda
|
||||
(
|
||||
half* c,
|
||||
int size_m,
|
||||
int size_n
|
||||
);
|
||||
|
||||
#endif
|
484
autogptq_extension/exllamav2/cuda/q_gemm_kernel.cuh
Normal file
484
autogptq_extension/exllamav2/cuda/q_gemm_kernel.cuh
Normal file
|
@ -0,0 +1,484 @@
|
|||
#include "compat.cuh"
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_16(half2(&dq)[8], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half2 dot22_32(half2(&dq)[16], const half* a_ptr, const half2 g_result, const half qs_h)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hfma2(result, __halves2half2(qs_h, qs_h), g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_16_f(half2(&dq)[8], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_32_f(half2(&dq)[16], const half* a_ptr, const float g_result, const float qs_f)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i += 1) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
float result_f = __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
return fma(result_f, qs_f, g_result);
|
||||
}
|
||||
|
||||
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const bool
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2,
|
||||
const bool clear
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
half a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Clear
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*) c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
|
||||
// Preload scales
|
||||
|
||||
float scales[MAX_GROUPS_IN_BLOCK][4];
|
||||
|
||||
int groups_in_block = DIVIDE((end_k - offset_k), groupsize);
|
||||
for (int g = 0; g < groups_in_block; g++)
|
||||
{
|
||||
int qscales[4];
|
||||
b_q_scale_.item4(qscales, group + g, n);
|
||||
qscales[0]++;
|
||||
qscales[1]++;
|
||||
qscales[2]++;
|
||||
qscales[3]++;
|
||||
float maxscale = __half2float(b_q_scale_max[group + g]);
|
||||
scales[g][0] = __int2float_rn(qscales[0] * qscales[0]) * maxscale;
|
||||
scales[g][1] = __int2float_rn(qscales[1] * qscales[1]) * maxscale;
|
||||
scales[g][2] = __int2float_rn(qscales[2] * qscales[2]) * maxscale;
|
||||
scales[g][3] = __int2float_rn(qscales[3] * qscales[3]) * maxscale;
|
||||
}
|
||||
|
||||
// a, b offset
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int scales_idx = 0;
|
||||
float qs_f0 = scales[scales_idx][0];
|
||||
float qs_f1 = scales[scales_idx][1];
|
||||
float qs_f2 = scales[scales_idx][2];
|
||||
float qs_f3 = scales[scales_idx][3];
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize groups
|
||||
|
||||
int k = offset_k;
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[2];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_8bit_8(load_int4[0].x, load_int4[1].x, dq[0], size_n);
|
||||
dequant_8bit_8(load_int4[0].y, load_int4[1].y, dq[1], size_n);
|
||||
dequant_8bit_8(load_int4[0].z, load_int4[1].z, dq[2], size_n);
|
||||
dequant_8bit_8(load_int4[0].w, load_int4[1].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_6bit_16(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_6bit_16(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_6bit_16(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_6bit_16(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[5];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[3] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[4] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_5bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, load_int4[3].x, load_int4[4].x, dq[0], size_n);
|
||||
dequant_5bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, load_int4[3].y, load_int4[4].y, dq[1], size_n);
|
||||
dequant_5bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, load_int4[3].z, load_int4[4].z, dq[2], size_n);
|
||||
dequant_5bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, load_int4[3].w, load_int4[4].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8(load_int4[0].x, dq[0], size_n);
|
||||
dequant_4bit_8(load_int4[0].y, dq[1], size_n);
|
||||
dequant_4bit_8(load_int4[0].z, dq[2], size_n);
|
||||
dequant_4bit_8(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_8_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_8_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_8_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_8_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 8;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 1; j++)
|
||||
{
|
||||
int4 load_int4[3];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[1] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
load_int4[2] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][16];
|
||||
dequant_3bit_32(load_int4[0].x, load_int4[1].x, load_int4[2].x, dq[0], size_n);
|
||||
dequant_3bit_32(load_int4[0].y, load_int4[1].y, load_int4[2].y, dq[1], size_n);
|
||||
dequant_3bit_32(load_int4[0].z, load_int4[1].z, load_int4[2].z, dq[2], size_n);
|
||||
dequant_3bit_32(load_int4[0].w, load_int4[1].w, load_int4[2].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_32_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_32_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_32_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_32_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
a_ptr += 32;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
scales_idx++;
|
||||
qs_f0 = scales[scales_idx][0];
|
||||
qs_f1 = scales[scales_idx][1];
|
||||
qs_f2 = scales[scales_idx][2];
|
||||
qs_f3 = scales[scales_idx][3];
|
||||
nextgroup += groupsize;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
int4 load_int4[1];
|
||||
load_int4[0] = *((int4*) b_ptr); b_ptr += size_n;
|
||||
|
||||
half2 dq[4][8];
|
||||
dequant_2bit_16(load_int4[0].x, dq[0], size_n);
|
||||
dequant_2bit_16(load_int4[0].y, dq[1], size_n);
|
||||
dequant_2bit_16(load_int4[0].z, dq[2], size_n);
|
||||
dequant_2bit_16(load_int4[0].w, dq[3], size_n);
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = dot22_16_f(dq[0], a_ptr + m * a_stride, block_c[m][0], qs_f0);
|
||||
block_c[m][1] = dot22_16_f(dq[1], a_ptr + m * a_stride, block_c[m][1], qs_f1);
|
||||
block_c[m][2] = dot22_16_f(dq[2], a_ptr + m * a_stride, block_c[m][2], qs_f2);
|
||||
block_c[m][3] = dot22_16_f(dq[3], a_ptr + m * a_stride, block_c[m][3], qs_f3);
|
||||
}
|
||||
|
||||
a_ptr += 16;
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
// Accumulate column sums in c
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2* out = (half2*)c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_kernel pick_gemm_half_q_half_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
219
autogptq_extension/exllamav2/cuda/q_gemm_kernel_gptq.cuh
Normal file
219
autogptq_extension/exllamav2/cuda/q_gemm_kernel_gptq.cuh
Normal file
|
@ -0,0 +1,219 @@
|
|||
#include "compat.cuh"
|
||||
|
||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __hadd2(result, g_result);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
|
||||
{
|
||||
half2 result = {};
|
||||
const half2* a2_ptr = (const half2*)a_ptr;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
|
||||
return __half2float(__low2half(result)) + __half2float(__high2half(result));
|
||||
}
|
||||
|
||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
|
||||
(
|
||||
const half*,
|
||||
const uint32_t*,
|
||||
const uint32_t*,
|
||||
const half*,
|
||||
half*,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const int,
|
||||
const uint16_t*,
|
||||
const int,
|
||||
const bool
|
||||
);
|
||||
|
||||
template <bool first_block, int m_count>
|
||||
__global__ void gemm_half_q_half_gptq_kernel
|
||||
(
|
||||
const half* __restrict__ a,
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
half* __restrict__ c,
|
||||
const int size_m,
|
||||
const int size_n,
|
||||
const int size_k,
|
||||
const int groups,
|
||||
const int groupsize,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const int rows_4,
|
||||
const bool clear
|
||||
)
|
||||
{
|
||||
MatrixView_half a_(a, size_m, size_k);
|
||||
MatrixView_half_rw c_(c, size_m, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int t = threadIdx.x;
|
||||
|
||||
// Block
|
||||
|
||||
int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
|
||||
int offset_m = blockIdx.y * m_count;
|
||||
int offset_k = blockIdx.z * BLOCK_KN_SIZE;
|
||||
|
||||
int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
|
||||
int end_m = min(offset_m + m_count, size_m);
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
|
||||
// Preload block_a
|
||||
|
||||
__shared__ half block_a[m_count][BLOCK_KN_SIZE];
|
||||
|
||||
if (offset_k + t < end_k)
|
||||
{
|
||||
for (int m = 0; m < m_count; ++m)
|
||||
{
|
||||
const half* a_ptr = a_.item_ptr(offset_m + m, 0);
|
||||
half* block_a_ptr = block_a[m];
|
||||
|
||||
half a0;
|
||||
if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
|
||||
else a0 = a_ptr[offset_k + t];
|
||||
block_a_ptr[t] = a0;
|
||||
}
|
||||
}
|
||||
|
||||
// Zero output
|
||||
|
||||
if (n >= size_n) return;
|
||||
|
||||
if (clear && blockIdx.z == 0) // && (threadIdx.x & 1) == 0)
|
||||
{
|
||||
for (int m = 0; m < m_count; m++)
|
||||
*((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// a, b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
const half* a_ptr = &block_a[0][0];
|
||||
int a_stride = BLOCK_KN_SIZE;
|
||||
|
||||
// Initial group
|
||||
|
||||
int zeros[4];
|
||||
float scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
// __syncthreads();
|
||||
|
||||
// Column result
|
||||
|
||||
float block_c[m_count][4] = {};
|
||||
|
||||
// Dequantize and multiply
|
||||
|
||||
int k = offset_k;
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_f(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
half2 dq[4][4];
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
#pragma unroll
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
|
||||
block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
|
||||
block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
|
||||
block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
|
||||
}
|
||||
|
||||
b_ptr += size_n;
|
||||
a_ptr += 8;
|
||||
}
|
||||
|
||||
k += 32;
|
||||
}
|
||||
|
||||
for (int m = 0; m < m_count; m++)
|
||||
{
|
||||
half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
|
||||
half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
|
||||
half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
|
||||
atomicAdd(out , result01);
|
||||
atomicAdd(out + 1, result23);
|
||||
}
|
||||
}
|
||||
|
||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
|
||||
{
|
||||
#if BLOCK_M_SIZE_MAX >= 1
|
||||
if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 2
|
||||
if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 3
|
||||
if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 4
|
||||
if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 5
|
||||
if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 6
|
||||
if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 7
|
||||
if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
|
||||
#endif
|
||||
#if BLOCK_M_SIZE_MAX >= 8
|
||||
if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
|
||||
#endif
|
||||
return NULL;
|
||||
}
|
603
autogptq_extension/exllamav2/cuda/q_matrix.cu
Normal file
603
autogptq_extension/exllamav2/cuda/q_matrix.cu
Normal file
|
@ -0,0 +1,603 @@
|
|||
#include "q_matrix.cuh"
|
||||
#include "matrix_view.cuh"
|
||||
#include "util.cuh"
|
||||
|
||||
#include "quant/qdq_2.cuh"
|
||||
#include "quant/qdq_3.cuh"
|
||||
#include "quant/qdq_4.cuh"
|
||||
#include "quant/qdq_5.cuh"
|
||||
#include "quant/qdq_6.cuh"
|
||||
#include "quant/qdq_8.cuh"
|
||||
|
||||
#define BLOCK_KN_SIZE 128
|
||||
|
||||
#define THREADS_X 32
|
||||
#define THREADS_Y 32
|
||||
|
||||
// Shuffle quantized data on load
|
||||
|
||||
__global__ void shuffle_kernel
|
||||
(
|
||||
uint32_t* __restrict__ b_q_weight,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
int n = blockIdx.x * THREADS_X + threadIdx.x;
|
||||
if (n >= size_n) return;
|
||||
int k = 0;
|
||||
uint32_t* b_ptr = b_q_weight + n;
|
||||
while (k < rows_8) { shuffle_8bit_4 (b_ptr, size_n); b_ptr += 1 * size_n; k += 4; }
|
||||
while (k < rows_6) { shuffle_6bit_16(b_ptr, size_n); b_ptr += 3 * size_n; k += 16; }
|
||||
while (k < rows_5) { shuffle_5bit_32(b_ptr, size_n); b_ptr += 5 * size_n; k += 32; }
|
||||
while (k < rows_4) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; }
|
||||
while (k < rows_3) { shuffle_3bit_32(b_ptr, size_n); b_ptr += 3 * size_n; k += 32; }
|
||||
while (k < rows_2) { shuffle_2bit_16(b_ptr, size_n); b_ptr += 1 * size_n; k += 16; }
|
||||
}
|
||||
|
||||
|
||||
// QMatrix constructor
|
||||
|
||||
QMatrix::QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
) :
|
||||
device(_device),
|
||||
height(_height),
|
||||
width(_width),
|
||||
groups(_groups),
|
||||
temp_dq(_temp_dq)
|
||||
{
|
||||
cudaSetDevice(device);
|
||||
|
||||
cuda_q_weight = _q_weight;
|
||||
cuda_q_perm = _q_perm;
|
||||
cuda_q_invperm = _q_invperm;
|
||||
cuda_q_scale = _q_scale;
|
||||
cuda_q_scale_max = _q_scale_max;
|
||||
cuda_q_groups = _q_groups;
|
||||
cuda_gptq_qzeros = _gptq_qzeros;
|
||||
cuda_gptq_scales = _gptq_scales;
|
||||
|
||||
is_gptq = (_gptq_qzeros != NULL);
|
||||
|
||||
groupsize = 1;
|
||||
while (groupsize * groups < height) groupsize *= 2;
|
||||
|
||||
// Create group map
|
||||
|
||||
rows_8 = 0;
|
||||
rows_6 = 0;
|
||||
rows_5 = 0;
|
||||
rows_4 = 0;
|
||||
rows_3 = 0;
|
||||
rows_2 = 0;
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
uint16_t* cpu_q_groups = (uint16_t*)calloc(groups * 2, sizeof(uint16_t));
|
||||
cudaMemcpy(cpu_q_groups, cuda_q_groups, groups * 2 * sizeof(uint16_t), cudaMemcpyDeviceToHost);
|
||||
|
||||
for (int i = 0; i < groups; i++)
|
||||
{
|
||||
int bits = cpu_q_groups[i * 2];
|
||||
if (bits == 8) rows_8 += groupsize;
|
||||
if (bits == 6) rows_6 += groupsize;
|
||||
if (bits == 5) rows_5 += groupsize;
|
||||
if (bits == 4) rows_4 += groupsize;
|
||||
if (bits == 3) rows_3 += groupsize;
|
||||
if (bits == 2) rows_2 += groupsize;
|
||||
}
|
||||
|
||||
free(cpu_q_groups);
|
||||
|
||||
rows_6 += rows_8;
|
||||
rows_5 += rows_6;
|
||||
rows_4 += rows_5;
|
||||
rows_3 += rows_4;
|
||||
rows_2 += rows_3;
|
||||
}
|
||||
else
|
||||
{
|
||||
rows_4 = height;
|
||||
rows_3 = height;
|
||||
rows_2 = height;
|
||||
|
||||
if (_gptq_g_idx) make_sequential(_gptq_g_idx);
|
||||
}
|
||||
|
||||
// Shuffle quantized data
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = 1;
|
||||
|
||||
shuffle_kernel<<<gridDim, blockDim>>>(cuda_q_weight, height, width, rows_8, rows_6, rows_5, rows_4, rows_3, rows_2);
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct b[k,n] (GPTQ)
|
||||
|
||||
__global__ void reconstruct_gptq_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_gptq_qzeros,
|
||||
const half* __restrict__ b_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_4
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
|
||||
MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
int t = threadIdx.x;
|
||||
|
||||
if (b_q_perm)
|
||||
{
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
}
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t * 4;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
// b offset
|
||||
|
||||
int qk = offset_k / (32 / 4);
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
// Initial zeros/scale
|
||||
|
||||
int zeros[4];
|
||||
half2 scales[4];
|
||||
half2 z1z16[4][2];
|
||||
half2 y1y16[4][2];
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
while (k < end_k)
|
||||
{
|
||||
if (k == nextgroup)
|
||||
{
|
||||
group++;
|
||||
nextgroup += groupsize;
|
||||
b_gptq_qzeros_.item4(zeros, group, n);
|
||||
b_gptq_scales_.item4_h2(scales, group, n);
|
||||
dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
|
||||
dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
|
||||
dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
|
||||
dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
|
||||
}
|
||||
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4][4];
|
||||
const int4* b_ptr4 = (int4*) b_ptr;
|
||||
int4 load_int4 = *b_ptr4;
|
||||
|
||||
dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
|
||||
dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
|
||||
|
||||
b_ptr += size_n;
|
||||
//half* dqh = (half*)dq;
|
||||
if (b_q_perm)
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for (int j = 0; j < 4; j++)
|
||||
{
|
||||
for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
|
||||
b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
|
||||
b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
|
||||
}
|
||||
}
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct b[k,n]
|
||||
|
||||
__global__ void reconstruct_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ b_q_weight,
|
||||
const uint16_t* __restrict__ b_q_perm,
|
||||
const uint32_t* __restrict__ b_q_scale,
|
||||
const half* __restrict__ b_q_scale_max,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
const int size_k,
|
||||
const int size_n,
|
||||
const int groupsize,
|
||||
const int groups,
|
||||
half* __restrict__ b,
|
||||
const int rows_8,
|
||||
const int rows_6,
|
||||
const int rows_5,
|
||||
const int rows_4,
|
||||
const int rows_3,
|
||||
const int rows_2
|
||||
)
|
||||
{
|
||||
MatrixView_half_rw b_(b, size_k, size_n);
|
||||
MatrixView_q4_row b_q_scale_(b_q_scale, groups, size_n);
|
||||
|
||||
int offset_k = BLOCK_KN_SIZE * blockIdx.y;
|
||||
int offset_n = BLOCK_KN_SIZE * blockIdx.x;
|
||||
|
||||
// Preload remapping table
|
||||
|
||||
int t = threadIdx.x;
|
||||
__shared__ uint16_t perm[BLOCK_KN_SIZE];
|
||||
if (offset_k + t < size_k)
|
||||
perm[t] = b_q_perm[offset_k + t];
|
||||
|
||||
// Column
|
||||
|
||||
int n = offset_n + t;
|
||||
if (n >= size_n) return;
|
||||
|
||||
// Find initial group
|
||||
|
||||
int group = offset_k / groupsize;
|
||||
|
||||
int pre_rows_8 = min(rows_8, offset_k);
|
||||
int pre_rows_6 = offset_k > rows_8 ? min(rows_6, offset_k) - rows_8 : 0;
|
||||
int pre_rows_5 = offset_k > rows_6 ? min(rows_5, offset_k) - rows_6 : 0;
|
||||
int pre_rows_4 = offset_k > rows_5 ? min(rows_4, offset_k) - rows_5 : 0;
|
||||
int pre_rows_3 = offset_k > rows_4 ? min(rows_3, offset_k) - rows_4 : 0;
|
||||
int pre_rows_2 = offset_k > rows_3 ? min(rows_2, offset_k) - rows_3 : 0;
|
||||
int qk = 0;
|
||||
qk += pre_rows_8 / 32 * 8;
|
||||
qk += pre_rows_6 / 32 * 6;
|
||||
qk += pre_rows_5 / 32 * 5;
|
||||
qk += pre_rows_4 / 32 * 4;
|
||||
qk += pre_rows_3 / 32 * 3;
|
||||
qk += pre_rows_2 / 32 * 2;
|
||||
|
||||
const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
|
||||
|
||||
half qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]);
|
||||
half2 qs_h2 = __halves2half2(qs_h, qs_h);
|
||||
int nextgroup = offset_k + groupsize;
|
||||
|
||||
int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
|
||||
int k = offset_k;
|
||||
int lk = 0;
|
||||
|
||||
__syncthreads();
|
||||
|
||||
while (k < rows_8 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
dequant_8bit_8(q_0, q_1, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_6 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_6bit_16(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_5 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_3 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_4 = *b_ptr; b_ptr += size_n;
|
||||
dequant_5bit_32(q_0, q_1, q_2, q_3, q_4, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_4 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 4; p++)
|
||||
{
|
||||
half2 dq[4];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_4bit_8(q_0, dq, size_n);
|
||||
for (int j = 0; j < 4; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 8; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_3 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 1; p++)
|
||||
{
|
||||
half2 dq[16];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_1 = *b_ptr; b_ptr += size_n;
|
||||
uint32_t q_2 = *b_ptr; b_ptr += size_n;
|
||||
dequant_3bit_32(q_0, q_1, q_2, dq, size_n);
|
||||
for (int j = 0; j < 16; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 32; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
|
||||
while (k < rows_2 && k < end_k)
|
||||
{
|
||||
if (k == nextgroup) { group++; qs_h = dq_scale(b_q_scale_.item(group, n), b_q_scale_max[group]); nextgroup += groupsize; qs_h2 = __halves2half2(qs_h, qs_h); }
|
||||
for (int p = 0; p < 2; p++)
|
||||
{
|
||||
half2 dq[8];
|
||||
uint32_t q_0 = *b_ptr; b_ptr += size_n;
|
||||
dequant_2bit_16(q_0, dq, size_n);
|
||||
for (int j = 0; j < 8; j++) dq[j] = __hmul2(dq[j], qs_h2);
|
||||
half* dqh = (half*) dq;
|
||||
for (int j = 0; j < 16; j++) b_.set(perm[lk++], n, dqh[j]);
|
||||
}
|
||||
k += 32;
|
||||
}
|
||||
}
|
||||
|
||||
void QMatrix::reconstruct(half* out)
|
||||
{
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = BLOCK_KN_SIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
|
||||
gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
|
||||
|
||||
if (!is_gptq)
|
||||
{
|
||||
reconstruct_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_q_scale,
|
||||
cuda_q_scale_max,
|
||||
//cuda_q_groups,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_8,
|
||||
rows_6,
|
||||
rows_5,
|
||||
rows_4,
|
||||
rows_3,
|
||||
rows_2
|
||||
);
|
||||
}
|
||||
else
|
||||
{
|
||||
reconstruct_gptq_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_q_perm,
|
||||
cuda_gptq_qzeros,
|
||||
cuda_gptq_scales,
|
||||
//const uint16_t* __restrict__ b_q_groups,
|
||||
height,
|
||||
width,
|
||||
groupsize,
|
||||
groups,
|
||||
out,
|
||||
rows_4
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void make_sequential_kernel
|
||||
(
|
||||
const uint32_t* __restrict__ w,
|
||||
uint32_t* __restrict__ w_new,
|
||||
const uint16_t* __restrict__ q_perm,
|
||||
const int w_height,
|
||||
const int w_width
|
||||
)
|
||||
{
|
||||
const uint64_t* w2 = (uint64_t*) w;
|
||||
uint64_t* w_new2 = (uint64_t*) w_new;
|
||||
int w2_stride = w_width >> 1;
|
||||
|
||||
int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
|
||||
if (w2_column >= w2_stride) return;
|
||||
|
||||
int w_new2_row = blockIdx.y;
|
||||
|
||||
int q_perm_idx = w_new2_row << 3;
|
||||
|
||||
uint64_t dst = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
int source_row = q_perm[q_perm_idx++];
|
||||
|
||||
int w2_row = source_row >> 3;
|
||||
int w2_subrow = source_row & 0x07;
|
||||
int w2_row_shift = w2_subrow << 2;
|
||||
int wnew2_row_shift = i << 2;
|
||||
|
||||
uint64_t src = w2[w2_row * w2_stride + w2_column];
|
||||
src >>= w2_row_shift;
|
||||
src &= 0x0000000f0000000f;
|
||||
src <<= wnew2_row_shift;
|
||||
dst |= src;
|
||||
}
|
||||
|
||||
w_new2[w_new2_row * w2_stride + w2_column] = dst;
|
||||
}
|
||||
|
||||
void QMatrix::make_sequential(const uint32_t* cpu_g_idx)
|
||||
{
|
||||
uint32_t* cuda_new_qweight = NULL;
|
||||
cudaMalloc(&cuda_new_qweight, height / 8 * width * sizeof(uint32_t));
|
||||
|
||||
uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t));
|
||||
|
||||
// Group histogram
|
||||
|
||||
for (int i = 0; i < height; i++) cpu_g_idx_map[cpu_g_idx[i]]++;
|
||||
|
||||
// Group map
|
||||
|
||||
for (int i = 0, acc = 0; i < groups; i++)
|
||||
{
|
||||
short tmp = cpu_g_idx_map[i];
|
||||
cpu_g_idx_map[i] = acc;
|
||||
acc += tmp;
|
||||
}
|
||||
|
||||
// X map (inverse)
|
||||
|
||||
for (int row = 0; row < height; row++)
|
||||
{
|
||||
uint32_t target_group = cpu_g_idx[row];
|
||||
uint32_t target_row = cpu_g_idx_map[target_group];
|
||||
cpu_g_idx_map[target_group]++;
|
||||
cpu_x_map_inv[row] = target_row;
|
||||
}
|
||||
|
||||
// X map
|
||||
|
||||
for (int row = 0; row < height; row++) cpu_x_map[cpu_x_map_inv[row]] = row;
|
||||
|
||||
// Reduce to uint16_t
|
||||
|
||||
uint16_t* cpu_x_map16 = (uint16_t*)cpu_x_map;
|
||||
uint16_t* cpu_x_map_inv16 = (uint16_t*)cpu_x_map_inv;
|
||||
for (int row = 0; row < height; row++) cpu_x_map16[row] = (uint16_t) cpu_x_map[row];
|
||||
for (int row = 0; row < height; row++) cpu_x_map_inv16[row] = (uint16_t) cpu_x_map_inv[row];
|
||||
|
||||
// Move to CUDA
|
||||
|
||||
cudaMemcpyAsync(cuda_q_perm, cpu_x_map16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
cudaMemcpyAsync(cuda_q_invperm, cpu_x_map_inv16, height * sizeof(uint16_t), cudaMemcpyHostToDevice);
|
||||
|
||||
// Rearrange rows in w
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = DIVIDE(width, THREADS_X);
|
||||
gridDim.y = height / 8;
|
||||
|
||||
make_sequential_kernel<<<gridDim, blockDim>>>
|
||||
(
|
||||
cuda_q_weight,
|
||||
cuda_new_qweight,
|
||||
cuda_q_perm,
|
||||
height / 8,
|
||||
width
|
||||
);
|
||||
|
||||
// Replace qweights
|
||||
|
||||
cudaMemcpyAsync(cuda_q_weight, cuda_new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
|
||||
|
||||
// Cleanup
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
cudaFree(cuda_new_qweight);
|
||||
free(cpu_g_idx_map);
|
||||
free(cpu_x_map);
|
||||
free(cpu_x_map_inv);
|
||||
}
|
71
autogptq_extension/exllamav2/cuda/q_matrix.cuh
Normal file
71
autogptq_extension/exllamav2/cuda/q_matrix.cuh
Normal file
|
@ -0,0 +1,71 @@
|
|||
#ifndef _q_matrix_cuh
|
||||
#define _q_matrix_cuh
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#define MAX_SUPERGROUPS 16
|
||||
|
||||
class QMatrix
|
||||
{
|
||||
public:
|
||||
|
||||
int device;
|
||||
bool is_gptq;
|
||||
|
||||
int height;
|
||||
int width;
|
||||
int groups;
|
||||
int groupsize;
|
||||
|
||||
int rows_8;
|
||||
int rows_6;
|
||||
int rows_5;
|
||||
int rows_4;
|
||||
int rows_3;
|
||||
int rows_2;
|
||||
|
||||
uint32_t* cuda_q_weight = NULL;
|
||||
uint16_t* cuda_q_perm = NULL;
|
||||
uint16_t* cuda_q_invperm = NULL;
|
||||
uint32_t* cuda_q_scale = NULL;
|
||||
half* cuda_q_scale_max = NULL;
|
||||
uint16_t* cuda_q_groups = NULL;
|
||||
uint32_t* cuda_gptq_qzeros = NULL;
|
||||
half* cuda_gptq_scales = NULL;
|
||||
|
||||
half* temp_dq;
|
||||
|
||||
QMatrix
|
||||
(
|
||||
const int _device,
|
||||
const int _height,
|
||||
const int _width,
|
||||
const int _groups,
|
||||
|
||||
uint32_t* _q_weight,
|
||||
uint16_t* _q_perm,
|
||||
uint16_t* _q_invperm,
|
||||
uint32_t* _q_scale,
|
||||
half* _q_scale_max,
|
||||
uint16_t* _q_groups,
|
||||
|
||||
uint32_t* _gptq_qzeros,
|
||||
half* _gptq_scales,
|
||||
uint32_t* _gptq_g_idx,
|
||||
|
||||
half* _temp_dq
|
||||
);
|
||||
|
||||
~QMatrix();
|
||||
|
||||
void reconstruct(half* out);
|
||||
void make_sequential(const uint32_t* cpu_g_idx);
|
||||
|
||||
private:
|
||||
|
||||
};
|
||||
|
||||
#endif
|
103
autogptq_extension/exllamav2/cuda/quant/qdq_2.cuh
Normal file
103
autogptq_extension/exllamav2/cuda/quant/qdq_2.cuh
Normal file
|
@ -0,0 +1,103 @@
|
|||
#ifndef _qdq_2_cuh
|
||||
#define _qdq_2_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_2BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// ffddbb99 77553311 eeccaa88 66442200
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x03;
|
||||
uint32_t qa1 = (qa & 0x0c) >> 2;
|
||||
qa >>= 4;
|
||||
qb |= (qa1 << (i * 2 + 16));
|
||||
qb |= (qa0 << (i * 2));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y4_ = __float2half_rn(1.0f / 4.0f);
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y4 = __halves2half2(y4_, y4_);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 2.0f);
|
||||
const half z4_ = __float2half_rn(-1024.0f / 4.0f - 2.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 2.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 2.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z4 = __halves2half2(z4_, z4_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x00030003) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x000c000c) | c0); // half2(q[ 2], q[ 3]) * 4 + 1024
|
||||
half2_uint32 q2((qa & 0x00300030) | c0); // half2(q[ 4], q[ 5]) * 16 + 1024
|
||||
half2_uint32 q3((qa & 0x00c000c0) | c0); // half2(q[ 6], q[ 7]) * 64 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q4((qa & 0x00030003) | c0); // half2(q[ 8], q[ 8]) + 1024
|
||||
half2_uint32 q5((qa & 0x000c000c) | c0); // half2(q[10], q[11]) * 4 + 1024
|
||||
half2_uint32 q6((qa & 0x00300030) | c0); // half2(q[12], q[13]) * 16 + 1024
|
||||
half2_uint32 q7((qa & 0x00c000c0) | c0); // half2(q[14], q[15]) * 64 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y4, z4);
|
||||
dq[2] = __hfma2(q2.as_half2, y16, z16);
|
||||
dq[3] = __hfma2(q3.as_half2, y64, z64);
|
||||
dq[4] = __hadd2(q4.as_half2, z1);
|
||||
dq[5] = __hfma2(q5.as_half2, y4, z4);
|
||||
dq[6] = __hfma2(q6.as_half2, y16, z16);
|
||||
dq[7] = __hfma2(q7.as_half2, y64, z64);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_2bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_2bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 16; i++) dqh[i] = dq_ns(exb(q_0, i * 2, 0x03), 2);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
169
autogptq_extension/exllamav2/cuda/quant/qdq_3.cuh
Normal file
169
autogptq_extension/exllamav2/cuda/quant/qdq_3.cuh
Normal file
|
@ -0,0 +1,169 @@
|
|||
#ifndef _qdq_3_cuh
|
||||
#define _qdq_3_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_3BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
|
||||
// qa: aa999888 77766655 54443332 22111000
|
||||
// qb: lkkkjjji iihhhggg fffeeedd dcccbbba
|
||||
// qc: vvvuuutt tsssrrrq qqpppooo nnnmmmll
|
||||
|
||||
uint32_t qd = qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: ..999888 77766655 54443332 22111000
|
||||
// qb: ..jjjiii hhhgggff feeedddc ccbbbaaa
|
||||
// qc: ..tttsss rrrqqqpp pooonnnm mmlllkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qa & 0x07; uint32_t t1 = (qa & 0x38) >> 3; qa >>= 6; za |= (t0 << (i * 3)); za |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qb & 0x07; uint32_t t1 = (qb & 0x38) >> 3; qb >>= 6; zb |= (t0 << (i * 3)); zb |= (t1 << (i * 3 + 16)); }
|
||||
for (int i = 0; i < 5; i++) { uint32_t t0 = qc & 0x07; uint32_t t1 = (qc & 0x38) >> 3; qc >>= 6; zc |= (t0 << (i * 3)); zc |= (t1 << (i * 3 + 16)); }
|
||||
|
||||
// za: 9997775 55333111 8886664 44222000
|
||||
// zb: jjjhhhf ffdddbbb iiiggge eecccaaa
|
||||
// zc: tttrrrp ppnnnlll sssqqqo oommmkkk
|
||||
// qd: vvvuuu
|
||||
|
||||
za |= ((qd & 0x01) >> 0) << 15;
|
||||
zb |= ((qd & 0x02) >> 1) << 15;
|
||||
zc |= ((qd & 0x04) >> 2) << 15;
|
||||
za |= ((qd & 0x08) >> 3) << 31;
|
||||
zb |= ((qd & 0x10) >> 4) << 31;
|
||||
zc |= ((qd & 0x20) >> 5) << 31;
|
||||
|
||||
// za: v9997775 55333111 u8886664 44222000 (u, v lsb)
|
||||
// zb: vjjjhhhf ffdddbbb uiiiggge eecccaaa
|
||||
// zc: vtttrrrp ppnnnlll usssqqqo oommmkkk
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y8_ = __float2half_rn(1.0f / 8.0f);
|
||||
const half y64_ = __float2half_rn(1.0f / 64.0f);
|
||||
const half2 y8 = __halves2half2(y8_, y8_);
|
||||
const half2 y64 = __halves2half2(y64_, y64_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 4.0f);
|
||||
const half z8_ = __float2half_rn(-1024.0f / 8.0f - 4.0f);
|
||||
const half z64_ = __float2half_rn(-1024.0f / 64.0f - 4.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z8 = __halves2half2(z8_, z8_);
|
||||
const half2 z64 = __halves2half2(z64_, z64_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
|
||||
half2_uint32 q0((qa & 0x00070007) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00380038) | c0); // half2(q[ 2], q[ 3]) * 8 + 1024
|
||||
qa >>= 6;
|
||||
half2_uint32 q2((qa & 0x00070007) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00380038) | c0); // half2(q[ 6], q[ 7]) * 8 + 1024
|
||||
half2_uint32 q4((qa & 0x01c001c0) | c0); // half2(q[ 8], q[ 9]) * 64 + 1024
|
||||
qa >>= 9;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q5((qb & 0x00070007) | c0); // half2(q[10], q[11]) + 1024
|
||||
half2_uint32 q6((qb & 0x00380038) | c0); // half2(q[12], q[13]) * 8 + 1024
|
||||
qb >>= 6;
|
||||
half2_uint32 q7((qb & 0x00070007) | c0); // half2(q[14], q[15]) + 1024
|
||||
half2_uint32 q8((qb & 0x00380038) | c0); // half2(q[16], q[17]) * 8 + 1024
|
||||
half2_uint32 q9((qb & 0x01c001c0) | c0); // half2(q[18], q[19]) * 64 + 1024
|
||||
qb >>= 8;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q10((qc & 0x00070007) | c0); // half2(q[20], q[21]) + 1024
|
||||
half2_uint32 q11((qc & 0x00380038) | c0); // half2(q[22], q[23]) * 8 + 1024
|
||||
qc >>= 6;
|
||||
half2_uint32 q12((qc & 0x00070007) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qc & 0x00380038) | c0); // half2(q[26], q[27]) * 8 + 1024
|
||||
half2_uint32 q14((qc & 0x01c001c0) | c0); // half2(q[28], q[29]) * 64 + 1024
|
||||
qc >>= 7;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q15((qa | qb | qc) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y8, z8);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hfma2( q3.as_half2, y8, z8);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y64, z64);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hfma2( q6.as_half2, y8, z8);
|
||||
dq[ 7] = __hadd2( q7.as_half2, z1);
|
||||
dq[ 8] = __hfma2( q8.as_half2, y8, z8);
|
||||
dq[ 9] = __hfma2( q9.as_half2, y64, z64);
|
||||
dq[10] = __hadd2(q10.as_half2, z1);
|
||||
dq[11] = __hfma2(q11.as_half2, y8, z8);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y8, z8);
|
||||
dq[14] = __hfma2(q14.as_half2, y64, z64);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_3bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_3bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 10; i++) dqh[ i] = dq_ns(exb( q_0, i * 3 , 0x07), 4);
|
||||
dqh[10 ] = dq_ns(exb(q_1, q_0, 30, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[11 + i] = dq_ns(exb( q_1, i * 3 + 1, 0x07), 4);
|
||||
dqh[21 ] = dq_ns(exb(q_2, q_1, 31, 0x07), 4);
|
||||
for (int i = 0; i < 10; i++) dqh[22 + i] = dq_ns(exb( q_2, i * 3 + 2, 0x07), 4);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
227
autogptq_extension/exllamav2/cuda/quant/qdq_4.cuh
Normal file
227
autogptq_extension/exllamav2/cuda/quant/qdq_4.cuh
Normal file
|
@ -0,0 +1,227 @@
|
|||
#ifndef _qdq_4_cuh
|
||||
#define _qdq_4_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_4BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// 77775555 33331111 66664444 22220000
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0];
|
||||
uint32_t qb = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
uint32_t qa0 = qa & 0x0f;
|
||||
uint32_t qa1 = (qa & 0xf0) >> 4;
|
||||
qa >>= 8;
|
||||
qb |= (qa1 << (i * 4 + 16));
|
||||
qb |= (qa0 << (i * 4));
|
||||
}
|
||||
q[0] = qb;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y16_ = __float2half_rn(1.0f / 16.0f);
|
||||
const half2 y16 = __halves2half2(y16_, y16_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
|
||||
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z16 = __halves2half2(z16_, z16_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
|
||||
|
||||
dq[0] = __hadd2(q0.as_half2, z1);
|
||||
dq[1] = __hfma2(q1.as_half2, y16, z16);
|
||||
dq[2] = __hadd2(q2.as_half2, z1);
|
||||
dq[3] = __hfma2(q3.as_half2, y16, z16);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
half2 scale2 = __half2half2(scale);
|
||||
|
||||
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
|
||||
z1z16[1] = __hmul2(scale2, __half2half2(z16));
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __hmul2(scale2, __half2half2(y1));
|
||||
y1y16[1] = __hmul2(scale2, __half2half2(y16));
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1z16)[2],
|
||||
half2(&y1y16)[2]
|
||||
)
|
||||
{
|
||||
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
|
||||
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
|
||||
|
||||
z1z16[0] = __half2half2(z1.as_half);
|
||||
z1z16[1] = __half2half2(z16);
|
||||
|
||||
const half y1 = __float2half_rn(1.0f);
|
||||
const half y16 = __float2half_rn(1.0f / 16.0f);
|
||||
|
||||
y1y16[0] = __half2half2(y1);
|
||||
y1y16[1] = __half2half2(y16);
|
||||
}
|
||||
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1z16)[2],
|
||||
half2 (&y1y16)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
|
||||
uint32_t qa = q_0;
|
||||
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
|
||||
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
|
||||
qa >>= 8;
|
||||
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
|
||||
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
|
||||
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
|
||||
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
|
||||
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
|
||||
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_4bit_8
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
|
||||
(
|
||||
const uint32_t zero,
|
||||
const half scale,
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z = __hmul(z, scale);
|
||||
z1[0] = __half2half2(z);
|
||||
y1[0] = __half2half2(scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
|
||||
(
|
||||
const uint32_t zero,
|
||||
half2(&z1)[2],
|
||||
half2(&y1)[2]
|
||||
)
|
||||
{
|
||||
half z = __int2half_rn(-((int)zero));
|
||||
z1[0] = __half2half2(z);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_4bit_8_gptq
|
||||
(
|
||||
const uint32_t q_0,
|
||||
half2 (&dq)[4],
|
||||
half2 (&z1)[2],
|
||||
half2 (&y1)[2],
|
||||
int stride,
|
||||
bool scaled
|
||||
)
|
||||
{
|
||||
half2 dqh2[8];
|
||||
|
||||
uint32_t qa = q_0;
|
||||
for (int i = 0; i < 4; i++)
|
||||
{
|
||||
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
|
||||
dqh2[i] = __halves2half2(d0, d1);
|
||||
}
|
||||
|
||||
if (scaled)
|
||||
{
|
||||
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
|
||||
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
|
||||
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
|
||||
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
|
||||
}
|
||||
else
|
||||
{
|
||||
dq[0] = __hadd2(dqh2[0], z1[0]);
|
||||
dq[1] = __hadd2(dqh2[1], z1[0]);
|
||||
dq[2] = __hadd2(dqh2[2], z1[0]);
|
||||
dq[3] = __hadd2(dqh2[3], z1[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
207
autogptq_extension/exllamav2/cuda/quant/qdq_5.cuh
Normal file
207
autogptq_extension/exllamav2/cuda/quant/qdq_5.cuh
Normal file
|
@ -0,0 +1,207 @@
|
|||
#ifndef _qdq_5_cuh
|
||||
#define _qdq_5_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_5BIT == 1
|
||||
|
||||
// Permutation:
|
||||
//
|
||||
// v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
uint32_t qa = q[0 * stride];
|
||||
uint32_t qb = q[1 * stride];
|
||||
uint32_t qc = q[2 * stride];
|
||||
uint32_t qd = q[3 * stride];
|
||||
uint32_t qe = q[4 * stride];
|
||||
|
||||
// qa: 66555554 44443333 32222211 11100000
|
||||
// qb: ccccbbbb baaaaa99 99988888 77777666
|
||||
// qc: jiiiiihh hhhggggg fffffeee eedddddc
|
||||
// qd: pppooooo nnnnnmmm mmlllllk kkkkjjjj
|
||||
// qe: vvvvvuuu uuttttts ssssrrrr rqqqqqpp
|
||||
|
||||
uint32_t qf = qe >> 22;
|
||||
qe <<= 8;
|
||||
qe |= qd >> 24;
|
||||
qd <<= 6;
|
||||
qd |= qc >> 26;
|
||||
qc <<= 4;
|
||||
qc |= qb >> 28;
|
||||
qb <<= 2;
|
||||
qb |= qa >> 30;
|
||||
|
||||
// qa: 555554 44443333 32222211 11100000
|
||||
// qb: bbbbba aaaa9999 98888877 77766666
|
||||
// qc: hhhhhg ggggffff feeeeedd dddccccc
|
||||
// qd: nnnnnm mmmmllll lkkkkkjj jjjiiiii
|
||||
// qe: ttttts ssssrrrr rqqqqqpp pppooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
uint32_t za = 0;
|
||||
uint32_t zb = 0;
|
||||
uint32_t zc = 0;
|
||||
uint32_t zd = 0;
|
||||
uint32_t ze = 0;
|
||||
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qa & 0x1f; uint32_t t1 = (qa & 0x3e0) >> 5; qa >>= 10; za |= (t0 << (i * 5)); za |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qb & 0x1f; uint32_t t1 = (qb & 0x3e0) >> 5; qb >>= 10; zb |= (t0 << (i * 5)); zb |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qc & 0x1f; uint32_t t1 = (qc & 0x3e0) >> 5; qc >>= 10; zc |= (t0 << (i * 5)); zc |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qd & 0x1f; uint32_t t1 = (qd & 0x3e0) >> 5; qd >>= 10; zd |= (t0 << (i * 5)); zd |= (t1 << (i * 5 + 16)); }
|
||||
for (int i = 0; i < 3; i++) { uint32_t t0 = qe & 0x1f; uint32_t t1 = (qe & 0x3e0) >> 5; qe >>= 10; ze |= (t0 << (i * 5)); ze |= (t1 << (i * 5 + 16)); }
|
||||
|
||||
// za: 5555533 33311111 4444422 22200000
|
||||
// zb: bbbbb99 99977777 aaaaa88 88866666
|
||||
// zc: hhhhhff fffddddd gggggee eeeccccc
|
||||
// zd: nnnnnll llljjjjj mmmmmkk kkkiiiii
|
||||
// ze: tttttrr rrrppppp sssssqq qqqooooo
|
||||
// qf: vv vvvuuuuu
|
||||
|
||||
za |= ((qf & 0x001) >> 0) << 15;
|
||||
zb |= ((qf & 0x002) >> 1) << 15;
|
||||
zc |= ((qf & 0x004) >> 2) << 15;
|
||||
zd |= ((qf & 0x008) >> 3) << 15;
|
||||
ze |= ((qf & 0x010) >> 4) << 15;
|
||||
za |= ((qf & 0x020) >> 5) << 31;
|
||||
zb |= ((qf & 0x040) >> 6) << 31;
|
||||
zc |= ((qf & 0x080) >> 7) << 31;
|
||||
zd |= ((qf & 0x100) >> 8) << 31;
|
||||
ze |= ((qf & 0x200) >> 9) << 31;
|
||||
|
||||
// za: v5555533 33311111 u4444422 22200000 (u, v lsb)
|
||||
// zb: vbbbbb99 99977777 uaaaaa88 88866666
|
||||
// zc: vhhhhhff fffddddd ugggggee eeeccccc
|
||||
// zd: vnnnnnll llljjjjj ummmmmkk kkkiiiii
|
||||
// ze: vtttttrr rrrppppp usssssqq qqqooooo
|
||||
|
||||
q[0 * stride] = za;
|
||||
q[1 * stride] = zb;
|
||||
q[2 * stride] = zc;
|
||||
q[3 * stride] = zd;
|
||||
q[4 * stride] = ze;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
const uint32_t c0 = 0x64006400;
|
||||
const half y32_ = __float2half_rn(1.0f / 32.0f);
|
||||
const half2 y32 = __halves2half2(y32_, y32_);
|
||||
const half z1_ = __float2half_rn(-1024.0f - 16.0f);
|
||||
const half z32_ = __float2half_rn(-1024.0f / 32.0f - 16.0f);
|
||||
const half2 z1 = __halves2half2(z1_, z1_);
|
||||
const half2 z32 = __halves2half2(z32_, z32_);
|
||||
|
||||
uint32_t qa = q_0;
|
||||
uint32_t qb = q_1;
|
||||
uint32_t qc = q_2;
|
||||
uint32_t qd = q_3;
|
||||
uint32_t qe = q_4;
|
||||
|
||||
half2_uint32 q0 ((qa & 0x001f001f) | c0); // half2(q[ 0], q[ 1]) + 1024
|
||||
half2_uint32 q1 ((qa & 0x03e003e0) | c0); // half2(q[ 2], q[ 3]) * 32 + 1024
|
||||
qa >>= 10;
|
||||
half2_uint32 q2 ((qa & 0x001f001f) | c0); // half2(q[ 4], q[ 5]) + 1024
|
||||
qa >>= 5;
|
||||
qa &= 0x00010001;
|
||||
half2_uint32 q3 ((qb & 0x001f001f) | c0); // half2(q[ 6], q[ 7]) + 1024
|
||||
half2_uint32 q4 ((qb & 0x03e003e0) | c0); // half2(q[ 8], q[ 9]) * 32 + 1024
|
||||
qb >>= 10;
|
||||
half2_uint32 q5 ((qb & 0x001f001f) | c0); // half2(q[10], q[11]) + 1024
|
||||
qb >>= 4;
|
||||
qb &= 0x00020002;
|
||||
half2_uint32 q6 ((qc & 0x001f001f) | c0); // half2(q[12], q[13]) + 1024
|
||||
half2_uint32 q7 ((qc & 0x03e003e0) | c0); // half2(q[14], q[15]) * 32 + 1024
|
||||
qc >>= 10;
|
||||
half2_uint32 q8 ((qc & 0x001f001f) | c0); // half2(q[16], q[17]) + 1024
|
||||
qc >>= 3;
|
||||
qc &= 0x00040004;
|
||||
half2_uint32 q9 ((qd & 0x001f001f) | c0); // half2(q[18], q[19]) + 1024
|
||||
half2_uint32 q10((qd & 0x03e003e0) | c0); // half2(q[20], q[21]) * 32 + 1024
|
||||
qd >>= 10;
|
||||
half2_uint32 q11((qd & 0x001f001f) | c0); // half2(q[22], q[23]) + 1024
|
||||
qd >>= 2;
|
||||
qd &= 0x00080008;
|
||||
half2_uint32 q12((qe & 0x001f001f) | c0); // half2(q[24], q[25]) + 1024
|
||||
half2_uint32 q13((qe & 0x03e003e0) | c0); // half2(q[26], q[27]) * 32 + 1024
|
||||
qe >>= 10;
|
||||
half2_uint32 q14((qe & 0x001f001f) | c0); // half2(q[28], q[29]) + 1024
|
||||
qe >>= 1;
|
||||
qe &= 0x00100010;
|
||||
half2_uint32 q15((qa | qb | qc | qd | qe) | c0);
|
||||
|
||||
dq[ 0] = __hadd2( q0.as_half2, z1);
|
||||
dq[ 1] = __hfma2( q1.as_half2, y32, z32);
|
||||
dq[ 2] = __hadd2( q2.as_half2, z1);
|
||||
dq[ 3] = __hadd2( q3.as_half2, z1);
|
||||
dq[ 4] = __hfma2( q4.as_half2, y32, z32);
|
||||
dq[ 5] = __hadd2( q5.as_half2, z1);
|
||||
dq[ 6] = __hadd2( q6.as_half2, z1);
|
||||
dq[ 7] = __hfma2( q7.as_half2, y32, z32);
|
||||
dq[ 8] = __hadd2( q8.as_half2, z1);
|
||||
dq[ 9] = __hadd2( q9.as_half2, z1);
|
||||
dq[10] = __hfma2(q10.as_half2, y32, z32);
|
||||
dq[11] = __hadd2(q11.as_half2, z1);
|
||||
dq[12] = __hadd2(q12.as_half2, z1);
|
||||
dq[13] = __hfma2(q13.as_half2, y32, z32);
|
||||
dq[14] = __hadd2(q14.as_half2, z1);
|
||||
dq[15] = __hadd2(q15.as_half2, z1);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_5bit_32
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_5bit_32
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
const uint32_t q_3,
|
||||
const uint32_t q_4,
|
||||
half2 (&dq)[16],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[32];
|
||||
for (int i = 0; i < 6; i++) dqh[ i] = dq_ns(exb( q_0, i * 5 , 0x1f), 16);
|
||||
dqh[ 6 ] = dq_ns(exb(q_1, q_0, 30, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[ 7 + i] = dq_ns(exb( q_1, i * 5 + 3, 0x1f), 16);
|
||||
dqh[12 ] = dq_ns(exb(q_2, q_1, 28, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[13 + i] = dq_ns(exb( q_2, i * 5 + 1, 0x1f), 16);
|
||||
dqh[19 ] = dq_ns(exb(q_3, q_2, 31, 0x1f), 16);
|
||||
for (int i = 0; i < 5; i++) dqh[20 + i] = dq_ns(exb( q_3, i * 5 + 4, 0x1f), 16);
|
||||
dqh[25 ] = dq_ns(exb(q_4, q_3, 29, 0x1f), 16);
|
||||
for (int i = 0; i < 6; i++) dqh[26 + i] = dq_ns(exb( q_4, i * 5 + 2, 0x1f), 16);
|
||||
|
||||
for (int i = 0; i < 16; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
44
autogptq_extension/exllamav2/cuda/quant/qdq_6.cuh
Normal file
44
autogptq_extension/exllamav2/cuda/quant/qdq_6.cuh
Normal file
|
@ -0,0 +1,44 @@
|
|||
#ifndef _qdq_6_cuh
|
||||
#define _qdq_6_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_6BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_6bit_16
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_6bit_16
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
const uint32_t q_2,
|
||||
half2 (&dq)[8],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[16];
|
||||
for (int i = 0; i < 5; i++) dqh[ i] = dq_ns(exb( q_0, i * 6 , 0x3f), 32);
|
||||
dqh[ 5 ] = dq_ns(exb(q_1, q_0, 30, 0x3f), 32);
|
||||
for (int i = 0; i < 4; i++) dqh[ 6 + i] = dq_ns(exb( q_1, i * 6 + 4, 0x3f), 32);
|
||||
dqh[10 ] = dq_ns(exb(q_2, q_1, 28, 0x3f), 32);
|
||||
for (int i = 0; i < 5; i++) dqh[11 + i] = dq_ns(exb( q_2, i * 6 + 2, 0x3f), 32);
|
||||
|
||||
for (int i = 0; i < 8; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
|
38
autogptq_extension/exllamav2/cuda/quant/qdq_8.cuh
Normal file
38
autogptq_extension/exllamav2/cuda/quant/qdq_8.cuh
Normal file
|
@ -0,0 +1,38 @@
|
|||
#ifndef _qdq_8_cuh
|
||||
#define _qdq_8_cuh
|
||||
|
||||
#include "qdq_util.cuh"
|
||||
#include "../../config.h"
|
||||
|
||||
#if QMODE_8BIT == 1
|
||||
|
||||
// Not implemented
|
||||
|
||||
#else
|
||||
|
||||
__forceinline__ __device__ void shuffle_8bit_4
|
||||
(
|
||||
uint32_t* q,
|
||||
int stride
|
||||
)
|
||||
{
|
||||
}
|
||||
|
||||
__forceinline__ __device__ void dequant_8bit_8
|
||||
(
|
||||
const uint32_t q_0,
|
||||
const uint32_t q_1,
|
||||
half2 (&dq)[4],
|
||||
int stride
|
||||
)
|
||||
{
|
||||
half dqh[8];
|
||||
for (int i = 0; i < 4; i++) dqh[i ] = dq_ns(exb(q_0, i * 8, 0xff), 128);
|
||||
for (int i = 0; i < 4; i++) dqh[i + 4] = dq_ns(exb(q_1, i * 8, 0xff), 128);
|
||||
|
||||
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
#endif
|
51
autogptq_extension/exllamav2/cuda/quant/qdq_util.cuh
Normal file
51
autogptq_extension/exllamav2/cuda/quant/qdq_util.cuh
Normal file
|
@ -0,0 +1,51 @@
|
|||
#ifndef _qdq_util_cuh
|
||||
#define _qdq_util_cuh
|
||||
|
||||
union half2_uint32
|
||||
{
|
||||
uint32_t as_uint32;
|
||||
half2 as_half2;
|
||||
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
|
||||
__device__ half2_uint32(half2 val) : as_half2(val) {}
|
||||
};
|
||||
|
||||
union half_uint16
|
||||
{
|
||||
uint16_t as_uint16;
|
||||
half as_half;
|
||||
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
|
||||
__device__ half_uint16(half val) : as_half(val) {}
|
||||
};
|
||||
|
||||
// Max_scale premultiplied by 1/256
|
||||
|
||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
|
||||
{
|
||||
int qs_i = qs + 1;
|
||||
half qs_h = __int2half_rn(qs_i * qs_i);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
|
||||
{
|
||||
return __hmul(__int2half_rn(q - qzero), scale);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
|
||||
{
|
||||
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
|
||||
return __int2half_rn(q - qzero);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
|
||||
{
|
||||
return (int)((q >> shift) & mask);
|
||||
}
|
||||
|
||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
|
||||
{
|
||||
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
|
||||
}
|
||||
|
||||
#endif
|
32
autogptq_extension/exllamav2/cuda/util.cuh
Normal file
32
autogptq_extension/exllamav2/cuda/util.cuh
Normal file
|
@ -0,0 +1,32 @@
|
|||
|
||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
|
||||
|
||||
#define DBGS(__x) printf("%s\n", __x)
|
||||
#define DBGI(__x) printf("%s: %i\n", #__x, __x)
|
||||
#define DBGI2(__x, __y) printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y)
|
||||
#define DBGI3(__x, __y, __z) printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGX(__x) printf("%s: %x\n", #__x, __x)
|
||||
#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y)
|
||||
#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGF(__x) printf("%s: %f\n", #__x, __x)
|
||||
#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y)
|
||||
#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z)
|
||||
#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x))
|
||||
#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y))
|
||||
#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z))
|
||||
|
||||
#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y))
|
||||
#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z))
|
||||
|
||||
__forceinline__ __device__ half dq_scale_(const int qs, const half max_scale)
|
||||
{
|
||||
half qs_h = __hmul(__int2half_rn(qs + 1), __float2half_rn(1.0f / 16.0f));
|
||||
qs_h = __hmul(qs_h, qs_h);
|
||||
qs_h = __hmul(qs_h, max_scale);
|
||||
return qs_h;
|
||||
}
|
||||
|
||||
__forceinline__ __device__ float clamp(float x, float a, float b)
|
||||
{
|
||||
return fmaxf(a, fminf(b, x));
|
||||
}
|
134
autogptq_extension/exllamav2/ext.cpp
Normal file
134
autogptq_extension/exllamav2/ext.cpp
Normal file
|
@ -0,0 +1,134 @@
|
|||
#include <torch/extension.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#include "cuda/q_matrix.cuh"
|
||||
#include "cuda/q_gemm.cuh"
|
||||
|
||||
#include "cpp/util.h"
|
||||
|
||||
// Some decluttering macros
|
||||
|
||||
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
|
||||
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
|
||||
|
||||
|
||||
// Quant matrix
|
||||
|
||||
uintptr_t make_q_matrix
|
||||
(
|
||||
torch::Tensor q_weight,
|
||||
torch::Tensor q_perm,
|
||||
torch::Tensor q_invperm,
|
||||
torch::Tensor q_scale,
|
||||
torch::Tensor q_scale_max,
|
||||
torch::Tensor q_groups,
|
||||
torch::Tensor gptq_qzeros,
|
||||
torch::Tensor gptq_scales,
|
||||
torch::Tensor gptq_g_idx,
|
||||
torch::Tensor temp_dq
|
||||
)
|
||||
{
|
||||
TORCH_CHECK_DTYPE(q_weight, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_perm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_invperm, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(q_groups, kShort);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt);
|
||||
|
||||
TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1);
|
||||
|
||||
int device = q_weight.device().index();
|
||||
int width = q_weight.size(1);
|
||||
int groups;
|
||||
int height;
|
||||
|
||||
if (!q_scale.device().is_meta())
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1);
|
||||
groups = q_scale.size(0);
|
||||
height = q_invperm.size(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_qzeros, 1, 8);
|
||||
TORCH_CHECK_SHAPES(q_weight, 1, gptq_scales, 1, 1);
|
||||
groups = gptq_qzeros.size(0);
|
||||
height = q_weight.size(0) * 8;
|
||||
}
|
||||
|
||||
TORCH_CHECK(temp_dq.size(0) >= width * height, "Insufficient size of temp_dq buffer")
|
||||
|
||||
QMatrix* m = new QMatrix
|
||||
(
|
||||
device,
|
||||
height,
|
||||
width,
|
||||
groups,
|
||||
(uint32_t*) q_weight.data_ptr(),
|
||||
q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(),
|
||||
q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(),
|
||||
q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(),
|
||||
q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(),
|
||||
q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(),
|
||||
gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(),
|
||||
gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(),
|
||||
gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(),
|
||||
(half*) temp_dq.data_ptr()
|
||||
);
|
||||
|
||||
return reinterpret_cast<uintptr_t> (m);
|
||||
}
|
||||
|
||||
void gemm_half_q_half
|
||||
(
|
||||
torch::Tensor a,
|
||||
uintptr_t b,
|
||||
torch::Tensor c,
|
||||
bool force_cuda
|
||||
)
|
||||
{
|
||||
QMatrix* qm = reinterpret_cast<QMatrix*> (b);
|
||||
|
||||
TORCH_CHECK_DTYPE(a, kHalf);
|
||||
TORCH_CHECK_DTYPE(c, kHalf);
|
||||
TORCH_CHECK_SHAPES(a, 0, c, 0, 1);
|
||||
TORCH_CHECK(qm->height == a.size(1), "a and b have incompatible shapes")
|
||||
TORCH_CHECK(qm->width == c.size(1), "b and c have incompatible shapes")
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||
|
||||
gemm_half_q_half_cuda
|
||||
(
|
||||
at::cuda::getCurrentCUDABlasHandle(),
|
||||
(const half*) a.data_ptr(),
|
||||
qm,
|
||||
(half*) c.data_ptr(),
|
||||
c.size(0), // m
|
||||
c.size(1), // n
|
||||
a.size(1), // k
|
||||
true,
|
||||
NULL,
|
||||
force_cuda
|
||||
);
|
||||
}
|
||||
|
||||
// Bindings
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("make_q_matrix", &make_q_matrix, "make_q_matrix");
|
||||
m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half");
|
||||
}
|
11
setup.py
11
setup.py
|
@ -158,6 +158,17 @@ if BUILD_CUDA_EXT:
|
|||
extra_link_args=extra_link_args
|
||||
)
|
||||
)
|
||||
extensions.append(
|
||||
cpp_extension.CUDAExtension(
|
||||
"exllamav2_kernels",
|
||||
[
|
||||
"autogptq_extension/exllamav2/ext.cpp",
|
||||
"autogptq_extension/exllamav2/cuda/q_matrix.cu",
|
||||
"autogptq_extension/exllamav2/cuda/q_gemm.cu",
|
||||
],
|
||||
extra_link_args=extra_link_args
|
||||
)
|
||||
)
|
||||
|
||||
additional_setup_kwargs = {
|
||||
"ext_modules": extensions,
|
||||
|
|
233
tests/test_q4.py
233
tests/test_q4.py
|
@ -143,7 +143,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
|||
n = 1024
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4)
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllama=False, disable_exllamav2=True)
|
||||
|
||||
linear = linear_class(
|
||||
bits=4,
|
||||
|
@ -197,7 +197,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
|||
revision = "actorder"
|
||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
@ -227,7 +227,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
|||
|
||||
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
|
||||
model_basename = "model"
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename)
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
@ -249,7 +249,7 @@ class TestsQ4Exllama(unittest.TestCase):
|
|||
revision = "actorder"
|
||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False)
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename, disable_exllama=False, disable_exllamav2=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
@ -338,7 +338,7 @@ class TestsQ4CUDA(unittest.TestCase):
|
|||
n = 256
|
||||
device = "cuda"
|
||||
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllama=True)
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4, disable_exllamav2=True)
|
||||
|
||||
linear = linear_class(
|
||||
bits=4,
|
||||
|
@ -369,3 +369,226 @@ class TestsQ4CUDA(unittest.TestCase):
|
|||
reference = self.REFERENCE_OLD_NO_HALF.to(device)
|
||||
|
||||
self.assertTrue(torch.allclose(res, reference), get_diff(res, reference))
|
||||
|
||||
|
||||
class TestsQ4ExllamaV2(unittest.TestCase):
|
||||
|
||||
# reference generated with cuda_old
|
||||
REFERENCE = torch.Tensor([5.8398, 6.8555, 7.2734, 6.4219, 6.2070, 5.8203, 6.5664, 6.4219, 6.2148,
|
||||
5.3281, 5.7578, 7.5312, 8.1016, 6.1133, 7.2031, 6.6484, 6.5156, 6.0117,
|
||||
6.0312, 6.1914, 6.2109, 6.8125, 5.8125, 7.1172, 7.3125, 6.7305, 5.9961,
|
||||
6.5117, 6.1914, 5.9648, 7.1680, 6.4766, 7.2070, 6.5469, 6.7734, 6.4219,
|
||||
6.8086, 7.0469, 5.9297, 6.4727, 6.2539, 5.9570, 7.2383, 5.8945, 6.0820,
|
||||
5.7969, 7.1094, 6.2188, 6.7500, 7.3555, 6.2930, 6.7734, 5.9219, 7.4805,
|
||||
6.8750, 6.4102, 6.5898, 6.5469, 7.6016, 6.7461, 5.9492, 7.2227, 5.8164,
|
||||
5.4570, 6.2930, 7.3984, 6.0938, 7.3984, 5.9609, 6.3516, 6.5664, 5.7969,
|
||||
7.1250, 6.0781, 6.7930, 5.9492, 6.1641, 6.5898, 6.0586, 6.3359, 6.7930,
|
||||
7.0469, 6.0664, 6.3320, 5.4414, 6.7617, 5.1641, 7.2891, 6.8516, 6.5312,
|
||||
5.6914, 7.3711, 6.8203, 5.9492, 7.0781, 6.3164, 7.1992, 7.1133, 7.4219,
|
||||
7.5586, 7.1836, 6.9102, 6.4844, 6.9805, 6.1953, 6.5156, 5.4844, 6.6602,
|
||||
6.6719, 7.9844, 6.4727, 6.6367, 6.2227, 6.4531, 5.0625, 6.4609, 6.7031,
|
||||
6.6445, 6.5234, 6.8633, 6.6055, 5.6055, 6.4453, 7.2617, 6.3945, 6.6367,
|
||||
6.1055, 7.0664, 6.0820, 6.6875, 6.1445, 6.8672, 6.2070, 6.8828, 6.1484,
|
||||
6.7070, 6.8516, 6.2734, 7.1055, 7.0586, 6.9648, 5.9727, 6.1016, 6.8750,
|
||||
7.0078, 7.1523, 5.7383, 5.9531, 6.5508, 7.5352, 6.1602, 6.2578, 6.3906,
|
||||
5.7383, 6.7031, 5.7344, 6.3516, 5.2852, 7.5312, 6.4531, 6.6406, 6.2266,
|
||||
6.1094, 5.9102, 5.7617, 6.3789, 7.0508, 6.3750, 6.3320, 6.8555, 6.7266,
|
||||
7.0352, 7.7695, 6.3984, 6.5039, 6.8320, 6.1602, 6.0312, 6.3828, 6.9023,
|
||||
7.4336, 7.3711, 6.1016, 7.0703, 6.3281, 6.8281, 6.4922, 5.9453, 5.1016,
|
||||
6.7188, 6.1406, 6.6289, 7.2695, 6.2070, 6.7070, 7.2930, 7.1836, 6.3828,
|
||||
6.1992, 6.7070, 7.8008, 7.7773, 5.6602, 7.0273, 6.6172, 6.0898, 5.3516,
|
||||
7.3359, 5.9727, 6.0078, 7.0586, 6.3086, 6.8555, 7.2617, 7.3477, 6.3828,
|
||||
7.1133, 6.6328, 7.3516, 6.9141, 7.2031, 6.9805, 6.1719, 6.7812, 8.3047,
|
||||
6.5898, 6.3633, 6.2539, 7.2773, 6.5938, 6.4141, 6.8203, 6.8906, 7.8828,
|
||||
5.9609, 6.4180, 7.3984, 5.7539, 7.1758, 6.6641, 6.9062, 6.2578, 7.5508,
|
||||
6.1719, 6.5742, 5.9375, 6.7891, 6.2109, 6.5039, 6.8750, 6.2031, 6.8828,
|
||||
7.1094, 5.9570, 7.2969, 6.6797, 6.8828, 5.5430, 6.9648, 5.8398, 6.5430,
|
||||
6.3945, 6.5664, 5.8086, 6.6172, 7.0586, 6.8867, 6.0820, 5.8125, 6.7070,
|
||||
7.5742, 6.2578, 6.1328, 6.5391, 5.4531, 6.8242, 6.6953, 6.8008, 6.3398,
|
||||
6.4805, 7.2266, 6.3281, 6.6875, 6.4688, 5.9414, 7.4297, 5.8711, 6.0625,
|
||||
5.8750, 6.5664, 5.8867, 6.3477, 6.1133, 6.9453, 5.0547, 6.7812, 6.4922,
|
||||
7.2422, 5.4688, 6.2109, 7.2148, 6.1758, 5.9297, 7.1953, 5.5195, 6.3203,
|
||||
5.9961, 7.9297, 6.2695, 6.4414, 6.7266, 7.1875, 7.3203, 5.4062, 6.0625,
|
||||
7.0898, 5.3828, 5.6133, 6.0742, 6.6836, 5.7109, 7.2852, 7.7539, 7.5820,
|
||||
6.4258, 5.9336, 6.3750, 6.3555, 7.5469, 6.2539, 6.5898, 6.4102, 7.0469,
|
||||
5.7344, 7.2031, 6.7969, 5.6836, 7.6523, 6.9297, 7.8672, 6.4766, 6.3008,
|
||||
7.0977, 6.5430, 7.0938, 5.8398, 6.9883, 6.5312, 6.3203, 6.3594, 5.4062,
|
||||
6.9688, 5.7930, 6.3164, 6.5547, 7.1992, 5.8750, 6.3008, 6.7930, 6.0391,
|
||||
7.4766, 6.6094, 6.5625, 5.9805, 6.2422, 7.2109, 6.6875, 5.3047, 7.6211,
|
||||
5.9453, 6.5625, 6.1641, 6.1250, 6.5977, 7.7422, 7.0742, 5.6875, 6.2656,
|
||||
6.6250, 6.8945, 5.7070, 6.3203, 5.7500, 6.2695, 6.2773, 6.8516, 6.4883,
|
||||
7.0000, 6.7578, 6.1875, 5.9844, 5.5703, 6.7188, 5.5273, 5.3438, 7.2500,
|
||||
6.7852, 6.5195, 6.8125, 6.0664, 6.7852, 7.0000, 7.0781, 6.8477, 7.2930,
|
||||
6.3438, 7.1523, 6.3281, 6.8047, 7.3203, 5.3359, 6.1484, 6.5586, 7.3828,
|
||||
6.2344, 7.1523, 6.4102, 5.5898, 7.0195, 7.1172, 5.8008, 6.5742, 6.2891,
|
||||
8.0312, 6.9023, 6.5898, 7.1953, 6.7266, 6.0078, 5.5430, 6.4766, 6.4258,
|
||||
5.9648, 8.0859, 5.0547, 7.2188, 7.4375, 6.5156, 5.9922, 6.3281, 6.2852,
|
||||
6.7734, 6.2461, 6.9805, 5.4648, 5.8867, 6.8242, 6.3008, 6.3281, 7.3047,
|
||||
7.1836, 6.5195, 6.6328, 6.7188, 5.4336, 6.5078, 5.3477, 5.5508, 7.3125,
|
||||
5.8750, 6.5195, 6.2383, 6.3594, 6.0898, 6.4141, 5.9844, 6.6250, 7.7109,
|
||||
6.0391, 7.2344, 5.9453, 5.9453, 7.0586, 5.6641, 7.2773, 6.5195, 7.2227,
|
||||
6.3359, 5.3203, 6.4375, 7.2383, 6.4023, 6.2148, 7.3750, 5.8164, 6.2109,
|
||||
6.5430, 5.8164, 6.1680, 6.7656, 6.0820, 6.1094, 6.5312, 6.8906, 6.8320,
|
||||
6.1289, 6.3125, 7.6797, 6.3008, 6.0000, 7.3320, 6.7852, 6.9297, 6.6328,
|
||||
6.2266, 5.1602, 6.2031, 7.0547, 5.9492, 6.0703, 6.0977, 6.8086, 6.0742,
|
||||
6.0195, 7.0625, 6.5781, 5.7461, 6.1562, 7.0430, 6.7148, 6.5312, 6.5820,
|
||||
6.4570, 7.5508, 5.6289, 6.0547, 6.5000, 7.3125, 5.8477, 5.9297, 6.2578,
|
||||
6.0078, 5.9922, 7.3398, 7.4922, 7.8906, 7.5547, 5.4648, 6.5156, 6.3242,
|
||||
6.1094, 6.9219, 6.7227, 6.6836, 7.4023, 5.9648, 7.2383, 6.7695, 6.6797,
|
||||
7.0547, 6.3047, 6.4688, 6.9961, 6.0391, 5.9727, 6.8398, 6.7422, 5.7656,
|
||||
5.4766, 6.7852, 7.0820, 5.3516, 7.6523, 5.1562, 6.6445, 6.1211, 6.2695,
|
||||
6.0703, 6.3594, 6.4062, 6.3398, 5.7578, 6.5391, 6.2500, 6.5742, 6.5000,
|
||||
7.5625, 7.0117, 6.5547, 7.1250, 6.4453, 6.6094, 6.1875, 6.4219, 6.6172,
|
||||
6.4336, 6.5703, 6.1758, 6.4219, 6.6016, 6.7383, 6.7070, 6.1328, 5.5586,
|
||||
6.6367, 6.3789, 6.2578, 5.5039, 6.6172, 6.4648, 5.8086, 7.2031, 5.8125,
|
||||
6.3711, 7.6758, 7.1289, 5.8086, 6.3008, 6.2109, 6.1602, 6.1797, 7.2305,
|
||||
6.7266, 6.2422, 5.6719, 6.7070, 6.9414, 6.8594, 7.4023, 7.2109, 6.0156,
|
||||
6.6680, 6.6172, 7.1250, 6.6523, 6.9531, 6.7617, 6.4961, 6.9414, 5.7188,
|
||||
7.6367, 6.5469, 6.2305, 6.4414, 7.4648, 5.9102, 6.2461, 6.1367, 6.8203,
|
||||
6.5703, 6.8867, 7.0000, 6.7539, 6.1719, 6.5469, 6.2422, 5.4297, 5.7305,
|
||||
5.1641, 6.1875, 7.0312, 6.6484, 6.0234, 7.4102, 6.8711, 6.3086, 6.3711,
|
||||
6.7344, 6.6992, 5.9766, 7.3906, 7.1875, 6.4883, 6.3984, 7.3438, 6.9688,
|
||||
6.9062, 6.4375, 6.7891, 7.0117, 6.4883, 5.7500, 7.0898, 7.0742, 6.7070,
|
||||
5.8750, 6.0469, 6.6445, 5.2773, 6.8984, 6.1641, 7.0508, 7.4609, 5.0273,
|
||||
6.7734, 6.4531, 5.7656, 6.5312, 7.4648, 6.1250, 6.5625, 7.1367, 6.0625,
|
||||
6.1211, 6.9766, 6.6758, 6.3164, 6.8828, 6.8203, 6.7500, 6.5352, 7.3008,
|
||||
6.7852, 6.1914, 5.0508, 6.7188, 7.1172, 6.8008, 6.8086, 5.4883, 6.9180,
|
||||
6.5742, 6.1719, 7.0469, 7.1523, 5.9492, 5.8594, 6.8320, 6.1719, 6.2031,
|
||||
6.8398, 7.3008, 6.6289, 6.4922, 6.0000, 5.4766, 6.3320, 6.5117, 6.2812,
|
||||
7.5742, 6.3516, 7.0039, 6.4570, 7.1523, 7.6289, 6.2578, 7.1875, 6.4844,
|
||||
5.7930, 6.7070, 7.5508, 7.1797, 6.0430, 6.8711, 6.5742, 7.5781, 6.4766,
|
||||
6.5391, 6.9453, 6.1992, 6.6367, 6.2812, 6.0234, 6.6953, 7.0312, 6.2031,
|
||||
6.5625, 6.6719, 6.1719, 6.5586, 5.7031, 7.4609, 6.6211, 7.7227, 6.9141,
|
||||
6.0469, 6.2500, 5.3828, 6.0078, 5.8164, 5.8867, 6.1523, 6.6523, 6.6953,
|
||||
7.3125, 6.4844, 5.9570, 5.9531, 6.2109, 5.5039, 6.5117, 6.8203, 6.6133,
|
||||
6.4766, 5.9297, 7.1445, 7.1914, 6.0117, 6.8281, 6.7422, 6.1328, 6.9805,
|
||||
6.5625, 6.9180, 7.1133, 7.3359, 5.7617, 5.8711, 6.4961, 6.5859, 6.2422,
|
||||
6.5273, 6.7461, 6.6992, 6.7695, 6.6289, 5.9453, 5.9805, 7.1172, 6.6719,
|
||||
6.0039, 7.6875, 6.7812, 7.8359, 6.9531, 7.4336, 7.6602, 6.8164, 7.3945,
|
||||
7.1602, 6.8789, 5.0078, 6.0547, 6.8086, 6.7070, 6.4688, 6.4492, 6.6172,
|
||||
5.5625, 6.6914, 6.4297, 5.7461, 5.3359, 6.8750, 6.4609, 7.4062, 5.2070,
|
||||
6.0820, 6.7383, 6.5703, 6.1797, 6.7070, 6.5977, 5.9961, 6.6328, 6.9375,
|
||||
6.3906, 6.6484, 4.9609, 6.6445, 6.5898, 7.1875, 7.5195, 6.7969, 6.1367,
|
||||
6.8906, 7.4297, 6.3633, 6.0508, 6.5000, 6.4648, 6.7539, 6.7109, 5.8086,
|
||||
6.6016, 7.1133, 4.8672, 6.6367, 6.1641, 5.1758, 6.9453, 6.3242, 7.0664,
|
||||
6.4805, 6.3516, 6.7383, 8.4688, 6.7305, 5.9844, 6.5938, 7.2969, 6.5977,
|
||||
7.5898, 6.2969, 6.8672, 6.6680, 7.1289, 6.6875, 5.4258, 8.1875, 8.0391,
|
||||
7.7969, 6.6445, 7.0703, 7.3359, 6.9805, 6.6328, 6.5352, 6.2422, 5.5820,
|
||||
6.8633, 6.8047, 6.5703, 6.0117, 6.7539, 7.1719, 6.8438, 7.3633, 6.6016,
|
||||
7.2070, 6.4727, 5.8008, 7.4062, 7.4805, 6.6445, 5.9023, 6.3984, 6.9961,
|
||||
6.6680, 6.8242, 6.7148, 6.6172, 6.9727, 6.8320, 5.9766, 6.6133, 5.5977,
|
||||
6.7773, 7.3906, 6.9219, 7.0781, 6.6914, 5.7539, 6.7969, 6.8008, 5.8047,
|
||||
7.1055, 6.4961, 6.0352, 5.6211, 7.4414, 7.0703, 6.1172, 6.7461, 6.4492,
|
||||
7.7148, 6.4258, 6.0039, 6.5156, 7.2188, 7.4531, 7.4844, 7.5938, 7.4023,
|
||||
6.7617, 6.0078, 6.3320, 5.8906, 7.5977, 5.6523, 6.7734, 6.3008, 5.2227,
|
||||
7.1719, 7.1289, 6.6602, 5.4609, 7.0312, 6.0820, 6.1719, 6.0000, 6.5547,
|
||||
6.6328, 7.0547, 7.0859, 6.2656, 5.5234, 6.0273, 6.7891, 7.1875, 6.9531,
|
||||
6.8203, 6.3516, 6.1172, 6.4648, 6.9180, 7.3906, 6.2812, 5.7109, 6.1484,
|
||||
6.9102, 6.8711, 7.0156, 6.1445, 5.8867, 6.3828, 5.9961, 6.6914, 6.7891,
|
||||
7.0820, 6.6719, 6.9297, 6.3750, 6.7578, 6.4883, 6.2227, 6.2305, 6.0508,
|
||||
6.6484, 5.7578, 7.2070, 7.2383, 6.9375, 7.2578, 6.5312, 6.0312, 6.7930,
|
||||
6.2578, 7.0625, 7.2148, 6.4961, 7.0703, 6.4727, 7.3906]).to(torch.float16)
|
||||
|
||||
def test_exllamav2(self):
|
||||
from auto_gptq.nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
|
||||
|
||||
group_size = 128
|
||||
|
||||
m = 1
|
||||
k = 1024
|
||||
n = 1024
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
linear_class = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=4)
|
||||
|
||||
linear = linear_class(
|
||||
bits=4,
|
||||
group_size=group_size,
|
||||
infeatures=k,
|
||||
outfeatures=n,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.assertTrue(isinstance(linear, QuantLinear))
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
linear.qweight = torch.randint(-100, 100, size=linear.qweight.shape, dtype=torch.int32)
|
||||
linear.scales = linear.scales + 0.002
|
||||
|
||||
linear = linear.eval()
|
||||
linear = linear.to(device)
|
||||
|
||||
linear = autogptq_post_init(linear, use_act_order=False)
|
||||
|
||||
inp = torch.rand(1, m, k, dtype=torch.float16).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
res = linear(inp)[0][0]
|
||||
|
||||
reference = self.REFERENCE.to(device)
|
||||
|
||||
self.assertTrue(torch.allclose(res, reference, rtol=3e-5, atol=2e-2), get_diff(res, reference))
|
||||
|
||||
def test_generation_no_act_order(self):
|
||||
prompt = "I am in Paris and"
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Reference generated with the cuda-old kernel
|
||||
reference_output = "<s> I am in Paris and I am going to the Louvre Museum. What time does it open and what is the best way to get there?\nThe Louvre Museum in Paris is open from 9:00 AM to 6:00 PM every day except for Tuesdays. The best way to get"
|
||||
|
||||
model_id = "TheBloke/WizardLM-7B-uncensored-GPTQ"
|
||||
model_basename = "model"
|
||||
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, device="cuda:0", use_triton=False, use_safetensors=True, model_basename=model_basename)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
|
||||
|
||||
predicted_text = tokenizer.decode(res[0])
|
||||
|
||||
|
||||
self.assertEqual(predicted_text, reference_output)
|
||||
|
||||
def test_generation_with_act_order(self):
|
||||
prompt = "I am in Paris and"
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Reference generated with the cuda-old kernel
|
||||
reference_output = "<s> I am in Paris and it is a beautiful day. I am sitting in a café, drinking coffee and writing this book. I am surrounded by the sights and sounds of the city, and I am filled with a sense of contentment and gratitude.\n\nI am grateful for the opportunity to live and"
|
||||
|
||||
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
|
||||
revision = "actorder"
|
||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=False, inject_fused_mlp=True, model_basename=model_basename)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60)
|
||||
|
||||
predicted_text = tokenizer.decode(res[0])
|
||||
|
||||
self.assertEqual(predicted_text, reference_output)
|
||||
|
||||
def test_exllama_buffer_size(self):
|
||||
# prompt = "I'm in Paris and" * 450
|
||||
prompt = "I'm in Paris and" * 1000
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
model_id = "TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g"
|
||||
revision = "actorder"
|
||||
model_basename = "vicuna-13B-1.1-GPTQ-4bit-128g.latest"
|
||||
|
||||
model_q = AutoGPTQForCausalLM.from_quantized(model_id, revision=revision, device="cuda:0", use_triton=False, use_safetensors=True, inject_fused_attention=True, inject_fused_mlp=True, model_basename=model_basename)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
inp = tokenizer(prompt, return_tensors="pt").to(device)
|
||||
|
||||
self.assertTrue(inp["input_ids"].shape[1] > 2048) # 2048 is the default max_input_length for LLama
|
||||
|
||||
res = model_q.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
|
Loading…
Add table
Reference in a new issue