Merge remote-tracking branch 'qwopqwop200/main' into main

This commit is contained in:
qwopqwop200 2023-08-25 18:06:03 +09:00
commit 6a9d80eddc
27 changed files with 2498 additions and 79 deletions

View file

@ -13,6 +13,7 @@ import torch.nn as nn
import transformers
from accelerate.hooks import remove_hook_from_module
from safetensors.torch import save_file as safe_save
from safetensors.torch import load_file as safe_load
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
from transformers.utils.generic import ContextManagers
@ -687,7 +688,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
device: Optional[Union[str, int]] = None,
low_cpu_mem_usage: bool = False,
use_triton: bool = False,
torch_dtype: torch.dtype = torch.float16,
use_qigen: bool = False,
torch_dtype: Optional[torch.dtype] = None,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True,
use_cuda_fp16: bool = True,
@ -725,6 +727,12 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
}
if use_qigen:
logger.warning("QIgen is active. Ignores all settings related to cuda.")
inject_fused_attention = False
inject_fused_mlp = False
use_triton = False
disable_exllama = True
if use_triton and not TRITON_AVAILABLE:
logger.warning("Triton is not installed, reset use_triton to False.")
@ -803,6 +811,13 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
def skip(*args, **kwargs):
pass
if torch_dtype is None:
if not use_qigen:
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
if not use_qigen:
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
@ -880,7 +895,46 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
offload_buffers=True
)
model = simple_dispatch_model(model, device_map)
else:
if quantize_config.desc_act:
NotImplementedError('desc_act=True is not yet supported.')
model = AutoModelForCausalLM.from_config(
config,
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype
)
layers = find_layers(model)
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
for name in list(layers.keys()):
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
del layers[name]
if model_save_name.endswith('.safetensors'):
checkpoint = safe_load(model_save_name)
else:
checkpoint = torch.load(model_save_name)
make_quant(
model,
layers,
quantize_config.bits,
quantize_config.group_size,
use_triton=use_triton,
disable_exllama=disable_exllama,
use_cuda_fp16=use_cuda_fp16,
desc_act=quantize_config.desc_act,
trainable=trainable,
use_qigen=True
)
preprocess_checkpoint_qigen(
model,
layers,
quantize_config.bits,
quantize_config.group_size,
checkpoint
)
model.load_state_dict(checkpoint)
# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]

View file

