support static_groups
This commit is contained in:
parent
d427489911
commit
6233afce3b
1 changed files with 35 additions and 16 deletions
|
@ -60,7 +60,7 @@ class GPTQ:
|
|||
self.H += inp.matmul(inp.t())
|
||||
|
||||
def fasterquant(
|
||||
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False
|
||||
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False, static_groups=False
|
||||
):
|
||||
W = self.layer.weight.data.clone()
|
||||
if isinstance(self.layer, nn.Conv2d):
|
||||
|
@ -80,10 +80,26 @@ class GPTQ:
|
|||
H[dead, dead] = 1
|
||||
W[:, dead] = 0
|
||||
|
||||
g_idx = []
|
||||
scale = []
|
||||
zero = []
|
||||
now_idx = 1
|
||||
|
||||
if static_groups:
|
||||
import copy
|
||||
groups = []
|
||||
for i in range(0, self.columns, group_size):
|
||||
quantizer = copy.deepcopy(self.quantizer)
|
||||
quantizer.find_params(W[:, i:(i + group_size)], weight=True)
|
||||
scale.append(quantizer.scale)
|
||||
zero.append(quantizer.zero)
|
||||
groups.append(quantizer)
|
||||
|
||||
if actorder:
|
||||
perm = torch.argsort(torch.diag(H), descending=True)
|
||||
W = W[:, perm]
|
||||
H = H[perm][:, perm]
|
||||
invperm = torch.argsort(perm)
|
||||
|
||||
Losses = torch.zeros_like(W)
|
||||
Q = torch.zeros_like(W)
|
||||
|
@ -96,11 +112,6 @@ class GPTQ:
|
|||
H = torch.linalg.cholesky(H, upper=True)
|
||||
Hinv = H
|
||||
|
||||
g_idx = []
|
||||
scale = []
|
||||
zero = []
|
||||
now_idx = 1
|
||||
|
||||
for i1 in range(0, self.columns, blocksize):
|
||||
i2 = min(i1 + blocksize, self.columns)
|
||||
count = i2 - i1
|
||||
|
@ -116,6 +127,7 @@ class GPTQ:
|
|||
d = Hinv1[i, i]
|
||||
|
||||
if group_size != -1:
|
||||
if not static_groups:
|
||||
if (i1 + i) % group_size == 0:
|
||||
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
|
||||
|
||||
|
@ -123,6 +135,11 @@ class GPTQ:
|
|||
scale.append(self.quantizer.scale)
|
||||
zero.append(self.quantizer.zero)
|
||||
now_idx += 1
|
||||
else:
|
||||
idx = i1 + i
|
||||
if actorder:
|
||||
idx = perm[idx]
|
||||
self.quantizer = groups[idx // group_size]
|
||||
|
||||
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
||||
Q1[:, i] = q
|
||||
|
@ -148,10 +165,12 @@ class GPTQ:
|
|||
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
|
||||
|
||||
group_size = group_size if group_size != -1 else self.columns
|
||||
if static_groups and actorder:
|
||||
g_idx = [perm[i] // group_size for i in range(self.columns)]
|
||||
else:
|
||||
g_idx = [i // group_size for i in range(self.columns)]
|
||||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||
if actorder:
|
||||
invperm = torch.argsort(perm)
|
||||
Q = Q[:, invperm]
|
||||
g_idx = g_idx[invperm]
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue