180 lines
5.7 KiB
Python
180 lines
5.7 KiB
Python
import math
|
|
import os
|
|
import time
|
|
from logging import getLogger
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import transformers
|
|
|
|
from .quant import *
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
torch.backends.cudnn.allow_tf32 = False
|
|
|
|
|
|
class GPTQ:
|
|
def __init__(self, layer):
|
|
self.layer = layer
|
|
self.dev = self.layer.weight.device
|
|
W = layer.weight.data.clone()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
W = W.flatten(1)
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
W = W.t()
|
|
self.rows = W.shape[0]
|
|
self.columns = W.shape[1]
|
|
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
|
|
self.nsamples = 0
|
|
|
|
def add_batch(self, inp, out):
|
|
if os.environ.get("DEBUG"):
|
|
self.inp1 = inp
|
|
self.out1 = out
|
|
if len(inp.shape) == 2:
|
|
inp = inp.unsqueeze(0)
|
|
tmp = inp.shape[0]
|
|
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
|
|
if len(inp.shape) == 3:
|
|
inp = inp.reshape((-1, inp.shape[-1]))
|
|
inp = inp.t()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
unfold = nn.Unfold(
|
|
self.layer.kernel_size,
|
|
dilation=self.layer.dilation,
|
|
padding=self.layer.padding,
|
|
stride=self.layer.stride
|
|
)
|
|
inp = unfold(inp)
|
|
inp = inp.permute([1, 0, 2])
|
|
inp = inp.flatten(1)
|
|
self.H *= self.nsamples / (self.nsamples + tmp)
|
|
self.nsamples += tmp
|
|
# inp = inp.float()
|
|
inp = math.sqrt(2 / self.nsamples) * inp.float()
|
|
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
|
|
self.H += inp.matmul(inp.t())
|
|
|
|
def fasterquant(
|
|
self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False
|
|
):
|
|
W = self.layer.weight.data.clone()
|
|
if isinstance(self.layer, nn.Conv2d):
|
|
W = W.flatten(1)
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
W = W.t()
|
|
W = W.float()
|
|
|
|
tick = time.time()
|
|
|
|
if not self.quantizer.ready():
|
|
self.quantizer.find_params(W, weight=True)
|
|
|
|
H = self.H
|
|
del self.H
|
|
dead = torch.diag(H) == 0
|
|
H[dead, dead] = 1
|
|
W[:, dead] = 0
|
|
|
|
if actorder:
|
|
perm = torch.argsort(torch.diag(H), descending=True)
|
|
W = W[:, perm]
|
|
H = H[perm][:, perm]
|
|
|
|
Losses = torch.zeros_like(W)
|
|
Q = torch.zeros_like(W)
|
|
|
|
damp = percdamp * torch.mean(torch.diag(H))
|
|
diag = torch.arange(self.columns, device=self.dev)
|
|
H[diag, diag] += damp
|
|
H = torch.linalg.cholesky(H)
|
|
H = torch.cholesky_inverse(H)
|
|
H = torch.linalg.cholesky(H, upper=True)
|
|
Hinv = H
|
|
|
|
g_idx = []
|
|
scale = []
|
|
zero = []
|
|
now_idx = 1
|
|
|
|
for i1 in range(0, self.columns, blocksize):
|
|
i2 = min(i1 + blocksize, self.columns)
|
|
count = i2 - i1
|
|
|
|
W1 = W[:, i1:i2].clone()
|
|
Q1 = torch.zeros_like(W1)
|
|
Err1 = torch.zeros_like(W1)
|
|
Losses1 = torch.zeros_like(W1)
|
|
Hinv1 = Hinv[i1:i2, i1:i2]
|
|
|
|
for i in range(count):
|
|
w = W1[:, i]
|
|
d = Hinv1[i, i]
|
|
|
|
if groupsize != -1:
|
|
if (i1 + i) % groupsize == 0:
|
|
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
|
|
|
|
if ((i1 + i) // groupsize) - now_idx == -1:
|
|
scale.append(self.quantizer.scale)
|
|
zero.append(self.quantizer.zero)
|
|
now_idx += 1
|
|
|
|
q = quantize(
|
|
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
|
|
).flatten()
|
|
Q1[:, i] = q
|
|
Losses1[:, i] = (w - q) ** 2 / d ** 2
|
|
|
|
err1 = (w - q) / d
|
|
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
|
|
Err1[:, i] = err1
|
|
|
|
Q[:, i1:i2] = Q1
|
|
Losses[:, i1:i2] = Losses1 / 2
|
|
|
|
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
|
|
|
|
if os.environ.get("DEBUG"):
|
|
self.layer.weight.data[:, :i2] = Q[:, :i2]
|
|
self.layer.weight.data[:, i2:] = W[:, i2:]
|
|
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
|
|
logger.debug(torch.sum(Losses))
|
|
|
|
torch.cuda.synchronize()
|
|
logger.info(f'duration: {(time.time() - tick)}')
|
|
logger.info(f'avg loss: {torch.sum(Losses).item() / self.nsamples}')
|
|
|
|
groupsize = groupsize if groupsize != -1 else self.columns
|
|
g_idx = [i // groupsize for i in range(self.columns)]
|
|
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
|
|
if actorder:
|
|
invperm = torch.argsort(perm)
|
|
Q = Q[:, invperm]
|
|
g_idx = g_idx[invperm]
|
|
|
|
if isinstance(self.layer, transformers.Conv1D):
|
|
Q = Q.t()
|
|
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)
|
|
if os.environ.get("DEBUG"):
|
|
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
|
|
|
|
if scale == []:
|
|
scale.append(self.quantizer.scale)
|
|
zero.append(self.quantizer.zero)
|
|
scale = torch.cat(scale, dim=1)
|
|
zero = torch.cat(zero, dim=1)
|
|
return scale, zero, g_idx
|
|
|
|
def free(self):
|
|
if os.environ.get("DEBUG"):
|
|
self.inp1 = None
|
|
self.out1 = None
|
|
self.H = None
|
|
self.Losses = None
|
|
self.Trace = None
|
|
torch.cuda.empty_cache()
|
|
|
|
__all__ = ["GPTQ"]
|