@ -6,11 +6,11 @@ import torch
import torch.nn as nn
from transformers import AutoConfig
import transformers
import cQIGen as qinfer
from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH
from ..utils.import_utils import dynamically_import_QuantLinear
logger = getLogger(__name__)
@ -58,11 +58,12 @@ def make_quant(
name='',
use_triton: bool = False,
disable_exllama: bool = False,
use_qigen: bool = False,
use_cuda_fp16: bool = True,
desc_act: bool = False,
trainable: bool = False
):
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama)
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, desc_act=desc_act, group_size=group_size, bits=bits, disable_exllama=disable_exllama, use_qigen=use_qigen)
if isinstance(module, QuantLinear):
return
@ -81,7 +82,7 @@ def make_quant(
elif isinstance(tmp,transformers.pytorch_utils.Conv1D):
in_features = tmp.weight.shape[0]
out_features = tmp.weight.shape[1]
if (not(desc_act) or group_size == -1) and not use_triton:
if (not(desc_act) or group_size == -1) and not use_triton and not use_qigen:
new_layer = QuantLinear(
bits, group_size, in_features, out_features, True, use_cuda_fp16=use_cuda_fp16, trainable=trainable
)
@ -101,8 +102,73 @@ def make_quant(
desc_act=desc_act,
trainable=trainable,
disable_exllama=disable_exllama,
use_qigen=use_qigen
)
def process_zeros_scales(zeros, scales, bits, out_features):
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != out_features:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != out_features:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
return new_zeros, new_scales
def preprocess_checkpoint_qigen(
module,
names,
bits,
group_size,
checkpoint,
name='',
):
QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True)
if isinstance(module, QuantLinear):
in_features = module.infeatures
out_features = module.outfeatures
checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features)
del checkpoint[name + '.qzeros']
del checkpoint[name + '.g_idx']
if name + '.bias' in checkpoint:
checkpoint[name + '.bias'] = checkpoint[name + '.bias'].float()
else:
checkpoint[name + '.bias'] = torch.zeros(out_features)
checkpoint_qweight = checkpoint[name + '.qweight'].int().contiguous()
if bits == 4:
qweight = torch.zeros(int(in_features // 8 * out_features)).int().contiguous()
qinfer.pack4(checkpoint_qweight, qweight, in_features // 8, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
elif bits == 3:
qweight = torch.zeros(int(in_features // 32 * 3 * out_features)).int().contiguous()
qinfer.pack3(checkpoint_qweight, qweight, in_features // 32 * 3, out_features, module.mb // 32 * 3, module.tb, module.cutoff)
elif bits == 2:
qweight = torch.zeros(int(in_features // 16 * out_features)).int().contiguous()
qinfer.pack2(checkpoint_qweight, qweight, in_features // 16, out_features, module.mb, module.tb, module.cutoff)# * (module.tt//tb))
checkpoint[name + '.qweight'] = qweight
return
for name1, child in module.named_children():
preprocess_checkpoint_qigen(
child,
names,
bits,
group_size,
checkpoint,
name + '.' + name1 if name != '' else name1,
)
def pack_model(
model,
@ -287,6 +353,7 @@ __all__ = [
"get_module_by_name_prefix",
"get_module_by_name_suffix",
"make_quant",
"preprocess_checkpoint_qigen",
"pack_model",
"autogptq_post_init",
"check_and_get_model_type",

View file

@ -0,0 +1,258 @@
from copy import deepcopy
import torch
from torch import nn
from tqdm import tqdm
import gc
import cQIGen as qinfer
import math
import numpy as np
from gekko import GEKKO
from logging import getLogger
logger = getLogger(__name__)
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
m = GEKKO() # create GEKKO model
#cinfergen if bits==3:
# tu = tu*3
B = m.Const(value=bits)
TP = m.Const(value=T//p)
k = m.Var(1,integer=True,lb=1)
z = m.Var(1,integer=True,lb=1)
w = m.Var(1,integer=True,lb=1)
y = m.Var(1,integer=True,lb=1)
v = m.Var(1,integer=True,lb=1)
mb = m.Var(mu,integer=True,lb=1)
if gs != -1:
gg = m.Var(1,integer=True,lb=1)
tb = m.Var(tu,integer=True,lb=1,ub=int(T/p))
L = m.Var(integer=True,lb=0,ub=l1)
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
m.Equation(mb * k == M)
if gs != -1:
m.Equation(gs * gg == mb)
# m.Equation(tb * z == T)
m.Equation(tb * z == TP)
m.Equation(mu * w == mb)
m.Equation(tu * y == tb)
# m.Equation(tb * v == tt)
m.Maximize(L)
m.options.SOLVER = 1
m.solver_options = ['minlp_maximum_iterations 1000', \
# minlp iterations with integer solution
'minlp_max_iter_with_int_sol 10', \
# treat minlp as nlp
'minlp_as_nlp 0', \
# nlp sub-problem max iterations
'nlp_maximum_iterations 100', \
# 1 = depth first, 2 = breadth first
'minlp_branch_method 2', \
# maximum deviation from whole number
'minlp_integer_tol 0.00', \
# covergence tolerance
'minlp_gap_tol 0.01']
try:
m.solve(disp=False)
except:
try:
m.solver_options = ['minlp_maximum_iterations 1000', \
# minlp iterations with integer solution
'minlp_max_iter_with_int_sol 10', \
# treat minlp as nlp
'minlp_as_nlp 0', \
# nlp sub-problem max iterations
'nlp_maximum_iterations 100', \
# 1 = depth first, 2 = breadth first
'minlp_branch_method 1', \
# maximum deviation from whole number
'minlp_integer_tol 0.00', \
# covergence tolerance
'minlp_gap_tol 0.01']
m.solve(disp=False)
except:
# mytb = T//p
mytb = tu
if gs != -1:
mymb = gs
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
mymb += gs
while M % mymb != 0:
mymb -= gs
return (int(mymb), int(mytb))
else:
mymb = mu
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
mymb += mu
while M % mymb != 0:
mymb -= mu
return (int(mymb), int(mytb))
return (int(mb.value[0]), int(tb.value[0]))
params = {}
def compute_reductions(x, gs=-1, cpp=True):
if cpp:
if len(x.shape) != 1:
rows, cols = x.shape
else:
rows = 1
cols = x.shape[0]
if gs == -1:
out = torch.zeros(rows).float().contiguous()
mygs = cols
else:
out = torch.zeros(rows, cols // gs).float().contiguous()
mygs = gs
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
return out
if gs == -1:
if len(x.shape) != 1:
return torch.sum(x,1)
else:
return torch.sum(x)
else:
if len(x.shape) != 1:
rows, cols = x.shape
out = torch.zeros(rows, cols // gs).float().contiguous()
for i in range(cols // gs):
out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1)
return out
else:
cols = x.shape[0]
out = torch.zeros(cols // gs).float().contiguous()
for i in range(cols // gs):
out[i] = torch.sum(x[i*gs:(i+1)*gs])
return out
def process_zeros_scales(zeros, scales, bits, M):
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != M:
new_scales = scales.transpose(0,1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != M:
new_zeros = zeros.transpose(0,1).contiguous()
else:
new_zeros = zeros.contiguous()
return new_zeros, new_scales
class QuantLinear(nn.Module):
QUANT_TYPE = "qigen"
def __init__(self, bits, group_size, infeatures, outfeatures, bias=None, trainable=False, hint=1, p=8, l1=2**18):
super().__init__()
if bits not in [2, 4]:
raise NotImplementedError("Only 2,4 bits are supported.")
if trainable:
raise NotImplementedError("Qigen kernel does not support training.")
self.bits = bits
pack = 32 // bits
self.infeatures = infeatures
self.outfeatures = outfeatures
n = hint
m = self.infeatures
t = self.outfeatures
#registers for now are fixed
if bits == 3:
packed = 32
unroll = 3
nu = 1 #args.n
mu = 32
tu = 32
else:
packed = 32 // bits
unroll = 2
nu = 1 #args.n
mu = 16
tu = 32
nb = n # it's always small for transformers
global params
if (m,t) in params:
mb = params[(m,t)][0]
tb = params[(m,t)][1]
else:
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
params[(m,t)] = (mb,tb)
split = np.ones(p)
split = split * tb
while np.sum(split) < t:
split = split + tb
idx = p - 1
while np.sum(split) > t:
split[idx] = split[idx] - tb
idx = idx - 1
assert(np.sum(split) == t)
split = split.astype(int)
self.tt = int(split[0])
if split[0] == split[-1]:
self.cutoff = int(p+1)
else:
self.cutoff = int(idx + 1)
self.mb = mb #// packed
self.tb = tb
self.group_size = group_size
self.register_buffer('bias', torch.zeros(self.outfeatures))
self.register_buffer('zeros', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
if bits == 4:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
elif bits == 3:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous())
elif bits == 2:
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
B = x.shape[0]
new_x = x.T.contiguous()
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
sums = compute_reductions(x,gs=self.group_size,cpp=True).contiguous()
if self.group_size == -1:
if self.bits == 4:
qinfer.forward4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
elif self.bits == 2:
qinfer.forward2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
elif self.bits == 3:
qinfer.forward3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
else:
if self.bits == 4:
qinfer.forward_gs4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
elif self.bits == 2:
qinfer.forward_gs2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
elif self.bits == 3:
qinfer.forward_gs3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
return out.reshape(out_shape)

View file

@ -28,7 +28,10 @@ except:
logger = getLogger(__name__)
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False):
def dynamically_import_QuantLinear(use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool = False, use_qigen: bool = False):
if use_qigen:
from ..nn_modules.qlinear.qlinear_qigen import QuantLinear
else:
if use_triton:
if torch.version.hip:
logger.warning("Running GPTQ triton version on AMD GPUs is untested and may result in errors or wrong predictions. Please use use_triton=False.")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,149 @@
def load_int(to, address, const=True):
if const:
return f"const __m256i {to} = _mm256_loadu_si256({address});"
else:
return f"__m256i {to} = _mm256_loadu_si256({address});"
def load_fp(to, address, const=True):
if const:
return f"const __m256 {to} = _mm256_loadu_ps({address});"
else:
return f"__m256 {to} = _mm256_loadu_ps({address});"
# to = a * b + c
def vfma(to, a, b, c):
return f"__m256 {to} = _mm256_fmadd_ps({a}, {b}, {c});"
def vsrli(to, a, b):
return f"const __m256i {to} = _mm256_srli_epi32({a}, {b});"
def vand(to, a, b):
return f"const __m256i {to} = _mm256_and_si256({a}, {b});"
def vbroadcast_fp(to, a):
return f"const __m256 {to} = _mm256_set1_ps({a});"
def vbroadcast_int32(to, a):
return f"__m256i {to} = _mm256_set1_epi32({a});"
def vsetzero(to):
return f"__m256 {to} = _mm256_setzero_ps();"
def vcvtepi32_ps(to, a):
return f"const __m256 {to} = _mm256_cvtepi32_ps({a});"
def _256extractf128_ps(to, a, imm):
return f"const __m128 {to} = _mm256_extractf128_ps({a}, {imm});"
def _256castps256_ps128(to, a):
return f"const __m128 {to} = _mm256_castps256_ps128({a});"
def _add_ps(to, a, b):
return f"const __m128 {to} = _mm_add_ps({a}, {b});"
def _movehl_ps(to, a, b):
return f"const __m128 {to} = _mm_movehl_ps({a}, {b});"
def _shuffle_ps(to, a, b, imm):
return f"const __m128 {to} = _mm_shuffle_ps({a}, {b}, {imm});"
def _cvtss_f32(to, a):
return f"const float {to} = _mm_cvtss_f32({a});"
def _reduce8_acc(a, b, c, d, e, f, g, h):
res = ""
res += _256extractf128_ps("hi_quad0", a, 1)
res += _256extractf128_ps("hi_quad1", b, 1)
res += _256extractf128_ps("hi_quad2", c, 1)
res += _256extractf128_ps("hi_quad3", d, 1)
res += _256extractf128_ps("hi_quad4", e, 1)
res += _256extractf128_ps("hi_quad5", f, 1)
res += _256extractf128_ps("hi_quad6", g, 1)
res += _256extractf128_ps("hi_quad7", h, 1)
res += _256castps256_ps128("lo_quad0", a)
res += _256castps256_ps128("lo_quad1", b)
res += _256castps256_ps128("lo_quad2", c)
res += _256castps256_ps128("lo_quad3", d)
res += _256castps256_ps128("lo_quad4", e)
res += _256castps256_ps128("lo_quad5", f)
res += _256castps256_ps128("lo_quad6", g)
res += _256castps256_ps128("lo_quad7", h)
res += _add_ps("sum_quad0", "lo_quad0", "hi_quad0")
res += _add_ps("sum_quad1", "lo_quad1", "hi_quad1")
res += _add_ps("sum_quad2", "lo_quad2", "hi_quad2")
res += _add_ps("sum_quad3", "lo_quad3", "hi_quad3")
res += _add_ps("sum_quad4", "lo_quad4", "hi_quad4")
res += _add_ps("sum_quad5", "lo_quad5", "hi_quad5")
res += _add_ps("sum_quad6", "lo_quad6", "hi_quad6")
res += _add_ps("sum_quad7", "lo_quad7", "hi_quad7")
res += _movehl_ps("hi_dual0", "sum_quad0", "sum_quad0")
res += _movehl_ps("hi_dual1", "sum_quad1", "sum_quad1")
res += _movehl_ps("hi_dual2", "sum_quad2", "sum_quad2")
res += _movehl_ps("hi_dual3", "sum_quad3", "sum_quad3")
res += _movehl_ps("hi_dual4", "sum_quad4", "sum_quad4")
res += _movehl_ps("hi_dual5", "sum_quad5", "sum_quad5")
res += _movehl_ps("hi_dual6", "sum_quad6", "sum_quad6")
res += _movehl_ps("hi_dual7", "sum_quad7", "sum_quad7")
res += _add_ps("sum_dual0", "sum_quad0", "hi_dual0")
res += _add_ps("sum_dual1", "sum_quad1", "hi_dual1")
res += _add_ps("sum_dual2", "sum_quad2", "hi_dual2")
res += _add_ps("sum_dual3", "sum_quad3", "hi_dual3")
res += _add_ps("sum_dual4", "sum_quad4", "hi_dual4")
res += _add_ps("sum_dual5", "sum_quad5", "hi_dual5")
res += _add_ps("sum_dual6", "sum_quad6", "hi_dual6")
res += _add_ps("sum_dual7", "sum_quad7", "hi_dual7")
res += _shuffle_ps("hi0", "sum_dual0", "sum_dual0", 0x1)
res += _shuffle_ps("hi1", "sum_dual1", "sum_dual1", 0x1)
res += _shuffle_ps("hi2", "sum_dual2", "sum_dual2", 0x1)
res += _shuffle_ps("hi3", "sum_dual3", "sum_dual3", 0x1)
res += _shuffle_ps("hi4", "sum_dual4", "sum_dual4", 0x1)
res += _shuffle_ps("hi5", "sum_dual5", "sum_dual5", 0x1)
res += _shuffle_ps("hi6", "sum_dual6", "sum_dual6", 0x1)
res += _shuffle_ps("hi7", "sum_dual7", "sum_dual7", 0x1)
res += _add_ps("sum0", "sum_dual0", "hi0")
res += _add_ps("sum1", "sum_dual1", "hi1")
res += _add_ps("sum2", "sum_dual2", "hi2")
res += _add_ps("sum3", "sum_dual3", "hi3")
res += _add_ps("sum4", "sum_dual4", "hi4")
res += _add_ps("sum5", "sum_dual5", "hi5")
res += _add_ps("sum6", "sum_dual6", "hi6")
res += _add_ps("sum7", "sum_dual7", "hi7")
res += _cvtss_f32(f"f{a}", "sum0")
res += _cvtss_f32(f"f{b}", "sum1")
res += _cvtss_f32(f"f{c}", "sum2")
res += _cvtss_f32(f"f{d}", "sum3")
res += _cvtss_f32(f"f{e}", "sum4")
res += _cvtss_f32(f"f{f}", "sum5")
res += _cvtss_f32(f"f{g}", "sum6")
res += _cvtss_f32(f"f{h}", "sum7")
return res
acc_idx = 0
def _reduce_add(a):
global acc_idx
res = ""
res += _256extractf128_ps(f"hi_quad{acc_idx}", a, 1)
res += _256castps256_ps128(f"lo_quad{acc_idx}", a)
res += _add_ps(f"sum_quad{acc_idx}", f"lo_quad{acc_idx}", f"hi_quad{acc_idx}")
res += _movehl_ps(f"hi_dual{acc_idx}", f"sum_quad{acc_idx}", f"sum_quad{acc_idx}")
res += _add_ps(f"sum_dual{acc_idx}", f"sum_quad{acc_idx}", f"hi_dual{acc_idx}")
res += _shuffle_ps(f"hi{acc_idx}", f"sum_dual{acc_idx}", f"sum_dual{acc_idx}", 0x1)
res += _add_ps(f"sum{acc_idx}", f"sum_dual{acc_idx}", f"hi{acc_idx}")
res += _cvtss_f32(f"f{a}", f"sum{acc_idx}")
acc_idx += 1
return res

View file

@ -0,0 +1,302 @@
#include <iostream>
#include "forward.h"
#include <cstring>
#include <algorithm>
#include <vector>
#include <chrono>
#include <fstream>
#define mymin(a,b) ((a)<(b)?(a):(b))
#define mymax(a,b) ((a)>(b)?(a):(b))
void print_matrix(std::string name, float* A, int N, int M){
std::cout<<name<<std::endl;
for(int i = 0; i < N; i++){
for(int j = 0; j < M; j++){
std::cout << A[i*M+j] << " ";
}
std::cout << std::endl;
}
std::cout<<std::endl;
}
void oracle_mmadd(float* A, float* B, float* bias, float* C, int n, int m, int t){
// triple loop matmul and add bias
for (int i = 0; i < n; i++){
for (int j = 0; j < t; j++){
float sum = 0;
for (int k = 0; k < m; k++){
sum += A[i*m+k] * B[k*t+j];
}
C[i*t+j] += sum + bias[j];
}
}
}
void compute_reduction(float *in, float *out, int n, int m, int gs){
int ng;
if(gs == -1){
ng = 1;
gs = m;
}else{
ng = m/gs;
}
for(int i = 0; i < n; i++){
for(int j0 = 0; j0 < m; j0+=gs){
int j = j0/gs;
out[i*ng+j] = 0;
for(int j1 = j0; j1 < j0+gs; j1++){
out[i*ng+j] += in[i*m+j1];
}
}
}
}
void quantize_sim(float* A, float* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
//find scales and zeros arrays
if(gs == -1){
gs = n;
}
float range = (1<<bits) - 1;
int packed = 32 / bits;
for(int i0 = 0; i0 < n; i0+=gs){
int row = i0/gs;
for(int j = 0; j < m; j++){
float min = A[i0*m + j];
float max = A[i0*m + j];
for(int i1 = i0; i1 < i0+gs; i1++){
min = mymin(min, A[i1*m+j]);
max = mymax(max, A[i1*m+j]);
}
scales[row*m + j] = (max-min)/range;
zeros[row*m + j ] = min;
}
for(int j = 0; j < m; j++){
for (int i1 = i0; i1 < i0+gs; i1++){
uint32_t acc = 0;
int temp = (A[i1*m+j] - zeros[row*m+j])/scales[row*m+j];
float val = ((float) temp + zeros[row*m+j]) * scales[row*m+j];
BQ[i1*m+j] = val;
}
}
}
}
void quantize(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits, int gs){
//find scales and zeros arrays
if(gs == -1){
gs = n;
}
float range = (1<<bits) - 1;
int packed = 32 / bits;
for(int i0 = 0; i0 < n; i0+=gs){
int row = i0/gs;
for(int j = 0; j < m; j++){
float min = A[i0*m + j];
float max = A[i0*m + j];
for(int i1 = i0; i1 < i0+gs; i1++){
min = mymin(min, A[i1*m+j]);
max = mymax(max, A[i1*m+j]);
}
scales[row*m + j] = (max-min)/range;
zeros[row*m + j ] = min;
}
for(int j = 0; j < m; j++){
if(bits == 3){
for (int i1 = i0; i1 < i0+gs; i1+=32){
uint32_t acc = 0;
int temp0 = ((int)((A[(i1+0)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 0;
int temp1 = ((int)((A[(i1+1)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 3;
int temp2 = ((int)((A[(i1+2)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 6;
int temp3 = ((int)((A[(i1+3)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 9;
int temp4 = ((int)((A[(i1+4)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 12;
int temp5 = ((int)((A[(i1+5)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 15;
int temp6 = ((int)((A[(i1+6)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 18;
int temp7 = ((int)((A[(i1+7)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 21;
int temp8 = ((int)((A[(i1+8)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 24;
int temp9 = ((int)((A[(i1+9)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 27;
int temp10_0 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 30;
int temp10_1 = ((int)((A[(i1+10)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 2;
int temp11 = ((int)((A[(i1+11)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 1;
int temp12 = ((int)((A[(i1+12)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 4;
int temp13 = ((int)((A[(i1+13)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 7;
int temp14 = ((int)((A[(i1+14)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 10;
int temp15 = ((int)((A[(i1+15)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 13;
int temp16 = ((int)((A[(i1+16)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 16;
int temp17 = ((int)((A[(i1+17)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 19;
int temp18 = ((int)((A[(i1+18)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 22;
int temp19 = ((int)((A[(i1+19)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 25;
int temp20 = ((int)((A[(i1+20)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 28;
int temp21_0 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 31;
int temp21_1 = ((int)((A[(i1+21)*m+j] - zeros[row*m+j])/scales[row*m+j])) >> 1;
int temp22 = ((int)((A[(i1+22)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 2;
int temp23 = ((int)((A[(i1+23)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 5;
int temp24 = ((int)((A[(i1+24)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 8;
int temp25 = ((int)((A[(i1+25)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 11;
int temp26 = ((int)((A[(i1+26)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 14;
int temp27 = ((int)((A[(i1+27)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 17;
int temp28 = ((int)((A[(i1+28)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 20;
int temp29 = ((int)((A[(i1+29)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 23;
int temp30 = ((int)((A[(i1+30)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 26;
int temp31 = ((int)((A[(i1+31)*m+j] - zeros[row*m+j])/scales[row*m+j])) << 29;
int acc0 = 0, acc1 = 0, acc2 = 0;
acc0 |= temp0;
acc0 |= temp1;
acc0 |= temp2;
acc0 |= temp3;
acc0 |= temp4;
acc0 |= temp5;
acc0 |= temp6;
acc0 |= temp7;
acc0 |= temp8;
acc0 |= temp9;
acc0 |= temp10_0;
acc1 |= temp10_1;
acc1 |= temp11;
acc1 |= temp12;
acc1 |= temp13;
acc1 |= temp14;
acc1 |= temp15;
acc1 |= temp16;
acc1 |= temp17;
acc1 |= temp18;
acc1 |= temp19;
acc1 |= temp20;
acc1 |= temp21_0;
acc2 |= temp21_1;
acc2 |= temp22;
acc2 |= temp23;
acc2 |= temp24;
acc2 |= temp25;
acc2 |= temp26;
acc2 |= temp27;
acc2 |= temp28;
acc2 |= temp29;
acc2 |= temp30;
acc2 |= temp31;
BQ[(3*i1/32)*m+j] = acc0;
BQ[(3*i1/32+1)*m+j] = acc1;
BQ[(3*i1/32+2)*m+j] = acc2;
}
}else{
for (int i1 = i0; i1 < i0+gs; i1+=packed){
uint32_t acc = 0;
for (int i2 = i1; i2 < i1+packed; i2++){
int temp = (A[i2*m+j] - zeros[row*m+j])/scales[row*m+j];
acc = acc | (temp << (bits*(i2-i1)));
}
BQ[(i1/packed)*m+j] = acc;
}
}
}
}
}
int main(int argc, char *argv[]){
// read n m t from args
if(argc == 0){std::cout << "Parameters not given\n"; return 0;}
int n = atoi(argv[1]);
int m = atoi(argv[2]);
int t = atoi(argv[3]);
int bits = atoi(argv[4]);
int gs = atoi(argv[5]);
int ng;
if(gs == -1){
ng = 1;
}else{
ng = m/gs;
}
float* A = new float[n*m];
float* AB = new float[n*m];
float* B = new float[m*t];
float* BQS = new float[m*t];
float* scales = new float[t*ng];
float* zeros = new float[t*ng];
int* BQ = new int[m*t/8];
int* BQB = new int[m*t/8];
float* sums = new float[n*ng];
float* bias = new float[t];
float* C = new float[n*t];
float* CB = new float[n*t];
float* C2 = new float[n*t];
srand(1);
for (int i = 0; i < n*m; i++){
A[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < t*m; i++){
B[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < t; i++){
bias[i] = (float)rand() / RAND_MAX;
}
for (int i = 0; i < n*t; i++){
C[i] = 0.0;
C2[i] = 0.0;
}
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
quantize(B,BQ,scales,zeros,m,t,bits,gs);
quantize_sim(B,BQS,scales,zeros,m,t,bits,gs);
quantize(B,BQ,scales,zeros,m,t,bits,gs);
oracle_mmadd(A, BQS, bias, C, n, m, t);
pack_input(A,AB);
pack_qw(BQ,BQB);
pack_output(C,CB);
compute_reduction(A,sums,n,m,gs);
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
float norm = 0.0;
for (int i = 0; i < n*t; i++){
norm += (C[i] - C2[i]) * (C[i] - C2[i]);
}
if(norm / (n*t) < 0.0001){
int iter = 30;
for(int _ = 0; _ < iter; _++){
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
}
int num_runs = 15;
std::vector<long int> runs(num_runs);
for(int r = 0; r < num_runs; r++){
auto start = std::chrono::high_resolution_clock::now();
for(int _ = 0; _ < iter; _++){
qforward(AB,BQB,scales,zeros,bias,sums,C2,n,m,t);
}
auto end = std::chrono::high_resolution_clock::now();
runs[r] = std::chrono::duration_cast<std::chrono::nanoseconds>(end - start).count();
}
std::sort(runs.begin(), runs.end());
float cycles_final = runs[num_runs/2 + 1] / iter;
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
print_parameters();
outfile << cycles_final << std::endl;
}else{
float cycles_final = int(10e12);
std::ofstream outfile;
outfile.open("./autogptq_extension/qigen/tmp.csv", std::ios_base::app);
print_parameters();
outfile << cycles_final << std::endl;
}
return 0;
}

View file

@ -0,0 +1,85 @@
def includes():
out = " \
#include <torch/all.h>\n \
#include <torch/python.h>\n \
#include <omp.h>\n \
#include <cmath>\n \
#include <immintrin.h>\n \
\n \
#define mymin(a,b) ((a)<(b)?(a):(b))\n \
#define mymax(a,b) ((a)>(b)?(a):(b))\n \
"
return out
def module(bits_list=[4, 2]):
out = 'PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {\n'
for bits in bits_list:
out += ' m.def("forward{}", &forward{}_cpu);\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("unpack_zeros{}", &unpack_zeros{});\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("forward_gs{}", &forward{}_gs_cpu);\n'.format(bits, bits)
for bits in bits_list:
out += ' m.def("pack{}", &pack{}_w_cpu);\n'.format(bits, bits)
out += 'm.def("compute_reduction_cpp", &compute_reduction);\n'
out += 'm.def("unquantize_sim", &unquantize_sim);\n'
# if oracle:
# out += ' m.def("forward4_oracle", &forward4_oracle_cpu);\n'
out += 'm.def("quant_scalar_scaled", &quant_scalar_cpu);\n'
out += '}\n'
return out
def quant_scalar():
out = " \
void quantize_scalar(float* A, int* BQ, float* scales, float* zeros, int n, int m, int bits){ \n \
//find scales and zeros arrays \n \
//quantize \n \
int pack = 32/bits;\n \
for (int j = 0; j < m; j++){\n \
for (int i = 0; i < n; i+=pack){\n \
uint32_t acc = 0;\n \
for (int ii = i; ii < i+pack; ii++){\n \
float ftemp = std::round((A[ii*m+j] + zeros[j])/scales[j]);\n \
int temp = (int)ftemp;\n \
acc = acc | (temp << (bits*(ii-i)));\n \
}\n \
BQ[(i/pack)*m+j] = acc;\n \
//BQ[0] = acc;\n \
}\n \
}\n \
}\n \
\n \
void quant_scalar_cpu(\n \
torch::Tensor in, torch::Tensor out, \n \
torch::Tensor scales, torch::Tensor zeros, int bits\n \
) {\n \
\n \
int N = in.size(0);\n \
int M = in.size(1);\n \
\n \
float* input = in.data_ptr<float>(); \n \
float* s = scales.data_ptr<float>();\n \
float* z = zeros.data_ptr<float>();\n \
int* O = out.data_ptr<int>();\n \
\n \
quantize_scalar(input, O, s, z, N, M, bits);\n \
\n \
}\n"
return out

View file

@ -1,8 +1,12 @@
import os
import sys
from pathlib import Path
from setuptools import setup, find_packages
from setuptools import setup, Extension, find_packages
import subprocess
import math
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
common_setup_kwargs = {
"version": "0.4.1",
@ -69,12 +73,15 @@ if BUILD_CUDA_EXT:
requirements = [
"accelerate>=0.19.0",
"datasets",
"sentencepiece",
"numpy",
"rouge",
"gekko",
"torch>=1.13.0",
"safetensors",
"transformers>=4.31.0",
"peft"
"peft",
"tqdm",
]
extras_require = {
@ -88,6 +95,9 @@ additional_setup_kwargs = dict()
if BUILD_CUDA_EXT:
from torch.utils import cpp_extension
p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2])
subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)])
if not ROCM_VERSION:
from distutils.sysconfig import get_python_lib
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
@ -100,16 +110,23 @@ if BUILD_CUDA_EXT:
cpp_extension.CUDAExtension(
"autogptq_cuda_64",
[
"autogptq_cuda/autogptq_cuda_64.cpp",
"autogptq_cuda/autogptq_cuda_kernel_64.cu"
"autogptq_extension/cuda_64/autogptq_cuda_64.cpp",
"autogptq_extension/cuda_64/autogptq_cuda_kernel_64.cu"
]
),
cpp_extension.CUDAExtension(
"autogptq_cuda_256",
[
"autogptq_cuda/autogptq_cuda_256.cpp",
"autogptq_cuda/autogptq_cuda_kernel_256.cu"
"autogptq_extension/cuda_256/autogptq_cuda_256.cpp",
"autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu"
]
),
cpp_extension.CppExtension(
"cQIGen",
[
'autogptq_extension/qigen/backend.cpp'
],
extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"]
)
]
@ -126,11 +143,11 @@ if BUILD_CUDA_EXT:
cpp_extension.CUDAExtension(
"exllama_kernels",
[
"autogptq_cuda/exllama/exllama_ext.cpp",
"autogptq_cuda/exllama/cuda_buffers.cu",
"autogptq_cuda/exllama/cuda_func/column_remap.cu",
"autogptq_cuda/exllama/cuda_func/q4_matmul.cu",
"autogptq_cuda/exllama/cuda_func/q4_matrix.cu"
"autogptq_extension/exllama/exllama_ext.cpp",
"autogptq_extension/exllama/cuda_buffers.cu",
"autogptq_extension/exllama/cuda_func/column_remap.cu",
"autogptq_extension/exllama/cuda_func/q4_matmul.cu",
"autogptq_extension/exllama/cuda_func/q4_matrix.cu"
],
extra_link_args=extra_link_args
)