262 lines
9.6 KiB
Python
262 lines
9.6 KiB
Python
from copy import deepcopy
|
|
import torch
|
|
from torch import nn
|
|
from tqdm import tqdm
|
|
import gc
|
|
|
|
import math
|
|
import numpy as np
|
|
from gekko import GEKKO
|
|
from logging import getLogger
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
try:
|
|
import cQIGen as qinfer
|
|
except ImportError:
|
|
logger.error('cQIGen not installed.')
|
|
raise
|
|
|
|
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
|
|
m = GEKKO() # create GEKKO model
|
|
#cinfergen if bits==3:
|
|
# tu = tu*3
|
|
B = m.Const(value=bits)
|
|
TP = m.Const(value=T//p)
|
|
k = m.Var(1,integer=True,lb=1)
|
|
z = m.Var(1,integer=True,lb=1)
|
|
w = m.Var(1,integer=True,lb=1)
|
|
y = m.Var(1,integer=True,lb=1)
|
|
v = m.Var(1,integer=True,lb=1)
|
|
mb = m.Var(mu,integer=True,lb=1)
|
|
if gs != -1:
|
|
gg = m.Var(1,integer=True,lb=1)
|
|
tb = m.Var(tu,integer=True,lb=1,ub=int(T/p))
|
|
L = m.Var(integer=True,lb=0,ub=l1)
|
|
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
|
|
m.Equation(mb * k == M)
|
|
if gs != -1:
|
|
m.Equation(gs * gg == mb)
|
|
# m.Equation(tb * z == T)
|
|
m.Equation(tb * z == TP)
|
|
m.Equation(mu * w == mb)
|
|
m.Equation(tu * y == tb)
|
|
# m.Equation(tb * v == tt)
|
|
m.Maximize(L)
|
|
m.options.SOLVER = 1
|
|
m.solver_options = ['minlp_maximum_iterations 1000', \
|
|
# minlp iterations with integer solution
|
|
'minlp_max_iter_with_int_sol 10', \
|
|
# treat minlp as nlp
|
|
'minlp_as_nlp 0', \
|
|
# nlp sub-problem max iterations
|
|
'nlp_maximum_iterations 100', \
|
|
# 1 = depth first, 2 = breadth first
|
|
'minlp_branch_method 2', \
|
|
# maximum deviation from whole number
|
|
'minlp_integer_tol 0.00', \
|
|
# covergence tolerance
|
|
'minlp_gap_tol 0.01']
|
|
try:
|
|
m.solve(disp=False)
|
|
except:
|
|
try:
|
|
m.solver_options = ['minlp_maximum_iterations 1000', \
|
|
# minlp iterations with integer solution
|
|
'minlp_max_iter_with_int_sol 10', \
|
|
# treat minlp as nlp
|
|
'minlp_as_nlp 0', \
|
|
# nlp sub-problem max iterations
|
|
'nlp_maximum_iterations 100', \
|
|
# 1 = depth first, 2 = breadth first
|
|
'minlp_branch_method 1', \
|
|
# maximum deviation from whole number
|
|
'minlp_integer_tol 0.00', \
|
|
# covergence tolerance
|
|
'minlp_gap_tol 0.01']
|
|
m.solve(disp=False)
|
|
except:
|
|
# mytb = T//p
|
|
mytb = tu
|
|
if gs != -1:
|
|
mymb = gs
|
|
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
|
|
mymb += gs
|
|
while M % mymb != 0:
|
|
mymb -= gs
|
|
return (int(mymb), int(mytb))
|
|
else:
|
|
mymb = mu
|
|
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
|
|
mymb += mu
|
|
while M % mymb != 0:
|
|
mymb -= mu
|
|
return (int(mymb), int(mytb))
|
|
|
|
return (int(mb.value[0]), int(tb.value[0]))
|
|
|
|
params = {}
|
|
|
|
def compute_reductions(x, gs=-1, cpp=True):
|
|
if cpp:
|
|
if len(x.shape) != 1:
|
|
rows, cols = x.shape
|
|
else:
|
|
rows = 1
|
|
cols = x.shape[0]
|
|
if gs == -1:
|
|
out = torch.zeros(rows).float().contiguous()
|
|
mygs = cols
|
|
else:
|
|
out = torch.zeros(rows, cols // gs).float().contiguous()
|
|
mygs = gs
|
|
|
|
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
|
|
return out
|
|
if gs == -1:
|
|
if len(x.shape) != 1:
|
|
return torch.sum(x,1)
|
|
else:
|
|
return torch.sum(x)
|
|
else:
|
|
if len(x.shape) != 1:
|
|
rows, cols = x.shape
|
|
out = torch.zeros(rows, cols // gs).float().contiguous()
|
|
for i in range(cols // gs):
|
|
out[:,i] = torch.sum(x[:,i*gs:(i+1)*gs],1)
|
|
return out
|
|
else:
|
|
cols = x.shape[0]
|
|
out = torch.zeros(cols // gs).float().contiguous()
|
|
for i in range(cols // gs):
|
|
out[i] = torch.sum(x[i*gs:(i+1)*gs])
|
|
return out
|
|
|
|
def process_zeros_scales(zeros, scales, bits, M):
|
|
if zeros.dtype != torch.float32:
|
|
new_zeros = torch.zeros_like(scales).float().contiguous()
|
|
if bits == 4:
|
|
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
|
elif bits == 2:
|
|
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
|
|
elif bits == 3:
|
|
logger.info("Unpacking zeros for 3 bits")
|
|
new_scales = scales.contiguous()
|
|
else:
|
|
if scales.shape[1] != M:
|
|
new_scales = scales.transpose(0,1).contiguous()
|
|
else:
|
|
new_scales = scales.contiguous()
|
|
if zeros.shape[1] != M:
|
|
new_zeros = zeros.transpose(0,1).contiguous()
|
|
else:
|
|
new_zeros = zeros.contiguous()
|
|
|
|
return new_zeros, new_scales
|
|
|
|
class QuantLinear(nn.Module):
|
|
QUANT_TYPE = "qigen"
|
|
|
|
def __init__(self, bits, group_size, infeatures, outfeatures, bias=None, trainable=False, hint=1, p=8, l1=2**18):
|
|
super().__init__()
|
|
if bits not in [2, 4]:
|
|
raise NotImplementedError("Only 2,4 bits are supported.")
|
|
if trainable:
|
|
raise NotImplementedError("Qigen kernel does not support training.")
|
|
self.bits = bits
|
|
pack = 32 // bits
|
|
|
|
self.infeatures = infeatures
|
|
self.outfeatures = outfeatures
|
|
|
|
n = hint
|
|
m = self.infeatures
|
|
t = self.outfeatures
|
|
|
|
#registers for now are fixed
|
|
if bits == 3:
|
|
packed = 32
|
|
unroll = 3
|
|
nu = 1 #args.n
|
|
mu = 32
|
|
tu = 32
|
|
else:
|
|
packed = 32 // bits
|
|
unroll = 2
|
|
nu = 1 #args.n
|
|
mu = 16
|
|
tu = 32
|
|
|
|
nb = n # it's always small for transformers
|
|
|
|
global params
|
|
if (m,t) in params:
|
|
mb = params[(m,t)][0]
|
|
tb = params[(m,t)][1]
|
|
else:
|
|
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
|
|
params[(m,t)] = (mb,tb)
|
|
|
|
split = np.ones(p)
|
|
split = split * tb
|
|
while np.sum(split) < t:
|
|
split = split + tb
|
|
|
|
idx = p - 1
|
|
while np.sum(split) > t:
|
|
split[idx] = split[idx] - tb
|
|
idx = idx - 1
|
|
|
|
assert(np.sum(split) == t)
|
|
|
|
split = split.astype(int)
|
|
self.tt = int(split[0])
|
|
|
|
if split[0] == split[-1]:
|
|
self.cutoff = int(p+1)
|
|
else:
|
|
self.cutoff = int(idx + 1)
|
|
|
|
self.mb = mb #// packed
|
|
self.tb = tb
|
|
|
|
self.group_size = group_size
|
|
|
|
self.register_buffer('bias', torch.zeros(self.outfeatures))
|
|
self.register_buffer('zeros', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
|
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.group_size), outfeatures), dtype=torch.float32))
|
|
if bits == 4:
|
|
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
|
elif bits == 3:
|
|
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous())
|
|
elif bits == 2:
|
|
self.register_buffer('qweight', torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous())
|
|
|
|
def forward(self, x):
|
|
out_shape = x.shape[:-1] + (self.outfeatures,)
|
|
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
|
|
B = x.shape[0]
|
|
new_x = x.T.contiguous()
|
|
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
|
|
sums = compute_reductions(x,gs=self.group_size,cpp=True).contiguous()
|
|
if self.group_size == -1:
|
|
if self.bits == 4:
|
|
qinfer.forward4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
|
elif self.bits == 2:
|
|
qinfer.forward2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
|
elif self.bits == 3:
|
|
qinfer.forward3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.cutoff)
|
|
else:
|
|
if self.bits == 4:
|
|
qinfer.forward_gs4(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
|
elif self.bits == 2:
|
|
qinfer.forward_gs2(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
|
elif self.bits == 3:
|
|
qinfer.forward_gs3(new_x, self.qweight, out, self.bias, self.scales, self.zeros, sums,
|
|
B, self.infeatures, self.outfeatures, B, self.mb, self.tb, self.tt, self.group_size, self.cutoff)
|
|
return out.reshape(out_shape)
|