support static_groups
This commit is contained in:
parent
d427489911
commit
6233afce3b
1 changed files with 35 additions and 16 deletions
|
@ -60,7 +60,7 @@ class GPTQ:
|
||||||
self.H += inp.matmul(inp.t())
|
self.H += inp.matmul(inp.t())
|
||||||
|
|
||||||
def fasterquant(
|
def fasterquant(
|
||||||
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False
|
self, blocksize=128, percdamp=.01, group_size=-1, actorder=False, static_groups=False
|
||||||
):
|
):
|
||||||
W = self.layer.weight.data.clone()
|
W = self.layer.weight.data.clone()
|
||||||
if isinstance(self.layer, nn.Conv2d):
|
if isinstance(self.layer, nn.Conv2d):
|
||||||
|
@ -80,10 +80,26 @@ class GPTQ:
|
||||||
H[dead, dead] = 1
|
H[dead, dead] = 1
|
||||||
W[:, dead] = 0
|
W[:, dead] = 0
|
||||||
|
|
||||||
|
g_idx = []
|
||||||
|
scale = []
|
||||||
|
zero = []
|
||||||
|
now_idx = 1
|
||||||
|
|
||||||
|
if static_groups:
|
||||||
|
import copy
|
||||||
|
groups = []
|
||||||
|
for i in range(0, self.columns, group_size):
|
||||||
|
quantizer = copy.deepcopy(self.quantizer)
|
||||||
|
quantizer.find_params(W[:, i:(i + group_size)], weight=True)
|
||||||
|
scale.append(quantizer.scale)
|
||||||
|
zero.append(quantizer.zero)
|
||||||
|
groups.append(quantizer)
|
||||||
|
|
||||||
if actorder:
|
if actorder:
|
||||||
perm = torch.argsort(torch.diag(H), descending=True)
|
perm = torch.argsort(torch.diag(H), descending=True)
|
||||||
W = W[:, perm]
|
W = W[:, perm]
|
||||||
H = H[perm][:, perm]
|
H = H[perm][:, perm]
|
||||||
|
invperm = torch.argsort(perm)
|
||||||
|
|
||||||
Losses = torch.zeros_like(W)
|
Losses = torch.zeros_like(W)
|
||||||
Q = torch.zeros_like(W)
|
Q = torch.zeros_like(W)
|
||||||
|
@ -96,11 +112,6 @@ class GPTQ:
|
||||||
H = torch.linalg.cholesky(H, upper=True)
|
H = torch.linalg.cholesky(H, upper=True)
|
||||||
Hinv = H
|
Hinv = H
|
||||||
|
|
||||||
g_idx = []
|
|
||||||
scale = []
|
|
||||||
zero = []
|
|
||||||
now_idx = 1
|
|
||||||
|
|
||||||
for i1 in range(0, self.columns, blocksize):
|
for i1 in range(0, self.columns, blocksize):
|
||||||
i2 = min(i1 + blocksize, self.columns)
|
i2 = min(i1 + blocksize, self.columns)
|
||||||
count = i2 - i1
|
count = i2 - i1
|
||||||
|
@ -116,13 +127,19 @@ class GPTQ:
|
||||||
d = Hinv1[i, i]
|
d = Hinv1[i, i]
|
||||||
|
|
||||||
if group_size != -1:
|
if group_size != -1:
|
||||||
if (i1 + i) % group_size == 0:
|
if not static_groups:
|
||||||
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
|
if (i1 + i) % group_size == 0:
|
||||||
|
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + group_size)], weight=True)
|
||||||
|
|
||||||
if ((i1 + i) // group_size) - now_idx == -1:
|
if ((i1 + i) // group_size) - now_idx == -1:
|
||||||
scale.append(self.quantizer.scale)
|
scale.append(self.quantizer.scale)
|
||||||
zero.append(self.quantizer.zero)
|
zero.append(self.quantizer.zero)
|
||||||
now_idx += 1
|
now_idx += 1
|
||||||
|
else:
|
||||||
|
idx = i1 + i
|
||||||
|
if actorder:
|
||||||
|
idx = perm[idx]
|
||||||
|
self.quantizer = groups[idx // group_size]
|
||||||
|
|
||||||
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
|
||||||
Q1[:, i] = q
|
Q1[:, i] = q
|
||||||
|
@ -148,10 +165,12 @@ class GPTQ:
|
||||||
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
|
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
|
||||||
|
|
||||||
group_size = group_size if group_size != -1 else self.columns
|
group_size = group_size if group_size != -1 else self.columns
|
||||||
g_idx = [i // group_size for i in range(self.columns)]
|
if static_groups and actorder:
|
||||||
|
g_idx = [perm[i] // group_size for i in range(self.columns)]
|
||||||
|
else:
|
||||||
|
g_idx = [i // group_size for i in range(self.columns)]
|
||||||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
||||||
if actorder:
|
if actorder:
|
||||||
invperm = torch.argsort(perm)
|
|
||||||
Q = Q[:, invperm]
|
Q = Q[:, invperm]
|
||||||
g_idx = g_idx[invperm]
|
g_idx = g_idx[invperm]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue