Merge remote-tracking branch 'qwopqwop200/main' into main
This commit is contained in:
commit
6a9d80eddc
27 changed files with 2498 additions and 79 deletions
|
@ -13,6 +13,7 @@ import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate.hooks import remove_hook_from_module
|
from accelerate.hooks import remove_hook_from_module
|
||||||
from safetensors.torch import save_file as safe_save
|
from safetensors.torch import save_file as safe_save
|
||||||
|
from safetensors.torch import load_file as safe_load
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||||
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
||||||
from transformers.utils.generic import ContextManagers
|
from transformers.utils.generic import ContextManagers
|
||||||
|
@ -687,7 +688,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
device: Optional[Union[str, int]] = None,
|
device: Optional[Union[str, int]] = None,
|
||||||
low_cpu_mem_usage: bool = False,
|
low_cpu_mem_usage: bool = False,
|
||||||
use_triton: bool = False,
|
use_triton: bool = False,
|
||||||
torch_dtype: torch.dtype = torch.float16,
|
use_qigen: bool = False,
|
||||||
|
torch_dtype: Optional[torch.dtype] = None,
|
||||||
inject_fused_attention: bool = True,
|
inject_fused_attention: bool = True,
|
||||||
inject_fused_mlp: bool = True,
|
inject_fused_mlp: bool = True,
|
||||||
use_cuda_fp16: bool = True,
|
use_cuda_fp16: bool = True,
|
||||||
|
@ -725,6 +727,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
"_raise_exceptions_for_missing_entries": False,
|
"_raise_exceptions_for_missing_entries": False,
|
||||||
"_commit_hash": commit_hash,
|
"_commit_hash": commit_hash,
|
||||||
}
|
}
|
||||||
|
if use_qigen:
|
||||||
|
logger.warning("QIgen is active. Ignores all settings related to cuda.")
|
||||||
|
inject_fused_attention = False
|
||||||
|
inject_fused_mlp = False
|
||||||
|
use_triton = False
|
||||||
|
disable_exllama = True
|
||||||
|
|
||||||
if use_triton and not TRITON_AVAILABLE:
|
if use_triton and not TRITON_AVAILABLE:
|
||||||
logger.warning("Triton is not installed, reset use_triton to False.")
|
logger.warning("Triton is not installed, reset use_triton to False.")
|
||||||
|
@ -802,18 +810,94 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
# == step2: convert model to gptq-model (replace Linear with QuantLinear) == #
|
||||||
def skip(*args, **kwargs):
|
def skip(*args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
if torch_dtype is None:
|
||||||
|
if not use_qigen:
|
||||||
|
torch_dtype = torch.float16
|
||||||
|
else:
|
||||||
|
torch_dtype = torch.float32
|
||||||
|
|
||||||
|
if not use_qigen:
|
||||||
|
torch.nn.init.kaiming_uniform_ = skip
|
||||||
|
torch.nn.init.uniform_ = skip
|
||||||
|
torch.nn.init.normal_ = skip
|
||||||
|
|
||||||
torch.nn.init.kaiming_uniform_ = skip
|
transformers.modeling_utils._init_weights = False
|
||||||
torch.nn.init.uniform_ = skip
|
|
||||||
torch.nn.init.normal_ = skip
|
|
||||||
|
|
||||||
transformers.modeling_utils._init_weights = False
|
init_contexts = [no_init_weights()]
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
init_contexts.append(accelerate.init_empty_weights(include_buffers=False))
|
||||||
|
|
||||||
init_contexts = [no_init_weights()]
|
with ContextManagers(init_contexts):
|
||||||
if low_cpu_mem_usage:
|
model = AutoModelForCausalLM.from_config(
|
||||||
init_contexts.append(accelerate.init_empty_weights(include_buffers=False))
|
config,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
|
torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
with ContextManagers(init_contexts):
|
layers = find_layers(model)
|
||||||
|
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
|
||||||
|
for name in list(layers.keys()):
|
||||||
|
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
|
||||||
|
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
|
||||||
|
del layers[name]
|
||||||
|
|
||||||
|
make_quant(
|
||||||
|
model,
|
||||||
|
layers,
|
||||||
|
quantize_config.bits,
|
||||||
|
quantize_config.group_size,
|
||||||
|
use_triton=use_triton,
|
||||||
|
disable_exllama=disable_exllama,
|
||||||
|
use_cuda_fp16=use_cuda_fp16,
|
||||||
|
desc_act=quantize_config.desc_act,
|
||||||
|
trainable=trainable
|
||||||
|
)
|
||||||
|
model.tie_weights()
|
||||||
|
|
||||||
|
# == step3: load checkpoint and dispatch == #
|
||||||
|
if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
||||||
|
raise ValueError(
|
||||||
|
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
||||||
|
"'sequential'."
|
||||||
|
)
|
||||||
|
if isinstance(device_map, dict):
|
||||||
|
max_memory = None
|
||||||
|
else:
|
||||||
|
if device is None and not device_map and not max_memory:
|
||||||
|
device_map = "auto"
|
||||||
|
if device is not None:
|
||||||
|
device = torch.device(device)
|
||||||
|
if not max_memory and not device_map:
|
||||||
|
device_map = {"": device.index if device.type == "cuda" else device.type}
|
||||||
|
if not isinstance(device_map, dict) and device_map != "sequential":
|
||||||
|
max_memory = accelerate.utils.get_balanced_memory(
|
||||||
|
model=model,
|
||||||
|
max_memory=max_memory,
|
||||||
|
no_split_module_classes=[cls.layer_type],
|
||||||
|
low_zero=(device_map == "balanced_low_0")
|
||||||
|
)
|
||||||
|
if not isinstance(device_map, dict):
|
||||||
|
device_map = accelerate.infer_auto_device_map(
|
||||||
|
model,
|
||||||
|
max_memory=max_memory,
|
||||||
|
no_split_module_classes=[cls.layer_type]
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits)
|
||||||
|
|
||||||
|
accelerate.utils.modeling.load_checkpoint_in_model(
|
||||||
|
model,
|
||||||
|
checkpoint=model_save_name,
|
||||||
|
device_map=device_map,
|
||||||
|
offload_state_dict=True,
|
||||||
|
offload_buffers=True
|
||||||
|
)
|
||||||
|
model = simple_dispatch_model(model, device_map)
|
||||||
|
else:
|
||||||
|
if quantize_config.desc_act:
|
||||||
|
NotImplementedError('desc_act=True is not yet supported.')
|
||||||
model = AutoModelForCausalLM.from_config(
|
model = AutoModelForCausalLM.from_config(
|
||||||
config,
|
config,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
|
@ -826,7 +910,11 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
|
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
|
||||||
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
|
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
|
||||||
del layers[name]
|
del layers[name]
|
||||||
|
|
||||||
|
if model_save_name.endswith('.safetensors'):
|
||||||
|
checkpoint = safe_load(model_save_name)
|
||||||
|
else:
|
||||||
|
checkpoint = torch.load(model_save_name)
|
||||||
make_quant(
|
make_quant(
|
||||||
model,
|
model,
|
||||||
layers,
|
layers,
|
||||||
|
@ -836,52 +924,18 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
use_cuda_fp16=use_cuda_fp16,
|
use_cuda_fp16=use_cuda_fp16,
|
||||||
desc_act=quantize_config.desc_act,
|
desc_act=quantize_config.desc_act,
|
||||||
trainable=trainable
|
trainable=trainable,
|
||||||
|
use_qigen=True
|
||||||
)
|
)
|
||||||
model.tie_weights()
|
preprocess_checkpoint_qigen(
|
||||||
|
|
||||||
# == step3: load checkpoint and dispatch == #
|
|
||||||
if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
|
|
||||||
raise ValueError(
|
|
||||||
"If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or "
|
|
||||||
"'sequential'."
|
|
||||||
)
|
|
||||||
if isinstance(device_map, dict):
|
|
||||||
max_memory = None
|
|
||||||
else:
|
|
||||||
if device is None and not device_map and not max_memory:
|
|
||||||
device_map = "auto"
|
|
||||||
if device is not None:
|
|
||||||
device = torch.device(device)
|
|
||||||
if not max_memory and not device_map:
|
|
||||||
device_map = {"": device.index if device.type == "cuda" else device.type}
|
|
||||||
if not isinstance(device_map, dict) and device_map != "sequential":
|
|
||||||
max_memory = accelerate.utils.get_balanced_memory(
|
|
||||||
model=model,
|
|
||||||
max_memory=max_memory,
|
|
||||||
no_split_module_classes=[cls.layer_type],
|
|
||||||
low_zero=(device_map == "balanced_low_0")
|
|
||||||
)
|
|
||||||
if not isinstance(device_map, dict):
|
|
||||||
device_map = accelerate.infer_auto_device_map(
|
|
||||||
model,
|
model,
|
||||||
max_memory=max_memory,
|
layers,
|
||||||
no_split_module_classes=[cls.layer_type]
|
quantize_config.bits,
|
||||||
|
quantize_config.group_size,
|
||||||
|
checkpoint
|
||||||
)
|
)
|
||||||
|
model.load_state_dict(checkpoint)
|
||||||
if low_cpu_mem_usage:
|
# == step4: set seqlen == #
|
||||||
make_sure_no_tensor_in_meta_device(model, use_triton, quantize_config.desc_act, quantize_config.group_size, bits=quantize_config.bits)
|
|
||||||
|
|
||||||
accelerate.utils.modeling.load_checkpoint_in_model(
|
|
||||||
model,
|
|
||||||
checkpoint=model_save_name,
|
|
||||||
device_map=device_map,
|
|
||||||
offload_state_dict=True,
|
|
||||||
offload_buffers=True
|
|
||||||
)
|
|
||||||
model = simple_dispatch_model(model, device_map)
|
|
||||||
|
|
||||||
# == step4: set seqlen == #
|
|
||||||
model_config = model.config.to_dict()
|
model_config = model.config.to_dict()
|
||||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||||
if any([k in model_config for k in seq_len_keys]):
|
if any([k in model_config for k in seq_len_keys]):
|
||||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
import transformers
|
import transformers
|
||||||
|
import cQIGen as qinfer
|
||||||
|
|
||||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||||
from ..utils.import_utils import dynamically_import_QuantLinear
|
from ..utils.import_utils import dynamically_import_QuantLinear
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,11 +58,12 @@ def make_quant(
|
||||||
name='',
|
name='',
|
||||||
use_triton: bool = False,
|
use_triton: bool = False,
|
||||||
disable_exllama: bool = False,
|
disable_exllama: bool = False,
|
||||||
|
use_qigen: bool = False,
|
||||||
use_cuda_fp16: bool = True,
|
use_cuda_fp16: bool = True,
|
||||||
desc_act: bool = False,
|
desc_act: bool = False,
|
||||||
trainable: bool = False
|
trainable: bool = False
|
||||||
):
|
):
|
||||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, use_qigen=use_qigen)
|
||||||
|
|
||||||
if isinstance(module, QuantLinear):
|
if isinstance(module, QuantLinear):
|
||||||
return
|
return
|
||||||
|
@ -81,7 +82,7 @@ def make_quant(
|
||||||
elif isinstance(tmp,transformers.pytorch_utils.Conv1D):
|
elif isinstance(tmp,transformers.pytorch_utils.Conv1D):
|
||||||
in_features = tmp.weight.shape[0]
|
in_features = tmp.weight.shape[0]
|
||||||
out_features = tmp.weight.shape[1]
|
out_features = tmp.weight.shape[1]
|
||||||
if (not(desc_act) or group_size == -1) and not use_triton:
|
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
|
||||||
new_layer = QuantLinear(
|
new_layer = QuantLinear(
|
||||||
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
|
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
|
||||||
)
|
)
|
||||||
|
@ -101,8 +102,73 @@ def make_quant(
|
||||||
desc_act=desc_act,
|
desc_act=desc_act,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
disable_exllama=disable_exllama,
|
disable_exllama=disable_exllama,
|
||||||
|
use_qigen=use_qigen
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def process_zeros_scales(zeros, scales, bits, out_features):
|
||||||
|
if zeros.dtype != torch.float32:
|
||||||
|
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||||
|
if bits == 4:
|
||||||
|
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||||
|
elif bits == 2:
|
||||||
|
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||||
|
elif bits == 3:
|
||||||
|
logger.info("Unpacking zeros for 3 bits")
|
||||||
|
new_scales = scales.contiguous()
|
||||||
|
else:
|
||||||
|
if scales.shape[1] != out_features:
|
||||||
|
new_scales = scales.transpose(0,1).contiguous()
|
||||||
|
else:
|
||||||
|
new_scales = scales.contiguous()
|
||||||
|
if zeros.shape[1] != out_features:
|
||||||
|
new_zeros = zeros.transpose(0,1).contiguous()
|
||||||
|
else:
|
||||||
|
new_zeros = zeros.contiguous()
|
||||||
|
|
||||||
|
return new_zeros, new_scales
|
||||||
|
|
||||||
|
def preprocess_checkpoint_qigen(
|
||||||
|
module,
|
||||||
|
names,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
checkpoint,
|
||||||
|
name='',
|
||||||
|
):
|
||||||
|
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
|
||||||
|
if isinstance(module, QuantLinear):
|
||||||
|
in_features = module.infeatures
|
||||||
|
out_features = module.outfeatures
|
||||||
|
|
||||||
|
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features)
|
||||||
|
del checkpoint[name + '.qzeros']
|
||||||
|
del checkpoint[name + '.g_idx']
|
||||||
|
if name + '.bias' in checkpoint:
|
||||||
|
checkpoint[name + '.bias'] = checkpoint[name + '.bias'].float()
|
||||||
|
else:
|
||||||
|
checkpoint[name + '.bias'] = torch.zeros(out_features)
|
||||||
|
checkpoint_qweight = checkpoint[name + '.qweight'].int().contiguous()
|
||||||
|
if bits == 4:
|
||||||
|
qweight = torch.zeros(int(in_features // 8 * out_features)).int().contiguous()
|
||||||
|
qinfer.pack4(checkpoint_qweight, qweight, in_features // 8, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
|
||||||
|
elif bits == 3:
|
||||||
|
qweight = torch.zeros(int(in_features // 32 * 3 * out_features)).int().contiguous()
|
||||||
|
qinfer.pack3(checkpoint_qweight, qweight, in_features // 32 * 3, out_features, module.mb // 32 * 3, module.tb, module.cutoff)
|
||||||
|
elif bits == 2:
|
||||||
|
qweight = torch.zeros(int(in_features // 16 * out_features)).int().contiguous()
|
||||||
|
qinfer.pack2(checkpoint_qweight, qweight, in_features // 16, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
|
||||||
|
checkpoint[name + '.qweight'] = qweight
|
||||||
|
return
|
||||||
|
|
||||||
|
for name1, child in module.named_children():
|
||||||
|
preprocess_checkpoint_qigen(
|
||||||
|
child,
|
||||||
|
names,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
checkpoint,
|
||||||
|
name + '.' + name1 if name != '' else name1,
|
||||||
|
)
|
||||||
|
|
||||||
def pack_model(
|
def pack_model(
|
||||||
model,
|
model,
|
||||||
|
@ -287,6 +353,7 @@ __all__ = [
|
||||||
"get_module_by_name_prefix",
|
"get_module_by_name_prefix",
|
||||||
"get_module_by_name_suffix",
|
"get_module_by_name_suffix",
|
||||||
"make_quant",
|
"make_quant",
|
||||||
|
"preprocess_checkpoint_qigen",
|
||||||
"pack_model",
|
"pack_model",
|
||||||
"autogptq_post_init",
|
"autogptq_post_init",
|
||||||
"check_and_get_model_type",
|
"check_and_get_model_type",
|
||||||
|
|
258
auto_gptq/nn_modules/qlinear/qlinear_qigen.py
Normal file
258
auto_gptq/nn_modules/qlinear/qlinear_qigen.py
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import cQIGen as qinfer
|
||||||
|
import math
|
||||||
|
import numpy as np
|
||||||
|
from gekko import GEKKO
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
|
||||||
|
m = GEKKO() # create GEKKO model
|
||||||
|
#cinfergen if bits==3:
|
||||||
|
# tu = tu*3
|
||||||
|
B = m.Const(value=bits)
|
||||||
|
TP = m.Const(value=T//p)
|
||||||
|
k = m.Var(1,integer=True,lb=1)
|
||||||
|
z = m.Var(1,integer=True,lb=1)
|
||||||
|
w = m.Var(1,integer=True,lb=1)
|
||||||
|
y = m.Var(1,integer=True,lb=1)
|
||||||
|
v = m.Var(1,integer=True,lb=1)
|
||||||
|
mb = m.Var(mu,integer=True,lb=1)
|
||||||
|
if gs != -1:
|
||||||
|
gg = m.Var(1,integer=True,lb=1)
|
||||||
|
tb = m.Var(tu,integer=True,lb=1,ub=int(T/p))
|
||||||
|
L = m.Var(integer=True,lb=0,ub=l1)
|
||||||
|
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
|
||||||
|
m.Equation(mb * k == M)
|
||||||
|
if gs != -1:
|
||||||
|
m.Equation(gs * gg == mb)
|
||||||
|
# m.Equation(tb * z == T)
|
||||||
|
m.Equation(tb * z == TP)
|
||||||
|
m.Equation(mu * w == mb)
|
||||||
|
m.Equation(tu * y == tb)
|
||||||
|
# m.Equation(tb * v == tt)
|
||||||
|
m.Maximize(L)
|
||||||
|
m.options.SOLVER = 1
|
||||||
|
m.solver_options = ['minlp_maximum_iterations 1000', \
|
||||||
|
# minlp iterations with integer solution
|
||||||
|
'minlp_max_iter_with_int_sol 10', \
|
||||||
|
# treat minlp as nlp
|
||||||
|
'minlp_as_nlp 0', \
|
||||||
|
# nlp sub-problem max iterations
|
||||||
|
'nlp_maximum_iterations 100', \
|
||||||
|
# 1 = depth first, 2 = breadth first
|
||||||
|
'minlp_branch_method 2', \
|
||||||
|
# maximum deviation from whole number
|
||||||
|
'minlp_integer_tol 0.00', \
|
||||||
|
# covergence tolerance
|
||||||
|
'minlp_gap_tol 0.01']
|
||||||
|
try:
|
||||||
|
m.solve(disp=False)
|
||||||
|
except:
|
||||||
|
try:
|
||||||
|
m.solver_options = ['minlp_maximum_iterations 1000', \
|
||||||
|
# minlp iterations with integer solution
|
||||||
|
'minlp_max_iter_with_int_sol 10', \
|
||||||
|
# treat minlp as nlp
|
||||||
|
'minlp_as_nlp 0', \
|
||||||
|
# nlp sub-problem max iterations
|
||||||
|
'nlp_maximum_iterations 100', \
|
||||||
|
# 1 = depth first, 2 = breadth first
|
||||||
|
'minlp_branch_method 1', \
|
||||||
|
# maximum deviation from whole number
|
||||||
|
'minlp_integer_tol 0.00', \
|
||||||
|
# covergence tolerance
|
||||||
|
'minlp_gap_tol 0.01']
|
||||||
|
m.solve(disp=False)
|
||||||
|
except:
|
||||||
|
# mytb = T//p
|
||||||
|
mytb = tu
|
||||||
|
if gs != -1:
|
||||||
|
mymb = gs
|
||||||
|
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
|
||||||
|
mymb += gs
|
||||||
|
while M % mymb != 0:
|
||||||
|
mymb -= gs
|
||||||
|
return (int(mymb), int(mytb))
|
||||||
|
else:
|
||||||
|
mymb = mu
|
||||||
|
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
|
||||||
|
mymb += mu
|
||||||
|
while M % mymb != 0:
|
||||||
|
mymb -= mu
|
||||||
|
return (int(mymb), int(mytb))
|
||||||
|
|
||||||
|
return (int(mb.value[0]), int(tb.value[0]))
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
def compute_reductions(x, gs=-1, cpp=True):
|
||||||
|
if cpp:
|
||||||
|
if len(x.shape) != 1:
|
||||||
|
rows, cols = x.shape
|
||||||
|
else:
|
||||||
|
rows = 1
|
||||||
|
cols = x.shape[0]
|
||||||
|
if gs == -1:
|
||||||
|
out = torch.zeros(rows).float().contiguous()
|
||||||
|
mygs = cols
|
||||||
|
else:
|
||||||
|
out = torch.zeros(rows, cols // gs).float().contiguous()
|
||||||
|
mygs = gs
|
||||||
|
|
||||||
|
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
|
||||||
|
return out
|
||||||
|
if gs == -1:
|
||||||
|
if len(x.shape) != 1:
|
||||||
|
return torch.sum(x,1)
|
||||||
|
else:
|
||||||
|
return torch.sum(x)
|
||||||
|
else:
|
||||||
|
if len(x.shape) != 1:
|
||||||
|
rows, cols = x.shape
|
||||||
|
out = torch.zeros(rows, cols // gs).float().contiguous()
|
||||||
|
for i in range(cols // gs):
|
||||||
|
out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1)
|
||||||
|
return out
|
||||||
|
else:
|
||||||
|
cols = x.shape[0]
|
||||||
|
out = torch.zeros(cols // gs).float().contiguous()
|
||||||
|
for i in range(cols // gs):
|
||||||
|
out[i] = torch.sum(x[i*gs:(i+1)*gs])
|
||||||
|
return out
|
||||||
|
|
||||||
|
def process_zeros_scales(zeros, scales, bits, M):
|
||||||
|
if zeros.dtype != torch.float32:
|
||||||
|
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||||
|
if bits == 4:
|
||||||
|
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||||
|
elif bits == 2:
|
||||||
|
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||||
|
elif bits == 3:
|
||||||
|
logger.info("Unpacking zeros for 3 bits")
|
||||||
|
new_scales = scales.contiguous()
|
||||||
|
else:
|
||||||
|
if scales.shape[1] != M:
|
||||||
|
new_scales = scales.transpose(0,1).contiguous()
|
||||||
|
else:
|
||||||
|
new_scales = scales.contiguous()
|
||||||
|
if zeros.shape[1] != M:
|
||||||
|
new_zeros = zeros.transpose(0,1).contiguous()
|
||||||
|
else:
|
||||||
|
new_zeros = zeros.contiguous()
|
||||||
|
|
||||||
|
return new_zeros, new_scales
|
||||||
|
|
||||||
|
class QuantLinear(nn.Module):
|
||||||
|
QUANT_TYPE = "qigen"
|
||||||
|
|
||||||
|
def __init__(self, bits, group_size, infeatures, outfeatures, bias=None, trainable=False, hint=1, p=8, l1=2**18):
|
||||||
|
super().__init__()
|
||||||
|
if bits not in [2, 4]:
|
||||||
|
raise NotImplementedError("Only 2,4 bits are supported.")
|
||||||
|
if trainable:
|
||||||
|
raise NotImplementedError("Qigen kernel does not support training.")
|
||||||
|
self.bits = bits
|
||||||
|
pack = 32 // bits
|
||||||
|
|
||||||
|
self.infeatures = infeatures
|
||||||
|
self.outfeatures = outfeatures
|
||||||
|
|
||||||
|
n = hint
|
||||||
|
m = self.infeatures
|
||||||
|
t = self.outfeatures
|
||||||
|
|
||||||
|
#registers for now are fixed
|
||||||
|
if bits == 3:
|
||||||
|
packed = 32
|
||||||
|
unroll = 3
|
||||||
|
nu = 1 #args.n
|
||||||
|
mu = 32
|
||||||
|
tu = 32
|
||||||
|
else:
|
||||||
|
packed = 32 // bits
|
||||||
|
unroll = 2
|
||||||
|
nu = 1 #args.n
|
||||||
|
mu = 16
|
||||||
|
tu = 32
|
||||||
|
|
||||||
|
nb = n # it's always small for transformers
|
||||||
|
|
||||||
|
global params
|
||||||
|
if (m,t) in params:
|
||||||
|
mb = params[(m,t)][0]
|
||||||
|
tb = params[(m,t)][1]
|
||||||
|
else:
|
||||||
|
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
|
||||||
|
params[(m,t)] = (mb,tb)
|
||||||
|
|
||||||
|
split = np.ones(p)
|
||||||
|
split = split * tb
|
||||||
|
while np.sum(split) < t:
|
||||||
|
split = split + tb
|
||||||
|
|
||||||
|
idx = p - 1
|
||||||
|
while np.sum(split) > t:
|
||||||
|
split[idx] = split[idx] - tb
|
||||||
|
idx = idx - 1
|
||||||
|
|
||||||
|
assert(np.sum(split) == t)
|
||||||
|
|
||||||
|
split = split.astype(int)
|
||||||
|
self.tt = int(split[0])
|
||||||
|
|
||||||
|
if split[0] == split[-1]:
|
||||||
|
self.cutoff = int(p+1)
|
||||||
|
else:
|
||||||
|
self.cutoff = int(idx + 1)
|
||||||
|
|
||||||
|
self.mb = mb #// packed
|
||||||
|
self.tb = tb
|
||||||
|
|
||||||
|
self.group_size = group_size
|
||||||
|
|
||||||
|
self.register_buffer('bias', torch.zeros(self.outfeatures))
|
||||||
|
self.register_buffer('zeros', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
||||||
|
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
||||||
|
if bits == 4:
|
||||||
|
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
||||||
|
elif bits == 3:
|
||||||
|
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous())
|
||||||
|
elif bits == 2:
|
||||||
|
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||||
|
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
|
||||||
|
B = x.shape[0]
|
||||||
|
new_x = x.T.contiguous()
|
||||||
|
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
|
||||||
|
sums = compute_reductions(x,gs=self.group_size,cpp=True).contiguous()
|
||||||
|
if self.group_size == -1:
|
||||||
|
if self.bits == 4:
|
||||||
|
qinfer.forward4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||||
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||||
|
elif self.bits == 2:
|
||||||
|
qinfer.forward2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||||
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||||
|
elif self.bits == 3:
|
||||||
|
qinfer.forward3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||||
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||||
|
else:
|
||||||
|
if self.bits == 4:
|
||||||
|
qinfer.forward_gs4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||||
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||||
|
elif self.bits == 2:
|
||||||
|
qinfer.forward_gs2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||||
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||||
|
elif self.bits == 3:
|
||||||
|
qinfer.forward_gs3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||||
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||||
|
return out.reshape(out_shape)
|
|
@ -28,19 +28,22 @@ except:
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False):
|
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False, use_qigen: bool = False):
|
||||||
if use_triton:
|
if use_qigen:
|
||||||
if torch.version.hip:
|
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
|
||||||
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
|
|
||||||
|
|
||||||
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
|
||||||
else:
|
else:
|
||||||
if bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
|
if use_triton:
|
||||||
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
|
if torch.version.hip:
|
||||||
elif not desc_act or group_size == -1:
|
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
|
||||||
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
|
|
||||||
|
from ..nn_modules.qlinear.qlinear_triton import QuantLinear
|
||||||
else:
|
else:
|
||||||
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear
|
if bits == 4 and not disable_exllama and EXLLAMA_KERNELS_AVAILABLE:
|
||||||
|
from ..nn_modules.qlinear.qlinear_exllama import QuantLinear
|
||||||
|
elif not desc_act or group_size == -1:
|
||||||
|
from ..nn_modules.qlinear.qlinear_cuda_old import QuantLinear
|
||||||
|
else:
|
||||||
|
from ..nn_modules.qlinear.qlinear_cuda import QuantLinear
|
||||||
|
|
||||||
return QuantLinear
|
return QuantLinear
|
||||||
|
|
||||||
|
|
1484
autogptq_extension/qigen/generate.py
Normal file
1484
autogptq_extension/qigen/generate.py
Normal file
File diff suppressed because it is too large
Load diff
149
autogptq_extension/qigen/intrin.py
Normal file
149
autogptq_extension/qigen/intrin.py
Normal file
|
@ -0,0 +1,149 @@
|
||||||
|
|
||||||
|
def load_int(to, address, const=True):
|
||||||
|
if const:
|
||||||
|
return f"const __m256i {to} = _mm256_loadu_si256({address});"
|
||||||
|
else:
|
||||||
|
return f"__m256i {to} = _mm256_loadu_si256({address});"
|
||||||
|
|
||||||
|
def load_fp(to, address, const=True):
|
||||||
|
if const:
|
||||||
|
return f"const __m256 {to} = _mm256_loadu_ps({address});"
|
||||||
|
else:
|
||||||
|
return f"__m256 {to} = _mm256_loadu_ps({address});"
|
||||||
|
|
||||||
|
# to = a * b + c
|
||||||
|
def vfma(to, a, b, c):
|
||||||
|
return f"__m256 {to} = _mm256_fmadd_ps({a}, {b}, {c});"
|
||||||
|
|
||||||
|
def vsrli(to, a, b):
|
||||||
|
return f"const __m256i {to} = _mm256_srli_epi32({a}, {b});"
|
||||||
|
|
||||||
|
def vand(to, a, b):
|
||||||
|
return f"const __m256i {to} = _mm256_and_si256({a}, {b});"
|
||||||
|
|
||||||
|
def vbroadcast_fp(to, a):
|
||||||
|
return f"const __m256 {to} = _mm256_set1_ps({a});"
|
||||||
|
|
||||||
|
def vbroadcast_int32(to, a):
|
||||||
|
return f"__m256i {to} = _mm256_set1_epi32({a});"
|
||||||
|
|
||||||
|
def vsetzero(to):
|
||||||
|
return f"__m256 {to} = _mm256_setzero_ps();"
|
||||||
|
|
||||||
|
def vcvtepi32_ps(to, a):
|
||||||
|
return f"const __m256 {to} = _mm256_cvtepi32_ps({a});"
|
||||||
|
|
||||||
|
def _256extractf128_ps(to, a, imm):
|
||||||
|
return f"const __m128 {to} = _mm256_extractf128_ps({a}, {imm});"
|
||||||
|
|
||||||
|
def _256castps256_ps128(to, a):
|
||||||
|
return f"const __m128 {to} = _mm256_castps256_ps128({a});"
|
||||||
|
|
||||||
|
def _add_ps(to, a, b):
|
||||||
|
return f"const __m128 {to} = _mm_add_ps({a}, {b});"
|
||||||
|
|
||||||
|
def _movehl_ps(to, a, b):
|
||||||
|
return f"const __m128 {to} = _mm_movehl_ps({a}, {b});"
|
||||||
|
|
||||||
|
def _shuffle_ps(to, a, b, imm):
|
||||||
|
return f"const __m128 {to} = _mm_shuffle_ps({a}, {b}, {imm});"
|
||||||
|
|
||||||
|
def _cvtss_f32(to, a):
|
||||||
|
return f"const float {to} = _mm_cvtss_f32({a});"
|
||||||
|
|
||||||
|
def _reduce8_acc(a, b, c, d, e, f, g, h):
|
||||||
|
res = ""
|
||||||
|
res += _256extractf128_ps("hi_quad0", a, 1)
|
||||||
|
res += _256extractf128_ps("hi_quad1", b, 1)
|
||||||
|
res += _256extractf128_ps("hi_quad2", c, 1)
|
||||||
|
res += _256extractf128_ps("hi_quad3", d, 1)
|
||||||
|
res += _256extractf128_ps("hi_quad4", e, 1)
|
||||||
|
res += _256extractf128_ps("hi_quad5", f, 1)
|
||||||
|
res += _256extractf128_ps("hi_quad6", g, 1)
|
||||||
|
res += _256extractf128_ps("hi_quad7", h, 1)
|
||||||
|
|
||||||
|
res += _256castps256_ps128("lo_quad0", a)
|
||||||
|
res += _256castps256_ps128("lo_quad1", b)
|
||||||
|
res += _256castps256_ps128("lo_quad2", c)
|
||||||
|
res += _256castps256_ps128("lo_quad3", d)
|
||||||
|
res += _256castps256_ps128("lo_quad4", e)
|
||||||
|
res += _256castps256_ps128("lo_quad5", f)
|
||||||
|
res += _256castps256_ps128("lo_quad6", g)
|
||||||
|
res += _256castps256_ps128("lo_quad7", h)
|
||||||
|
|
||||||
|
res += _add_ps("sum_quad0", "lo_quad0", "hi_quad0")
|
||||||
|
res += _add_ps("sum_quad1", "lo_quad1", "hi_quad1")
|
||||||
|
res += _add_ps("sum_quad2", "lo_quad2", "hi_quad2")
|
||||||
|
res += _add_ps("sum_quad3", "lo_quad3", "hi_quad3")
|
||||||
|
res += _add_ps("sum_quad4", "lo_quad4", "hi_quad4")
|
||||||
|
res += _add_ps("sum_quad5", "lo_quad5", "hi_quad5")
|
||||||
|
res += _add_ps("sum_quad6", "lo_quad6", "hi_quad6")
|
||||||
|
res += _add_ps("sum_quad7", "lo_quad7", "hi_quad7")
|
||||||
|
|
||||||
|
res += _movehl_ps("hi_dual0", "sum_quad0", "sum_quad0")
|
||||||
|
res += _movehl_ps("hi_dual1", "sum_quad1", "sum_quad1")
|
||||||
|
res += _movehl_ps("hi_dual2", "sum_quad2", "sum_quad2")
|
||||||
|
res += _movehl_ps("hi_dual3", "sum_quad3", "sum_quad3")
|
||||||
|
res += _movehl_ps("hi_dual4", "sum_quad4", "sum_quad4")
|
||||||
|
res += _movehl_ps("hi_dual5", "sum_quad5", "sum_quad5")
|
||||||
|
res += _movehl_ps("hi_dual6", "sum_quad6", "sum_quad6")
|
||||||
|
res += _movehl_ps("hi_dual7", "sum_quad7", "sum_quad7")
|
||||||
|
|
||||||
|
res += _add_ps("sum_dual0", "sum_quad0", "hi_dual0")
|
||||||
|
res += _add_ps("sum_dual1", "sum_quad1", "hi_dual1")
|
||||||
|
res += _add_ps("sum_dual2", "sum_quad2", "hi_dual2")
|
||||||
|
res += _add_ps("sum_dual3", "sum_quad3", "hi_dual3")
|
||||||
|
res += _add_ps("sum_dual4", "sum_quad4", "hi_dual4")
|
||||||
|
res += _add_ps("sum_dual5", "sum_quad5", "hi_dual5")
|
||||||
|
res += _add_ps("sum_dual6", "sum_quad6", "hi_dual6")
|
||||||
|
res += _add_ps("sum_dual7", "sum_quad7", "hi_dual7")
|
||||||
|
|
||||||
|
res += _shuffle_ps("hi0", "sum_dual0", "sum_dual0", 0x1)
|
||||||
|
res += _shuffle_ps("hi1", "sum_dual1", "sum_dual1", 0x1)
|
||||||
|
res += _shuffle_ps("hi2", "sum_dual2", "sum_dual2", 0x1)
|
||||||
|
res += _shuffle_ps("hi3", "sum_dual3", "sum_dual3", 0x1)
|
||||||
|
res += _shuffle_ps("hi4", "sum_dual4", "sum_dual4", 0x1)
|
||||||
|
res += _shuffle_ps("hi5", "sum_dual5", "sum_dual5", 0x1)
|
||||||
|
res += _shuffle_ps("hi6", "sum_dual6", "sum_dual6", 0x1)
|
||||||
|
res += _shuffle_ps("hi7", "sum_dual7", "sum_dual7", 0x1)
|
||||||
|
|
||||||
|
res += _add_ps("sum0", "sum_dual0", "hi0")
|
||||||
|
res += _add_ps("sum1", "sum_dual1", "hi1")
|
||||||
|
res += _add_ps("sum2", "sum_dual2", "hi2")
|
||||||
|
res += _add_ps("sum3", "sum_dual3", "hi3")
|
||||||
|
res += _add_ps("sum4", "sum_dual4", "hi4")
|
||||||
|
res += _add_ps("sum5", "sum_dual5", "hi5")
|
||||||
|
res += _add_ps("sum6", "sum_dual6", "hi6")
|
||||||
|
res += _add_ps("sum7", "sum_dual7", "hi7")
|
||||||
|
|
||||||
|
res += _cvtss_f32(f"f{a}", "sum0")
|
||||||
|
res += _cvtss_f32(f"f{b}", "sum1")
|
||||||
|
res += _cvtss_f32(f"f{c}", "sum2")
|
||||||
|
res += _cvtss_f32(f"f{d}", "sum3")
|
||||||
|
res += _cvtss_f32(f"f{e}", "sum4")
|
||||||
|
res += _cvtss_f32(f"f{f}", "sum5")
|
||||||
|
res += _cvtss_f32(f"f{g}", "sum6")
|
||||||
|
res += _cvtss_f32(f"f{h}", "sum7")
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
acc_idx = 0
|
||||||
|
def _reduce_add(a):
|
||||||
|
global acc_idx
|
||||||
|
res = ""
|
||||||
|
res += _256extractf128_ps(f"hi_quad{acc_idx}", a, 1)
|
||||||
|
res += _256castps256_ps128(f"lo_quad{acc_idx}", a)
|
||||||
|
res += _add_ps(f"sum_quad{acc_idx}", f"lo_quad{acc_idx}", f"hi_quad{acc_idx}")
|
||||||
|
res += _movehl_ps(f"hi_dual{acc_idx}", f"sum_quad{acc_idx}", f"sum_quad{acc_idx}")
|
||||||
|
res += _add_ps(f"sum_dual{acc_idx}", f"sum_quad{acc_idx}", f"hi_dual{acc_idx}")
|
||||||
|
res += _shuffle_ps(f"hi{acc_idx}", f"sum_dual{acc_idx}", f"sum_dual{acc_idx}", 0x1)
|
||||||
|
res += _add_ps(f"sum{acc_idx}", f"sum_dual{acc_idx}", f"hi{acc_idx}")
|
||||||
|
res += _cvtss_f32(f"f{a}", f"sum{acc_idx}")
|
||||||
|
acc_idx += 1
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
302
autogptq_extension/qigen/mmm.cpp
Normal file
302
autogptq_extension/qigen/mmm.cpp
Normal file
|
@ -0,0 +1,302 @@
|
||||||
|
#include <iostream>
|
||||||
|
#include "forward.h"
|
||||||
|
#include <cstring>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
|
#include <chrono>
|
||||||
|
#include <fstream>
|
||||||
|
|
||||||
|
#define mymin(a,b) ((a)<(b)?(a):(b))
|
||||||
|
#define mymax(a,b) ((a)>(b)?(a):(b))
|
||||||
|
|
||||||
|
void print_matrix(std::string name, float* A, int N, int M){
|
||||||
|
std::cout<<name<<std::endl;
|
||||||
|
for(int i = 0; i < N; i++){
|
||||||
|
for(int j = 0; j < M; j++){
|
||||||
|
std::cout << A[i*M+j] << " ";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
}
|
||||||
|
std::cout<<std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void oracle_mmadd(float* A, float* B, float* bias, float* C, int n, int m, int t){
|
||||||
|
// triple loop matmul and add bias
|
||||||
|
for (int i = 0; i < n; i++){
|
||||||
|
for (int j = 0; j < t; j++){
|
||||||
|
float sum = 0;
|
||||||
|
for (int k = 0; k < m; k++){
|
||||||
|
sum += A[i*m+k] * B[k*t+j];
|
||||||
|
}
|
||||||
|
C[i*t+j] += sum + bias[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute_reduction(float *in, float *out, int n, int m, int gs){
|
||||||
|
int ng;
|
||||||
|
if(gs == -1){
|
||||||
|
ng = 1;
|
||||||
|
gs = m;
|
||||||
|
}else{
|
||||||
|
ng = m/gs;
|
||||||
|
}
|
||||||
|
for(int i = 0; i < n; i++){
|
||||||
|
for(int j0 = 0; j0 < m; j0+=gs){
|
||||||
|
int j = j0/gs;
|
||||||
|
out[i*ng+j] = 0;
|
||||||
|
for(int j1 = j0; j1 < j0+gs; j1++){
|
||||||
|
out[i*ng+j] += in[i*m+j1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize_sim(float* A, float* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
|
||||||
|
//find scales and zeros arrays
|
||||||
|
if(gs == -1){
|
||||||
|
gs = n;
|
||||||
|
}
|
||||||
|
float range = (1<<bits) - 1;
|
||||||
|
int packed = 32 / bits;
|
||||||
|
|
||||||
|
for(int i0 = 0; i0 < n; i0+=gs){
|
||||||
|
int row = i0/gs;
|
||||||
|
for(int j = 0; j < m; j++){
|
||||||
|
float min = A[i0*m + j];
|
||||||
|
float max = A[i0*m + j];
|
||||||
|
for(int i1 = i0; i1 < i0+gs; i1++){
|
||||||
|
min = mymin(min, A[i1*m+j]);
|
||||||
|
max = mymax(max, A[i1*m+j]);
|
||||||
|
}
|
||||||
|
scales[row*m + j] = (max-min)/range;
|
||||||
|
zeros[row*m + j ] = min;
|
||||||
|
}
|
||||||
|
for(int j = 0; j < m; j++){
|
||||||
|
for (int i1 = i0; i1 < i0+gs; i1++){
|
||||||
|
uint32_t acc = 0;
|
||||||
|
int temp = (A[i1*m+j] - zeros[row*m+j])/scales[row*m+j];
|
||||||
|
float val = ((float) temp + zeros[row*m+j]) * scales[row*m+j];
|
||||||
|
BQ[i1*m+j] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void quantize(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
|
||||||
|
//find scales and zeros arrays
|
||||||
|
if(gs == -1){
|
||||||
|
gs = n;
|
||||||
|
}
|
||||||
|
float range = (1<<bits) - 1;
|
||||||
|
int packed = 32 / bits;
|
||||||
|
|
||||||
|
for(int i0 = 0; i0 < n; i0+=gs){
|
||||||
|
int row = i0/gs;
|
||||||
|
for(int j = 0; j < m; j++){
|
||||||
|
float min = A[i0*m + j];
|
||||||
|
float max = A[i0*m + j];
|
||||||
|
for(int i1 = i0; i1 < i0+gs; i1++){
|
||||||
|
min = mymin(min, A[i1*m+j]);
|
||||||
|
max = mymax(max, A[i1*m+j]);
|
||||||
|
}
|
||||||
|
scales[row*m + j] = (max-min)/range;
|
||||||
|
zeros[row*m + j ] = min;
|
||||||
|
}
|
||||||
|
for(int j = 0; j < m; j++){
|
||||||
|
if(bits == 3){
|
||||||
|
for (int i1 = i0; i1 < i0+gs; i1+=32){
|
||||||
|
uint32_t acc = 0;
|
||||||
|
int temp0 = ((int)((A[(i1+0)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 0;
|
||||||
|
int temp1 = ((int)((A[(i1+1)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 3;
|
||||||
|
int temp2 = ((int)((A[(i1+2)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 6;
|
||||||
|
int temp3 = ((int)((A[(i1+3)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 9;
|
||||||
|
int temp4 = ((int)((A[(i1+4)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 12;
|
||||||
|
int temp5 = ((int)((A[(i1+5)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 15;
|
||||||
|
int temp6 = ((int)((A[(i1+6)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 18;
|
||||||
|
int temp7 = ((int)((A[(i1+7)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 21;
|
||||||
|
int temp8 = ((int)((A[(i1+8)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 24;
|
||||||
|
int temp9 = ((int)((A[(i1+9)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 27;
|
||||||
|
int temp10_0 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 30;
|
||||||
|
int temp10_1 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 2;
|
||||||
|
int temp11 = ((int)((A[(i1+11)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 1;
|
||||||
|
int temp12 = ((int)((A[(i1+12)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 4;
|
||||||
|
int temp13 = ((int)((A[(i1+13)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 7;
|
||||||
|
int temp14 = ((int)((A[(i1+14)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 10;
|
||||||
|
int temp15 = ((int)((A[(i1+15)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 13;
|
||||||
|
int temp16 = ((int)((A[(i1+16)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 16;
|
||||||
|
int temp17 = ((int)((A[(i1+17)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 19;
|
||||||
|
int temp18 = ((int)((A[(i1+18)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 22;
|
||||||
|
int temp19 = ((int)((A[(i1+19)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 25;
|
||||||
|
int temp20 = ((int)((A[(i1+20)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 28;
|
||||||
|
int temp21_0 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 31;
|
||||||
|
int temp21_1 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 1;
|
||||||
|
int temp22 = ((int)((A[(i1+22)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 2;
|
||||||
|
int temp23 = ((int)((A[(i1+23)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 5;
|
||||||
|
int temp24 = ((int)((A[(i1+24)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 8;
|
||||||
|
int temp25 = ((int)((A[(i1+25)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 11;
|
||||||
|
int temp26 = ((int)((A[(i1+26)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 14;
|
||||||
|
int temp27 = ((int)((A[(i1+27)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 17;
|
||||||
|
int temp28 = ((int)((A[(i1+28)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 20;
|
||||||
|
int temp29 = ((int)((A[(i1+29)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 23;
|
||||||
|
int temp30 = ((int)((A[(i1+30)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 26;
|
||||||
|
int temp31 = ((int)((A[(i1+31)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 29;
|
||||||
|
|
||||||
|
int acc0 = 0, acc1 = 0, acc2 = 0;
|
||||||
|
|
||||||
|
acc0 |= temp0;
|
||||||
|
acc0 |= temp1;
|
||||||
|
acc0 |= temp2;
|
||||||
|
acc0 |= temp3;
|
||||||
|
acc0 |= temp4;
|
||||||
|
acc0 |= temp5;
|
||||||
|
acc0 |= temp6;
|
||||||
|
acc0 |= temp7;
|
||||||
|
acc0 |= temp8;
|
||||||
|
acc0 |= temp9;
|
||||||
|
acc0 |= temp10_0;
|
||||||
|
|
||||||
|
acc1 |= temp10_1;
|
||||||
|
acc1 |= temp11;
|
||||||
|
acc1 |= temp12;
|
||||||
|
acc1 |= temp13;
|
||||||
|
acc1 |= temp14;
|
||||||
|
acc1 |= temp15;
|
||||||
|
acc1 |= temp16;
|
||||||
|
acc1 |= temp17;
|
||||||
|
acc1 |= temp18;
|
||||||
|
acc1 |= temp19;
|
||||||
|
acc1 |= temp20;
|
||||||
|
acc1 |= temp21_0;
|
||||||
|
|
||||||
|
acc2 |= temp21_1;
|
||||||
|
acc2 |= temp22;
|
||||||
|
acc2 |= temp23;
|
||||||
|
acc2 |= temp24;
|
||||||
|
acc2 |= temp25;
|
||||||
|
acc2 |= temp26;
|
||||||
|
acc2 |= temp27;
|
||||||
|
acc2 |= temp28;
|
||||||
|
acc2 |= temp29;
|
||||||
|
acc2 |= temp30;
|
||||||
|
acc2 |= temp31;
|
||||||
|
|
||||||
|
BQ[(3*i1/32)*m+j] = acc0;
|
||||||
|
BQ[(3*i1/32+1)*m+j] = acc1;
|
||||||
|
BQ[(3*i1/32+2)*m+j] = acc2;
|
||||||
|
}
|
||||||
|
|
||||||
|
}else{
|
||||||
|
for (int i1 = i0; i1 < i0+gs; i1+=packed){
|
||||||
|
uint32_t acc = 0;
|
||||||
|
for (int i2 = i1; i2 < i1+packed; i2++){
|
||||||
|
int temp = (A[i2*m+j] - zeros[row*m+j])/scales[row*m+j];
|
||||||
|
acc = acc | (temp << (bits*(i2-i1)));
|
||||||
|
}
|
||||||
|
BQ[(i1/packed)*m+j] = acc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char *argv[]){
|
||||||
|
// read n m t from args
|
||||||
|
if(argc == 0){std::cout << "Parameters not given\n"; return 0;}
|
||||||
|
int n = atoi(argv[1]);
|
||||||
|
int m = atoi(argv[2]);
|
||||||
|
int t = atoi(argv[3]);
|
||||||
|
int bits = atoi(argv[4]);
|
||||||
|
int gs = atoi(argv[5]);
|
||||||
|
int ng;
|
||||||
|
if(gs == -1){
|
||||||
|
ng = 1;
|
||||||
|
}else{
|
||||||
|
ng = m/gs;
|
||||||
|
}
|
||||||
|
float* A = new float[n*m];
|
||||||
|
float* AB = new float[n*m];
|
||||||
|
float* B = new float[m*t];
|
||||||
|
float* BQS = new float[m*t];
|
||||||
|
float* scales = new float[t*ng];
|
||||||
|
float* zeros = new float[t*ng];
|
||||||
|
int* BQ = new int[m*t/8];
|
||||||
|
int* BQB = new int[m*t/8];
|
||||||
|
float* sums = new float[n*ng];
|
||||||
|
float* bias = new float[t];
|
||||||
|
float* C = new float[n*t];
|
||||||
|
float* CB = new float[n*t];
|
||||||
|
float* C2 = new float[n*t];
|
||||||
|
srand(1);
|
||||||
|
for (int i = 0; i < n*m; i++){
|
||||||
|
A[i] = (float)rand() / RAND_MAX;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < t*m; i++){
|
||||||
|
B[i] = (float)rand() / RAND_MAX;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < t; i++){
|
||||||
|
bias[i] = (float)rand() / RAND_MAX;
|
||||||
|
}
|
||||||
|
for (int i = 0; i < n*t; i++){
|
||||||
|
C[i] = 0.0;
|
||||||
|
C2[i] = 0.0;
|
||||||
|
}
|
||||||
|
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
|
||||||
|
quantize(B,BQ,scales,zeros,m,t,bits,gs);
|
||||||
|
|
||||||
|
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
|
||||||
|
quantize(B,BQ,scales,zeros,m,t,bits,gs);
|
||||||
|
oracle_mmadd(A, BQS, bias, C, n, m, t);
|
||||||
|
pack_input(A,AB);
|
||||||
|
pack_qw(BQ,BQB);
|
||||||
|
pack_output(C,CB);
|
||||||
|
|
||||||
|
compute_reduction(A,sums,n,m,gs);
|
||||||
|
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||||
|
|
||||||
|
float norm = 0.0;
|
||||||
|
for (int i = 0; i < n*t; i++){
|
||||||
|
norm += (C[i] - C2[i]) * (C[i] - C2[i]);
|
||||||
|
}
|
||||||
|
if(norm / (n*t) < 0.0001){
|
||||||
|
int iter = 30;
|
||||||
|
for(int _ = 0; _ < iter; _++){
|
||||||
|
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||||
|
}
|
||||||
|
|
||||||
|
int num_runs = 15;
|
||||||
|
std::vector<long int> runs(num_runs);
|
||||||
|
for(int r = 0; r < num_runs; r++){
|
||||||
|
auto start = std::chrono::high_resolution_clock::now();
|
||||||
|
for(int _ = 0; _ < iter; _++){
|
||||||
|
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||||
|
}
|
||||||
|
auto end = std::chrono::high_resolution_clock::now();
|
||||||
|
runs[r] = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(runs.begin(), runs.end());
|
||||||
|
|
||||||
|
float cycles_final = runs[num_runs/2 + 1] / iter;
|
||||||
|
|
||||||
|
std::ofstream outfile;
|
||||||
|
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||||
|
|
||||||
|
print_parameters();
|
||||||
|
outfile << cycles_final << std::endl;
|
||||||
|
}else{
|
||||||
|
float cycles_final = int(10e12);
|
||||||
|
|
||||||
|
std::ofstream outfile;
|
||||||
|
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||||
|
|
||||||
|
print_parameters();
|
||||||
|
outfile << cycles_final << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
85
autogptq_extension/qigen/template.py
Normal file
85
autogptq_extension/qigen/template.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
|
||||||
|
def includes():
|
||||||
|
out = " \
|
||||||
|
#include <torch/all.h>\n \
|
||||||
|
#include <torch/python.h>\n \
|
||||||
|
#include <omp.h>\n \
|
||||||
|
#include <cmath>\n \
|
||||||
|
#include <immintrin.h>\n \
|
||||||
|
\n \
|
||||||
|
#define mymin(a,b) ((a)<(b)?(a):(b))\n \
|
||||||
|
#define mymax(a,b) ((a)>(b)?(a):(b))\n \
|
||||||
|
"
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def module(bits_list=[4, 2]):
|
||||||
|
out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n'
|
||||||
|
for bits in bits_list:
|
||||||
|
out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits)
|
||||||
|
|
||||||
|
for bits in bits_list:
|
||||||
|
out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits)
|
||||||
|
|
||||||
|
for bits in bits_list:
|
||||||
|
out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits)
|
||||||
|
|
||||||
|
for bits in bits_list:
|
||||||
|
out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits)
|
||||||
|
|
||||||
|
out += 'm.def("compute_reduction_cpp", &compute_reduction);\n'
|
||||||
|
out += 'm.def("unquantize_sim", &unquantize_sim);\n'
|
||||||
|
|
||||||
|
# if oracle:
|
||||||
|
# out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n'
|
||||||
|
|
||||||
|
|
||||||
|
out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n'
|
||||||
|
|
||||||
|
out += '}\n'
|
||||||
|
return out
|
||||||
|
|
||||||
|
def quant_scalar():
|
||||||
|
out = " \
|
||||||
|
void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \
|
||||||
|
//find scales and zeros arrays \n \
|
||||||
|
//quantize \n \
|
||||||
|
int pack = 32/bits;\n \
|
||||||
|
for (int j = 0; j < m; j++){\n \
|
||||||
|
for (int i = 0; i < n; i+=pack){\n \
|
||||||
|
uint32_t acc = 0;\n \
|
||||||
|
for (int ii = i; ii < i+pack; ii++){\n \
|
||||||
|
float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \
|
||||||
|
int temp = (int)ftemp;\n \
|
||||||
|
acc = acc | (temp << (bits*(ii-i)));\n \
|
||||||
|
}\n \
|
||||||
|
BQ[(i/pack)*m+j] = acc;\n \
|
||||||
|
//BQ[0] = acc;\n \
|
||||||
|
}\n \
|
||||||
|
}\n \
|
||||||
|
}\n \
|
||||||
|
\n \
|
||||||
|
void quant_scalar_cpu(\n \
|
||||||
|
torch::Tensor in, torch::Tensor out, \n \
|
||||||
|
torch::Tensor scales, torch::Tensor zeros, int bits\n \
|
||||||
|
) {\n \
|
||||||
|
\n \
|
||||||
|
int N = in.size(0);\n \
|
||||||
|
int M = in.size(1);\n \
|
||||||
|
\n \
|
||||||
|
float* input = in.data_ptr<float>(); \n \
|
||||||
|
float* s = scales.data_ptr<float>();\n \
|
||||||
|
float* z = zeros.data_ptr<float>();\n \
|
||||||
|
int* O = out.data_ptr<int>();\n \
|
||||||
|
\n \
|
||||||
|
quantize_scalar(input, O, s, z, N, M, bits);\n \
|
||||||
|
\n \
|
||||||
|
}\n"
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
39
setup.py
39
setup.py
|
@ -1,8 +1,12 @@
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from setuptools import setup, find_packages
|
from setuptools import setup, Extension, find_packages
|
||||||
|
import subprocess
|
||||||
|
import math
|
||||||
|
|
||||||
|
os.environ["CC"] = "g++"
|
||||||
|
os.environ["CXX"] = "g++"
|
||||||
|
|
||||||
common_setup_kwargs = {
|
common_setup_kwargs = {
|
||||||
"version": "0.4.1",
|
"version": "0.4.1",
|
||||||
|
@ -69,12 +73,15 @@ if BUILD_CUDA_EXT:
|
||||||
requirements = [
|
requirements = [
|
||||||
"accelerate>=0.19.0",
|
"accelerate>=0.19.0",
|
||||||
"datasets",
|
"datasets",
|
||||||
|
"sentencepiece",
|
||||||
"numpy",
|
"numpy",
|
||||||
"rouge",
|
"rouge",
|
||||||
|
"gekko",
|
||||||
"torch>=1.13.0",
|
"torch>=1.13.0",
|
||||||
"safetensors",
|
"safetensors",
|
||||||
"transformers>=4.31.0",
|
"transformers>=4.31.0",
|
||||||
"peft"
|
"peft",
|
||||||
|
"tqdm",
|
||||||
]
|
]
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
|
@ -88,6 +95,9 @@ additional_setup_kwargs = dict()
|
||||||
if BUILD_CUDA_EXT:
|
if BUILD_CUDA_EXT:
|
||||||
from torch.utils import cpp_extension
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
|
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
|
||||||
|
|
||||||
|
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
|
||||||
if not ROCM_VERSION:
|
if not ROCM_VERSION:
|
||||||
from distutils.sysconfig import get_python_lib
|
from distutils.sysconfig import get_python_lib
|
||||||
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
|
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
|
||||||
|
@ -100,16 +110,23 @@ if BUILD_CUDA_EXT:
|
||||||
cpp_extension.CUDAExtension(
|
cpp_extension.CUDAExtension(
|
||||||
"autogptq_cuda_64",
|
"autogptq_cuda_64",
|
||||||
[
|
[
|
||||||
"autogptq_cuda/autogptq_cuda_64.cpp",
|
"autogptq_extension/cuda_64/autogptq_cuda_64.cpp",
|
||||||
"autogptq_cuda/autogptq_cuda_kernel_64.cu"
|
"autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu"
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
cpp_extension.CUDAExtension(
|
cpp_extension.CUDAExtension(
|
||||||
"autogptq_cuda_256",
|
"autogptq_cuda_256",
|
||||||
[
|
[
|
||||||
"autogptq_cuda/autogptq_cuda_256.cpp",
|
"autogptq_extension/cuda_256/autogptq_cuda_256.cpp",
|
||||||
"autogptq_cuda/autogptq_cuda_kernel_256.cu"
|
"autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu"
|
||||||
]
|
]
|
||||||
|
),
|
||||||
|
cpp_extension.CppExtension(
|
||||||
|
"cQIGen",
|
||||||
|
[
|
||||||
|
'autogptq_extension/qigen/backend.cpp'
|
||||||
|
],
|
||||||
|
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -126,11 +143,11 @@ if BUILD_CUDA_EXT:
|
||||||
cpp_extension.CUDAExtension(
|
cpp_extension.CUDAExtension(
|
||||||
"exllama_kernels",
|
"exllama_kernels",
|
||||||
[
|
[
|
||||||
"autogptq_cuda/exllama/exllama_ext.cpp",
|
"autogptq_extension/exllama/exllama_ext.cpp",
|
||||||
"autogptq_cuda/exllama/cuda_buffers.cu",
|
"autogptq_extension/exllama/cuda_buffers.cu",
|
||||||
"autogptq_cuda/exllama/cuda_func/column_remap.cu",
|
"autogptq_extension/exllama/cuda_func/column_remap.cu",
|
||||||
"autogptq_cuda/exllama/cuda_func/q4_matmul.cu",
|
"autogptq_extension/exllama/cuda_func/q4_matmul.cu",
|
||||||
"autogptq_cuda/exllama/cuda_func/q4_matrix.cu"
|
"autogptq_extension/exllama/cuda_func/q4_matrix.cu"
|
||||||
],
|
],
|
||||||
extra_link_args=extra_link_args
|
extra_link_args=extra_link_args
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue