support static_groups

This commit is contained in:
qwopqwop200 2023-08-07 16:25:44 +09:00 committed by GitHub
parent d427489911
commit 6233afce3b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -60,7 +60,7 @@ class GPTQ:
self.H += inp.matmul(inp.t())
def fasterquant(
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False, static_groups=False
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
@ -80,10 +80,26 @@ class GPTQ:
H[dead, dead] = 1
W[:, dead] = 0
g_idx = []
scale = []
zero = []
now_idx = 1
if static_groups:
import copy
groups = []
for i in range(0, self.columns, group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i:(i + group_size)], weight=True)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
@ -96,11 +112,6 @@ class GPTQ:
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
g_idx = []
scale = []
zero = []
now_idx = 1
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
@ -116,14 +127,20 @@ class GPTQ:
d = Hinv1[i, i]
if group_size != -1:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
if ((i1 + i) // group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
if not static_groups:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
if ((i1 + i) // group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // group_size]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2
@ -148,10 +165,12 @@ class GPTQ:
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
group_size = group_size if group_size != -1 else self.columns
g_idx = [i // group_size for i in range(self.columns)]
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]