Merge remote-tracking branch 'qwopqwop200/main' into main
This commit is contained in:
commit
6a9d80eddc
27 changed files with 2498 additions and 79 deletions
|
@ -13,6 +13,7 @@ import torch.nn as nn
|
|||
import transformers
|
||||
from accelerate.hooks import remove_hook_from_module
|
||||
from safetensors.torch import save_file as safe_save
|
||||
from safetensors.torch import load_file as safe_load
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
||||
from transformers.utils.generic import ContextManagers
|
||||
|
@ -687,7 +688,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
device: Optional[Union[str, int]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
use_triton: bool = False,
|
||||
torch_dtype: torch.dtype = torch.float16,
|
||||
use_qigen: bool = False,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True,
|
||||
use_cuda_fp16: bool = True,
|
||||
|
@ -725,6 +727,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
"_raise_exceptions_for_missing_entries": False,
|
||||
"_commit_hash": commit_hash,
|
||||
}
|
||||
if use_qigen:
|
||||
logger.warning("QIgen is active. Ignores all settings related to cuda.")
|
||||
inject_fused_attention = False
|
||||
inject_fused_mlp = False
|
||||
use_triton = False
|
||||
disable_exllama = True
|
||||
|
||||
if use_triton and not TRITON_AVAILABLE:
|
||||
logger.warning("Triton is not installed, reset use_triton to False.")
|
||||
|
@ -803,6 +811,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
def skip(*args, **kwargs):
|
||||
pass
|
||||
|
||||
if torch_dtype is None:
|
||||
if not use_qigen:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
if not use_qigen:
|
||||
torch.nn.init.kaiming_uniform_ = skip
|
||||
torch.nn.init.uniform_ = skip
|
||||
torch.nn.init.normal_ = skip
|
||||
|
@ -880,7 +895,46 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
offload_buffers=True
|
||||
)
|
||||
model = simple_dispatch_model(model, device_map)
|
||||
else:
|
||||
if quantize_config.desc_act:
|
||||
NotImplementedError('desc_act=True is not yet supported.')
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
config,
|
||||
trust_remote_code=trust_remote_code,
|
||||
torch_dtype=torch_dtype
|
||||
)
|
||||
|
||||
layers = find_layers(model)
|
||||
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
|
||||
for name in list(layers.keys()):
|
||||
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
|
||||
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
|
||||
del layers[name]
|
||||
|
||||
if model_save_name.endswith('.safetensors'):
|
||||
checkpoint = safe_load(model_save_name)
|
||||
else:
|
||||
checkpoint = torch.load(model_save_name)
|
||||
make_quant(
|
||||
model,
|
||||
layers,
|
||||
quantize_config.bits,
|
||||
quantize_config.group_size,
|
||||
use_triton=use_triton,
|
||||
disable_exllama=disable_exllama,
|
||||
use_cuda_fp16=use_cuda_fp16,
|
||||
desc_act=quantize_config.desc_act,
|
||||
trainable=trainable,
|
||||
use_qigen=True
|
||||
)
|
||||
preprocess_checkpoint_qigen(
|
||||
model,
|
||||
layers,
|
||||
quantize_config.bits,
|
||||
quantize_config.group_size,
|
||||
checkpoint
|
||||
)
|
||||
model.load_state_dict(checkpoint)
|
||||
# == step4: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
|
|
|
@ -6,11 +6,11 @@ import torch
|
|||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
import transformers
|
||||
import cQIGen as qinfer
|
||||
|
||||
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
|
||||
from ..utils.import_utils import dynamically_import_QuantLinear
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -58,11 +58,12 @@ def make_quant(
|
|||
name='',
|
||||
use_triton: bool = False,
|
||||
disable_exllama: bool = False,
|
||||
use_qigen: bool = False,
|
||||
use_cuda_fp16: bool = True,
|
||||
desc_act: bool = False,
|
||||
trainable: bool = False
|
||||
):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, use_qigen=use_qigen)
|
||||
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
|
@ -81,7 +82,7 @@ def make_quant(
|
|||
elif isinstance(tmp,transformers.pytorch_utils.Conv1D):
|
||||
in_features = tmp.weight.shape[0]
|
||||
out_features = tmp.weight.shape[1]
|
||||
if (not(desc_act) or group_size == -1) and not use_triton:
|
||||
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
|
||||
new_layer = QuantLinear(
|
||||
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
|
||||
)
|
||||
|
@ -101,8 +102,73 @@ def make_quant(
|
|||
desc_act=desc_act,
|
||||
trainable=trainable,
|
||||
disable_exllama=disable_exllama,
|
||||
use_qigen=use_qigen
|
||||
)
|
||||
|
||||
def process_zeros_scales(zeros, scales, bits, out_features):
|
||||
if zeros.dtype != torch.float32:
|
||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||
if bits == 4:
|
||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 2:
|
||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 3:
|
||||
logger.info("Unpacking zeros for 3 bits")
|
||||
new_scales = scales.contiguous()
|
||||
else:
|
||||
if scales.shape[1] != out_features:
|
||||
new_scales = scales.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_scales = scales.contiguous()
|
||||
if zeros.shape[1] != out_features:
|
||||
new_zeros = zeros.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_zeros = zeros.contiguous()
|
||||
|
||||
return new_zeros, new_scales
|
||||
|
||||
def preprocess_checkpoint_qigen(
|
||||
module,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
checkpoint,
|
||||
name='',
|
||||
):
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
|
||||
if isinstance(module, QuantLinear):
|
||||
in_features = module.infeatures
|
||||
out_features = module.outfeatures
|
||||
|
||||
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features)
|
||||
del checkpoint[name + '.qzeros']
|
||||
del checkpoint[name + '.g_idx']
|
||||
if name + '.bias' in checkpoint:
|
||||
checkpoint[name + '.bias'] = checkpoint[name + '.bias'].float()
|
||||
else:
|
||||
checkpoint[name + '.bias'] = torch.zeros(out_features)
|
||||
checkpoint_qweight = checkpoint[name + '.qweight'].int().contiguous()
|
||||
if bits == 4:
|
||||
qweight = torch.zeros(int(in_features // 8 * out_features)).int().contiguous()
|
||||
qinfer.pack4(checkpoint_qweight, qweight, in_features // 8, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
|
||||
elif bits == 3:
|
||||
qweight = torch.zeros(int(in_features // 32 * 3 * out_features)).int().contiguous()
|
||||
qinfer.pack3(checkpoint_qweight, qweight, in_features // 32 * 3, out_features, module.mb // 32 * 3, module.tb, module.cutoff)
|
||||
elif bits == 2:
|
||||
qweight = torch.zeros(int(in_features // 16 * out_features)).int().contiguous()
|
||||
qinfer.pack2(checkpoint_qweight, qweight, in_features // 16, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
|
||||
checkpoint[name + '.qweight'] = qweight
|
||||
return
|
||||
|
||||
for name1, child in module.named_children():
|
||||
preprocess_checkpoint_qigen(
|
||||
child,
|
||||
names,
|
||||
bits,
|
||||
group_size,
|
||||
checkpoint,
|
||||
name + '.' + name1 if name != '' else name1,
|
||||
)
|
||||
|
||||
def pack_model(
|
||||
model,
|
||||
|
@ -287,6 +353,7 @@ __all__ = [
|
|||
"get_module_by_name_prefix",
|
||||
"get_module_by_name_suffix",
|
||||
"make_quant",
|
||||
"preprocess_checkpoint_qigen",
|
||||
"pack_model",
|
||||
"autogptq_post_init",
|
||||
"check_and_get_model_type",
|
||||
|
|
258
auto_gptq/nn_modules/qlinear/qlinear_qigen.py
Normal file
258
auto_gptq/nn_modules/qlinear/qlinear_qigen.py
Normal file
|
@ -0,0 +1,258 @@
|
|||
from copy import deepcopy
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
|
||||
import cQIGen as qinfer
|
||||
import math
|
||||
import numpy as np
|
||||
from gekko import GEKKO
|
||||
from logging import getLogger
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
|
||||
m = GEKKO() # create GEKKO model
|
||||
#cinfergen if bits==3:
|
||||
# tu = tu*3
|
||||
B = m.Const(value=bits)
|
||||
TP = m.Const(value=T//p)
|
||||
k = m.Var(1,integer=True,lb=1)
|
||||
z = m.Var(1,integer=True,lb=1)
|
||||
w = m.Var(1,integer=True,lb=1)
|
||||
y = m.Var(1,integer=True,lb=1)
|
||||
v = m.Var(1,integer=True,lb=1)
|
||||
mb = m.Var(mu,integer=True,lb=1)
|
||||
if gs != -1:
|
||||
gg = m.Var(1,integer=True,lb=1)
|
||||
tb = m.Var(tu,integer=True,lb=1,ub=int(T/p))
|
||||
L = m.Var(integer=True,lb=0,ub=l1)
|
||||
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
|
||||
m.Equation(mb * k == M)
|
||||
if gs != -1:
|
||||
m.Equation(gs * gg == mb)
|
||||
# m.Equation(tb * z == T)
|
||||
m.Equation(tb * z == TP)
|
||||
m.Equation(mu * w == mb)
|
||||
m.Equation(tu * y == tb)
|
||||
# m.Equation(tb * v == tt)
|
||||
m.Maximize(L)
|
||||
m.options.SOLVER = 1
|
||||
m.solver_options = ['minlp_maximum_iterations 1000', \
|
||||
# minlp iterations with integer solution
|
||||
'minlp_max_iter_with_int_sol 10', \
|
||||
# treat minlp as nlp
|
||||
'minlp_as_nlp 0', \
|
||||
# nlp sub-problem max iterations
|
||||
'nlp_maximum_iterations 100', \
|
||||
# 1 = depth first, 2 = breadth first
|
||||
'minlp_branch_method 2', \
|
||||
# maximum deviation from whole number
|
||||
'minlp_integer_tol 0.00', \
|
||||
# covergence tolerance
|
||||
'minlp_gap_tol 0.01']
|
||||
try:
|
||||
m.solve(disp=False)
|
||||
except:
|
||||
try:
|
||||
m.solver_options = ['minlp_maximum_iterations 1000', \
|
||||
# minlp iterations with integer solution
|
||||
'minlp_max_iter_with_int_sol 10', \
|
||||
# treat minlp as nlp
|
||||
'minlp_as_nlp 0', \
|
||||
# nlp sub-problem max iterations
|
||||
'nlp_maximum_iterations 100', \
|
||||
# 1 = depth first, 2 = breadth first
|
||||
'minlp_branch_method 1', \
|
||||
# maximum deviation from whole number
|
||||
'minlp_integer_tol 0.00', \
|
||||
# covergence tolerance
|
||||
'minlp_gap_tol 0.01']
|
||||
m.solve(disp=False)
|
||||
except:
|
||||
# mytb = T//p
|
||||
mytb = tu
|
||||
if gs != -1:
|
||||
mymb = gs
|
||||
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
|
||||
mymb += gs
|
||||
while M % mymb != 0:
|
||||
mymb -= gs
|
||||
return (int(mymb), int(mytb))
|
||||
else:
|
||||
mymb = mu
|
||||
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
|
||||
mymb += mu
|
||||
while M % mymb != 0:
|
||||
mymb -= mu
|
||||
return (int(mymb), int(mytb))
|
||||
|
||||
return (int(mb.value[0]), int(tb.value[0]))
|
||||
|
||||
params = {}
|
||||
|
||||
def compute_reductions(x, gs=-1, cpp=True):
|
||||
if cpp:
|
||||
if len(x.shape) != 1:
|
||||
rows, cols = x.shape
|
||||
else:
|
||||
rows = 1
|
||||
cols = x.shape[0]
|
||||
if gs == -1:
|
||||
out = torch.zeros(rows).float().contiguous()
|
||||
mygs = cols
|
||||
else:
|
||||
out = torch.zeros(rows, cols // gs).float().contiguous()
|
||||
mygs = gs
|
||||
|
||||
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
|
||||
return out
|
||||
if gs == -1:
|
||||
if len(x.shape) != 1:
|
||||
return torch.sum(x,1)
|
||||
else:
|
||||
return torch.sum(x)
|
||||
else:
|
||||
if len(x.shape) != 1:
|
||||
rows, cols = x.shape
|
||||
out = torch.zeros(rows, cols // gs).float().contiguous()
|
||||
for i in range(cols // gs):
|
||||
out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1)
|
||||
return out
|
||||
else:
|
||||
cols = x.shape[0]
|
||||
out = torch.zeros(cols // gs).float().contiguous()
|
||||
for i in range(cols // gs):
|
||||
out[i] = torch.sum(x[i*gs:(i+1)*gs])
|
||||
return out
|
||||
|
||||
def process_zeros_scales(zeros, scales, bits, M):
|
||||
if zeros.dtype != torch.float32:
|
||||
new_zeros = torch.zeros_like(scales).float().contiguous()
|
||||
if bits == 4:
|
||||
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 2:
|
||||
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
||||
elif bits == 3:
|
||||
logger.info("Unpacking zeros for 3 bits")
|
||||
new_scales = scales.contiguous()
|
||||
else:
|
||||
if scales.shape[1] != M:
|
||||
new_scales = scales.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_scales = scales.contiguous()
|
||||
if zeros.shape[1] != M:
|
||||
new_zeros = zeros.transpose(0,1).contiguous()
|
||||
else:
|
||||
new_zeros = zeros.contiguous()
|
||||
|
||||
return new_zeros, new_scales
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
QUANT_TYPE = "qigen"
|
||||
|
||||
def __init__(self, bits, group_size, infeatures, outfeatures, bias=None, trainable=False, hint=1, p=8, l1=2**18):
|
||||
super().__init__()
|
||||
if bits not in [2, 4]:
|
||||
raise NotImplementedError("Only 2,4 bits are supported.")
|
||||
if trainable:
|
||||
raise NotImplementedError("Qigen kernel does not support training.")
|
||||
self.bits = bits
|
||||
pack = 32 // bits
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
|
||||
n = hint
|
||||
m = self.infeatures
|
||||
t = self.outfeatures
|
||||
|
||||
#registers for now are fixed
|
||||
if bits == 3:
|
||||
packed = 32
|
||||
unroll = 3
|
||||
nu = 1 #args.n
|
||||
mu = 32
|
||||
tu = 32
|
||||
else:
|
||||
packed = 32 // bits
|
||||
unroll = 2
|
||||
nu = 1 #args.n
|
||||
mu = 16
|
||||
tu = 32
|
||||
|
||||
nb = n # it's always small for transformers
|
||||
|
||||
global params
|
||||
if (m,t) in params:
|
||||
mb = params[(m,t)][0]
|
||||
tb = params[(m,t)][1]
|
||||
else:
|
||||
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
|
||||
params[(m,t)] = (mb,tb)
|
||||
|
||||
split = np.ones(p)
|
||||
split = split * tb
|
||||
while np.sum(split) < t:
|
||||
split = split + tb
|
||||
|
||||
idx = p - 1
|
||||
while np.sum(split) > t:
|
||||
split[idx] = split[idx] - tb
|
||||
idx = idx - 1
|
||||
|
||||
assert(np.sum(split) == t)
|
||||
|
||||
split = split.astype(int)
|
||||
self.tt = int(split[0])
|
||||
|
||||
if split[0] == split[-1]:
|
||||
self.cutoff = int(p+1)
|
||||
else:
|
||||
self.cutoff = int(idx + 1)
|
||||
|
||||
self.mb = mb #// packed
|
||||
self.tb = tb
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
self.register_buffer('bias', torch.zeros(self.outfeatures))
|
||||
self.register_buffer('zeros', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
||||
if bits == 4:
|
||||
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
||||
elif bits == 3:
|
||||
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous())
|
||||
elif bits == 2:
|
||||
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
||||
|
||||
def forward(self, x):
|
||||
out_shape = x.shape[:-1] + (self.outfeatures,)
|
||||
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
|
||||
B = x.shape[0]
|
||||
new_x = x.T.contiguous()
|
||||
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
|
||||
sums = compute_reductions(x,gs=self.group_size,cpp=True).contiguous()
|
||||
if self.group_size == -1:
|
||||
if self.bits == 4:
|
||||
qinfer.forward4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||
elif self.bits == 2:
|
||||
qinfer.forward2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||
elif self.bits == 3:
|
||||
qinfer.forward3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
||||
else:
|
||||
if self.bits == 4:
|
||||
qinfer.forward_gs4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||
elif self.bits == 2:
|
||||
qinfer.forward_gs2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||
elif self.bits == 3:
|
||||
qinfer.forward_gs3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
||||
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
||||
return out.reshape(out_shape)
|
|
@ -28,7 +28,10 @@ except:
|
|||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False):
|
||||
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False, use_qigen: bool = False):
|
||||
if use_qigen:
|
||||
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
|
||||
else:
|
||||
if use_triton:
|
||||
if torch.version.hip:
|
||||
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")
|
||||
|
|
1484
autogptq_extension/qigen/generate.py
Normal file
1484
autogptq_extension/qigen/generate.py
Normal file
File diff suppressed because it is too large
Load diff
149
autogptq_extension/qigen/intrin.py
Normal file
149
autogptq_extension/qigen/intrin.py
Normal file
|
@ -0,0 +1,149 @@
|
|||
|
||||
def load_int(to, address, const=True):
|
||||
if const:
|
||||
return f"const __m256i {to} = _mm256_loadu_si256({address});"
|
||||
else:
|
||||
return f"__m256i {to} = _mm256_loadu_si256({address});"
|
||||
|
||||
def load_fp(to, address, const=True):
|
||||
if const:
|
||||
return f"const __m256 {to} = _mm256_loadu_ps({address});"
|
||||
else:
|
||||
return f"__m256 {to} = _mm256_loadu_ps({address});"
|
||||
|
||||
# to = a * b + c
|
||||
def vfma(to, a, b, c):
|
||||
return f"__m256 {to} = _mm256_fmadd_ps({a}, {b}, {c});"
|
||||
|
||||
def vsrli(to, a, b):
|
||||
return f"const __m256i {to} = _mm256_srli_epi32({a}, {b});"
|
||||
|
||||
def vand(to, a, b):
|
||||
return f"const __m256i {to} = _mm256_and_si256({a}, {b});"
|
||||
|
||||
def vbroadcast_fp(to, a):
|
||||
return f"const __m256 {to} = _mm256_set1_ps({a});"
|
||||
|
||||
def vbroadcast_int32(to, a):
|
||||
return f"__m256i {to} = _mm256_set1_epi32({a});"
|
||||
|
||||
def vsetzero(to):
|
||||
return f"__m256 {to} = _mm256_setzero_ps();"
|
||||
|
||||
def vcvtepi32_ps(to, a):
|
||||
return f"const __m256 {to} = _mm256_cvtepi32_ps({a});"
|
||||
|
||||
def _256extractf128_ps(to, a, imm):
|
||||
return f"const __m128 {to} = _mm256_extractf128_ps({a}, {imm});"
|
||||
|
||||
def _256castps256_ps128(to, a):
|
||||
return f"const __m128 {to} = _mm256_castps256_ps128({a});"
|
||||
|
||||
def _add_ps(to, a, b):
|
||||
return f"const __m128 {to} = _mm_add_ps({a}, {b});"
|
||||
|
||||
def _movehl_ps(to, a, b):
|
||||
return f"const __m128 {to} = _mm_movehl_ps({a}, {b});"
|
||||
|
||||
def _shuffle_ps(to, a, b, imm):
|
||||
return f"const __m128 {to} = _mm_shuffle_ps({a}, {b}, {imm});"
|
||||
|
||||
def _cvtss_f32(to, a):
|
||||
return f"const float {to} = _mm_cvtss_f32({a});"
|
||||
|
||||
def _reduce8_acc(a, b, c, d, e, f, g, h):
|
||||
res = ""
|
||||
res += _256extractf128_ps("hi_quad0", a, 1)
|
||||
res += _256extractf128_ps("hi_quad1", b, 1)
|
||||
res += _256extractf128_ps("hi_quad2", c, 1)
|
||||
res += _256extractf128_ps("hi_quad3", d, 1)
|
||||
res += _256extractf128_ps("hi_quad4", e, 1)
|
||||
res += _256extractf128_ps("hi_quad5", f, 1)
|
||||
res += _256extractf128_ps("hi_quad6", g, 1)
|
||||
res += _256extractf128_ps("hi_quad7", h, 1)
|
||||
|
||||
res += _256castps256_ps128("lo_quad0", a)
|
||||
res += _256castps256_ps128("lo_quad1", b)
|
||||
res += _256castps256_ps128("lo_quad2", c)
|
||||
res += _256castps256_ps128("lo_quad3", d)
|
||||
res += _256castps256_ps128("lo_quad4", e)
|
||||
res += _256castps256_ps128("lo_quad5", f)
|
||||
res += _256castps256_ps128("lo_quad6", g)
|
||||
res += _256castps256_ps128("lo_quad7", h)
|
||||
|
||||
res += _add_ps("sum_quad0", "lo_quad0", "hi_quad0")
|
||||
res += _add_ps("sum_quad1", "lo_quad1", "hi_quad1")
|
||||
res += _add_ps("sum_quad2", "lo_quad2", "hi_quad2")
|
||||
res += _add_ps("sum_quad3", "lo_quad3", "hi_quad3")
|
||||
res += _add_ps("sum_quad4", "lo_quad4", "hi_quad4")
|
||||
res += _add_ps("sum_quad5", "lo_quad5", "hi_quad5")
|
||||
res += _add_ps("sum_quad6", "lo_quad6", "hi_quad6")
|
||||
res += _add_ps("sum_quad7", "lo_quad7", "hi_quad7")
|
||||
|
||||
res += _movehl_ps("hi_dual0", "sum_quad0", "sum_quad0")
|
||||
res += _movehl_ps("hi_dual1", "sum_quad1", "sum_quad1")
|
||||
res += _movehl_ps("hi_dual2", "sum_quad2", "sum_quad2")
|
||||
res += _movehl_ps("hi_dual3", "sum_quad3", "sum_quad3")
|
||||
res += _movehl_ps("hi_dual4", "sum_quad4", "sum_quad4")
|
||||
res += _movehl_ps("hi_dual5", "sum_quad5", "sum_quad5")
|
||||
res += _movehl_ps("hi_dual6", "sum_quad6", "sum_quad6")
|
||||
res += _movehl_ps("hi_dual7", "sum_quad7", "sum_quad7")
|
||||
|
||||
res += _add_ps("sum_dual0", "sum_quad0", "hi_dual0")
|
||||
res += _add_ps("sum_dual1", "sum_quad1", "hi_dual1")
|
||||
res += _add_ps("sum_dual2", "sum_quad2", "hi_dual2")
|
||||
res += _add_ps("sum_dual3", "sum_quad3", "hi_dual3")
|
||||
res += _add_ps("sum_dual4", "sum_quad4", "hi_dual4")
|
||||
res += _add_ps("sum_dual5", "sum_quad5", "hi_dual5")
|
||||
res += _add_ps("sum_dual6", "sum_quad6", "hi_dual6")
|
||||
res += _add_ps("sum_dual7", "sum_quad7", "hi_dual7")
|
||||
|
||||
res += _shuffle_ps("hi0", "sum_dual0", "sum_dual0", 0x1)
|
||||
res += _shuffle_ps("hi1", "sum_dual1", "sum_dual1", 0x1)
|
||||
res += _shuffle_ps("hi2", "sum_dual2", "sum_dual2", 0x1)
|
||||
res += _shuffle_ps("hi3", "sum_dual3", "sum_dual3", 0x1)
|
||||
res += _shuffle_ps("hi4", "sum_dual4", "sum_dual4", 0x1)
|
||||
res += _shuffle_ps("hi5", "sum_dual5", "sum_dual5", 0x1)
|
||||
res += _shuffle_ps("hi6", "sum_dual6", "sum_dual6", 0x1)
|
||||
res += _shuffle_ps("hi7", "sum_dual7", "sum_dual7", 0x1)
|
||||
|
||||
res += _add_ps("sum0", "sum_dual0", "hi0")
|
||||
res += _add_ps("sum1", "sum_dual1", "hi1")
|
||||
res += _add_ps("sum2", "sum_dual2", "hi2")
|
||||
res += _add_ps("sum3", "sum_dual3", "hi3")
|
||||
res += _add_ps("sum4", "sum_dual4", "hi4")
|
||||
res += _add_ps("sum5", "sum_dual5", "hi5")
|
||||
res += _add_ps("sum6", "sum_dual6", "hi6")
|
||||
res += _add_ps("sum7", "sum_dual7", "hi7")
|
||||
|
||||
res += _cvtss_f32(f"f{a}", "sum0")
|
||||
res += _cvtss_f32(f"f{b}", "sum1")
|
||||
res += _cvtss_f32(f"f{c}", "sum2")
|
||||
res += _cvtss_f32(f"f{d}", "sum3")
|
||||
res += _cvtss_f32(f"f{e}", "sum4")
|
||||
res += _cvtss_f32(f"f{f}", "sum5")
|
||||
res += _cvtss_f32(f"f{g}", "sum6")
|
||||
res += _cvtss_f32(f"f{h}", "sum7")
|
||||
|
||||
return res
|
||||
|
||||
acc_idx = 0
|
||||
def _reduce_add(a):
|
||||
global acc_idx
|
||||
res = ""
|
||||
res += _256extractf128_ps(f"hi_quad{acc_idx}", a, 1)
|
||||
res += _256castps256_ps128(f"lo_quad{acc_idx}", a)
|
||||
res += _add_ps(f"sum_quad{acc_idx}", f"lo_quad{acc_idx}", f"hi_quad{acc_idx}")
|
||||
res += _movehl_ps(f"hi_dual{acc_idx}", f"sum_quad{acc_idx}", f"sum_quad{acc_idx}")
|
||||
res += _add_ps(f"sum_dual{acc_idx}", f"sum_quad{acc_idx}", f"hi_dual{acc_idx}")
|
||||
res += _shuffle_ps(f"hi{acc_idx}", f"sum_dual{acc_idx}", f"sum_dual{acc_idx}", 0x1)
|
||||
res += _add_ps(f"sum{acc_idx}", f"sum_dual{acc_idx}", f"hi{acc_idx}")
|
||||
res += _cvtss_f32(f"f{a}", f"sum{acc_idx}")
|
||||
acc_idx += 1
|
||||
return res
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
302
autogptq_extension/qigen/mmm.cpp
Normal file
302
autogptq_extension/qigen/mmm.cpp
Normal file
|
@ -0,0 +1,302 @@
|
|||
#include <iostream>
|
||||
#include "forward.h"
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
|
||||
#define mymin(a,b) ((a)<(b)?(a):(b))
|
||||
#define mymax(a,b) ((a)>(b)?(a):(b))
|
||||
|
||||
void print_matrix(std::string name, float* A, int N, int M){
|
||||
std::cout<<name<<std::endl;
|
||||
for(int i = 0; i < N; i++){
|
||||
for(int j = 0; j < M; j++){
|
||||
std::cout << A[i*M+j] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout<<std::endl;
|
||||
}
|
||||
|
||||
void oracle_mmadd(float* A, float* B, float* bias, float* C, int n, int m, int t){
|
||||
// triple loop matmul and add bias
|
||||
for (int i = 0; i < n; i++){
|
||||
for (int j = 0; j < t; j++){
|
||||
float sum = 0;
|
||||
for (int k = 0; k < m; k++){
|
||||
sum += A[i*m+k] * B[k*t+j];
|
||||
}
|
||||
C[i*t+j] += sum + bias[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void compute_reduction(float *in, float *out, int n, int m, int gs){
|
||||
int ng;
|
||||
if(gs == -1){
|
||||
ng = 1;
|
||||
gs = m;
|
||||
}else{
|
||||
ng = m/gs;
|
||||
}
|
||||
for(int i = 0; i < n; i++){
|
||||
for(int j0 = 0; j0 < m; j0+=gs){
|
||||
int j = j0/gs;
|
||||
out[i*ng+j] = 0;
|
||||
for(int j1 = j0; j1 < j0+gs; j1++){
|
||||
out[i*ng+j] += in[i*m+j1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_sim(float* A, float* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
|
||||
//find scales and zeros arrays
|
||||
if(gs == -1){
|
||||
gs = n;
|
||||
}
|
||||
float range = (1<<bits) - 1;
|
||||
int packed = 32 / bits;
|
||||
|
||||
for(int i0 = 0; i0 < n; i0+=gs){
|
||||
int row = i0/gs;
|
||||
for(int j = 0; j < m; j++){
|
||||
float min = A[i0*m + j];
|
||||
float max = A[i0*m + j];
|
||||
for(int i1 = i0; i1 < i0+gs; i1++){
|
||||
min = mymin(min, A[i1*m+j]);
|
||||
max = mymax(max, A[i1*m+j]);
|
||||
}
|
||||
scales[row*m + j] = (max-min)/range;
|
||||
zeros[row*m + j ] = min;
|
||||
}
|
||||
for(int j = 0; j < m; j++){
|
||||
for (int i1 = i0; i1 < i0+gs; i1++){
|
||||
uint32_t acc = 0;
|
||||
int temp = (A[i1*m+j] - zeros[row*m+j])/scales[row*m+j];
|
||||
float val = ((float) temp + zeros[row*m+j]) * scales[row*m+j];
|
||||
BQ[i1*m+j] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void quantize(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
|
||||
//find scales and zeros arrays
|
||||
if(gs == -1){
|
||||
gs = n;
|
||||
}
|
||||
float range = (1<<bits) - 1;
|
||||
int packed = 32 / bits;
|
||||
|
||||
for(int i0 = 0; i0 < n; i0+=gs){
|
||||
int row = i0/gs;
|
||||
for(int j = 0; j < m; j++){
|
||||
float min = A[i0*m + j];
|
||||
float max = A[i0*m + j];
|
||||
for(int i1 = i0; i1 < i0+gs; i1++){
|
||||
min = mymin(min, A[i1*m+j]);
|
||||
max = mymax(max, A[i1*m+j]);
|
||||
}
|
||||
scales[row*m + j] = (max-min)/range;
|
||||
zeros[row*m + j ] = min;
|
||||
}
|
||||
for(int j = 0; j < m; j++){
|
||||
if(bits == 3){
|
||||
for (int i1 = i0; i1 < i0+gs; i1+=32){
|
||||
uint32_t acc = 0;
|
||||
int temp0 = ((int)((A[(i1+0)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 0;
|
||||
int temp1 = ((int)((A[(i1+1)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 3;
|
||||
int temp2 = ((int)((A[(i1+2)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 6;
|
||||
int temp3 = ((int)((A[(i1+3)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 9;
|
||||
int temp4 = ((int)((A[(i1+4)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 12;
|
||||
int temp5 = ((int)((A[(i1+5)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 15;
|
||||
int temp6 = ((int)((A[(i1+6)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 18;
|
||||
int temp7 = ((int)((A[(i1+7)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 21;
|
||||
int temp8 = ((int)((A[(i1+8)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 24;
|
||||
int temp9 = ((int)((A[(i1+9)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 27;
|
||||
int temp10_0 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 30;
|
||||
int temp10_1 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 2;
|
||||
int temp11 = ((int)((A[(i1+11)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 1;
|
||||
int temp12 = ((int)((A[(i1+12)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 4;
|
||||
int temp13 = ((int)((A[(i1+13)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 7;
|
||||
int temp14 = ((int)((A[(i1+14)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 10;
|
||||
int temp15 = ((int)((A[(i1+15)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 13;
|
||||
int temp16 = ((int)((A[(i1+16)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 16;
|
||||
int temp17 = ((int)((A[(i1+17)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 19;
|
||||
int temp18 = ((int)((A[(i1+18)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 22;
|
||||
int temp19 = ((int)((A[(i1+19)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 25;
|
||||
int temp20 = ((int)((A[(i1+20)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 28;
|
||||
int temp21_0 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 31;
|
||||
int temp21_1 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 1;
|
||||
int temp22 = ((int)((A[(i1+22)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 2;
|
||||
int temp23 = ((int)((A[(i1+23)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 5;
|
||||
int temp24 = ((int)((A[(i1+24)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 8;
|
||||
int temp25 = ((int)((A[(i1+25)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 11;
|
||||
int temp26 = ((int)((A[(i1+26)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 14;
|
||||
int temp27 = ((int)((A[(i1+27)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 17;
|
||||
int temp28 = ((int)((A[(i1+28)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 20;
|
||||
int temp29 = ((int)((A[(i1+29)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 23;
|
||||
int temp30 = ((int)((A[(i1+30)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 26;
|
||||
int temp31 = ((int)((A[(i1+31)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 29;
|
||||
|
||||
int acc0 = 0, acc1 = 0, acc2 = 0;
|
||||
|
||||
acc0 |= temp0;
|
||||
acc0 |= temp1;
|
||||
acc0 |= temp2;
|
||||
acc0 |= temp3;
|
||||
acc0 |= temp4;
|
||||
acc0 |= temp5;
|
||||
acc0 |= temp6;
|
||||
acc0 |= temp7;
|
||||
acc0 |= temp8;
|
||||
acc0 |= temp9;
|
||||
acc0 |= temp10_0;
|
||||
|
||||
acc1 |= temp10_1;
|
||||
acc1 |= temp11;
|
||||
acc1 |= temp12;
|
||||
acc1 |= temp13;
|
||||
acc1 |= temp14;
|
||||
acc1 |= temp15;
|
||||
acc1 |= temp16;
|
||||
acc1 |= temp17;
|
||||
acc1 |= temp18;
|
||||
acc1 |= temp19;
|
||||
acc1 |= temp20;
|
||||
acc1 |= temp21_0;
|
||||
|
||||
acc2 |= temp21_1;
|
||||
acc2 |= temp22;
|
||||
acc2 |= temp23;
|
||||
acc2 |= temp24;
|
||||
acc2 |= temp25;
|
||||
acc2 |= temp26;
|
||||
acc2 |= temp27;
|
||||
acc2 |= temp28;
|
||||
acc2 |= temp29;
|
||||
acc2 |= temp30;
|
||||
acc2 |= temp31;
|
||||
|
||||
BQ[(3*i1/32)*m+j] = acc0;
|
||||
BQ[(3*i1/32+1)*m+j] = acc1;
|
||||
BQ[(3*i1/32+2)*m+j] = acc2;
|
||||
}
|
||||
|
||||
}else{
|
||||
for (int i1 = i0; i1 < i0+gs; i1+=packed){
|
||||
uint32_t acc = 0;
|
||||
for (int i2 = i1; i2 < i1+packed; i2++){
|
||||
int temp = (A[i2*m+j] - zeros[row*m+j])/scales[row*m+j];
|
||||
acc = acc | (temp << (bits*(i2-i1)));
|
||||
}
|
||||
BQ[(i1/packed)*m+j] = acc;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]){
|
||||
// read n m t from args
|
||||
if(argc == 0){std::cout << "Parameters not given\n"; return 0;}
|
||||
int n = atoi(argv[1]);
|
||||
int m = atoi(argv[2]);
|
||||
int t = atoi(argv[3]);
|
||||
int bits = atoi(argv[4]);
|
||||
int gs = atoi(argv[5]);
|
||||
int ng;
|
||||
if(gs == -1){
|
||||
ng = 1;
|
||||
}else{
|
||||
ng = m/gs;
|
||||
}
|
||||
float* A = new float[n*m];
|
||||
float* AB = new float[n*m];
|
||||
float* B = new float[m*t];
|
||||
float* BQS = new float[m*t];
|
||||
float* scales = new float[t*ng];
|
||||
float* zeros = new float[t*ng];
|
||||
int* BQ = new int[m*t/8];
|
||||
int* BQB = new int[m*t/8];
|
||||
float* sums = new float[n*ng];
|
||||
float* bias = new float[t];
|
||||
float* C = new float[n*t];
|
||||
float* CB = new float[n*t];
|
||||
float* C2 = new float[n*t];
|
||||
srand(1);
|
||||
for (int i = 0; i < n*m; i++){
|
||||
A[i] = (float)rand() / RAND_MAX;
|
||||
}
|
||||
for (int i = 0; i < t*m; i++){
|
||||
B[i] = (float)rand() / RAND_MAX;
|
||||
}
|
||||
for (int i = 0; i < t; i++){
|
||||
bias[i] = (float)rand() / RAND_MAX;
|
||||
}
|
||||
for (int i = 0; i < n*t; i++){
|
||||
C[i] = 0.0;
|
||||
C2[i] = 0.0;
|
||||
}
|
||||
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
|
||||
quantize(B,BQ,scales,zeros,m,t,bits,gs);
|
||||
|
||||
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
|
||||
quantize(B,BQ,scales,zeros,m,t,bits,gs);
|
||||
oracle_mmadd(A, BQS, bias, C, n, m, t);
|
||||
pack_input(A,AB);
|
||||
pack_qw(BQ,BQB);
|
||||
pack_output(C,CB);
|
||||
|
||||
compute_reduction(A,sums,n,m,gs);
|
||||
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||
|
||||
float norm = 0.0;
|
||||
for (int i = 0; i < n*t; i++){
|
||||
norm += (C[i] - C2[i]) * (C[i] - C2[i]);
|
||||
}
|
||||
if(norm / (n*t) < 0.0001){
|
||||
int iter = 30;
|
||||
for(int _ = 0; _ < iter; _++){
|
||||
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||
}
|
||||
|
||||
int num_runs = 15;
|
||||
std::vector<long int> runs(num_runs);
|
||||
for(int r = 0; r < num_runs; r++){
|
||||
auto start = std::chrono::high_resolution_clock::now();
|
||||
for(int _ = 0; _ < iter; _++){
|
||||
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
|
||||
}
|
||||
auto end = std::chrono::high_resolution_clock::now();
|
||||
runs[r] = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
|
||||
|
||||
}
|
||||
|
||||
std::sort(runs.begin(), runs.end());
|
||||
|
||||
float cycles_final = runs[num_runs/2 + 1] / iter;
|
||||
|
||||
std::ofstream outfile;
|
||||
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||
|
||||
print_parameters();
|
||||
outfile << cycles_final << std::endl;
|
||||
}else{
|
||||
float cycles_final = int(10e12);
|
||||
|
||||
std::ofstream outfile;
|
||||
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
|
||||
|
||||
print_parameters();
|
||||
outfile << cycles_final << std::endl;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
85
autogptq_extension/qigen/template.py
Normal file
85
autogptq_extension/qigen/template.py
Normal file
|
@ -0,0 +1,85 @@
|
|||
|
||||
def includes():
|
||||
out = " \
|
||||
#include <torch/all.h>\n \
|
||||
#include <torch/python.h>\n \
|
||||
#include <omp.h>\n \
|
||||
#include <cmath>\n \
|
||||
#include <immintrin.h>\n \
|
||||
\n \
|
||||
#define mymin(a,b) ((a)<(b)?(a):(b))\n \
|
||||
#define mymax(a,b) ((a)>(b)?(a):(b))\n \
|
||||
"
|
||||
return out
|
||||
|
||||
|
||||
def module(bits_list=[4, 2]):
|
||||
out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n'
|
||||
for bits in bits_list:
|
||||
out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits)
|
||||
|
||||
for bits in bits_list:
|
||||
out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits)
|
||||
|
||||
for bits in bits_list:
|
||||
out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits)
|
||||
|
||||
for bits in bits_list:
|
||||
out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits)
|
||||
|
||||
out += 'm.def("compute_reduction_cpp", &compute_reduction);\n'
|
||||
out += 'm.def("unquantize_sim", &unquantize_sim);\n'
|
||||
|
||||
# if oracle:
|
||||
# out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n'
|
||||
|
||||
|
||||
out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n'
|
||||
|
||||
out += '}\n'
|
||||
return out
|
||||
|
||||
def quant_scalar():
|
||||
out = " \
|
||||
void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \
|
||||
//find scales and zeros arrays \n \
|
||||
//quantize \n \
|
||||
int pack = 32/bits;\n \
|
||||
for (int j = 0; j < m; j++){\n \
|
||||
for (int i = 0; i < n; i+=pack){\n \
|
||||
uint32_t acc = 0;\n \
|
||||
for (int ii = i; ii < i+pack; ii++){\n \
|
||||
float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \
|
||||
int temp = (int)ftemp;\n \
|
||||
acc = acc | (temp << (bits*(ii-i)));\n \
|
||||
}\n \
|
||||
BQ[(i/pack)*m+j] = acc;\n \
|
||||
//BQ[0] = acc;\n \
|
||||
}\n \
|
||||
}\n \
|
||||
}\n \
|
||||
\n \
|
||||
void quant_scalar_cpu(\n \
|
||||
torch::Tensor in, torch::Tensor out, \n \
|
||||
torch::Tensor scales, torch::Tensor zeros, int bits\n \
|
||||
) {\n \
|
||||
\n \
|
||||
int N = in.size(0);\n \
|
||||
int M = in.size(1);\n \
|
||||
\n \
|
||||
float* input = in.data_ptr<float>(); \n \
|
||||
float* s = scales.data_ptr<float>();\n \
|
||||
float* z = zeros.data_ptr<float>();\n \
|
||||
int* O = out.data_ptr<int>();\n \
|
||||
\n \
|
||||
quantize_scalar(input, O, s, z, N, M, bits);\n \
|
||||
\n \
|
||||
}\n"
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
39
setup.py
39
setup.py
|
@ -1,8 +1,12 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from setuptools import setup, find_packages
|
||||
from setuptools import setup, Extension, find_packages
|
||||
import subprocess
|
||||
import math
|
||||
|
||||
os.environ["CC"] = "g++"
|
||||
os.environ["CXX"] = "g++"
|
||||
|
||||
common_setup_kwargs = {
|
||||
"version": "0.4.1",
|
||||
|
@ -69,12 +73,15 @@ if BUILD_CUDA_EXT:
|
|||
requirements = [
|
||||
"accelerate>=0.19.0",
|
||||
"datasets",
|
||||
"sentencepiece",
|
||||
"numpy",
|
||||
"rouge",
|
||||
"gekko",
|
||||
"torch>=1.13.0",
|
||||
"safetensors",
|
||||
"transformers>=4.31.0",
|
||||
"peft"
|
||||
"peft",
|
||||
"tqdm",
|
||||
]
|
||||
|
||||
extras_require = {
|
||||
|
@ -88,6 +95,9 @@ additional_setup_kwargs = dict()
|
|||
if BUILD_CUDA_EXT:
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
|
||||
|
||||
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
|
||||
if not ROCM_VERSION:
|
||||
from distutils.sysconfig import get_python_lib
|
||||
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
|
||||
|
@ -100,16 +110,23 @@ if BUILD_CUDA_EXT:
|
|||
cpp_extension.CUDAExtension(
|
||||
"autogptq_cuda_64",
|
||||
[
|
||||
"autogptq_cuda/autogptq_cuda_64.cpp",
|
||||
"autogptq_cuda/autogptq_cuda_kernel_64.cu"
|
||||
"autogptq_extension/cuda_64/autogptq_cuda_64.cpp",
|
||||
"autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu"
|
||||
]
|
||||
),
|
||||
cpp_extension.CUDAExtension(
|
||||
"autogptq_cuda_256",
|
||||
[
|
||||
"autogptq_cuda/autogptq_cuda_256.cpp",
|
||||
"autogptq_cuda/autogptq_cuda_kernel_256.cu"
|
||||
"autogptq_extension/cuda_256/autogptq_cuda_256.cpp",
|
||||
"autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu"
|
||||
]
|
||||
),
|
||||
cpp_extension.CppExtension(
|
||||
"cQIGen",
|
||||
[
|
||||
'autogptq_extension/qigen/backend.cpp'
|
||||
],
|
||||
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -126,11 +143,11 @@ if BUILD_CUDA_EXT:
|
|||
cpp_extension.CUDAExtension(
|
||||
"exllama_kernels",
|
||||
[
|
||||
"autogptq_cuda/exllama/exllama_ext.cpp",
|
||||
"autogptq_cuda/exllama/cuda_buffers.cu",
|
||||
"autogptq_cuda/exllama/cuda_func/column_remap.cu",
|
||||
"autogptq_cuda/exllama/cuda_func/q4_matmul.cu",
|
||||
"autogptq_cuda/exllama/cuda_func/q4_matrix.cu"
|
||||
"autogptq_extension/exllama/exllama_ext.cpp",
|
||||
"autogptq_extension/exllama/cuda_buffers.cu",
|
||||
"autogptq_extension/exllama/cuda_func/column_remap.cu",
|
||||
"autogptq_extension/exllama/cuda_func/q4_matmul.cu",
|
||||
"autogptq_extension/exllama/cuda_func/q4_matrix.cu"
|
||||
],
|
||||
extra_link_args=extra_link_args
|
||||
)
|
||||
|
|
Loading…
Add table
Reference in a new issue