diff --git a/auto_gptq/nn_modules/qlinear.py b/auto_gptq/nn_modules/qlinear.py index 593b918..c2705f7 100644 --- a/auto_gptq/nn_modules/qlinear.py +++ b/auto_gptq/nn_modules/qlinear.py @@ -11,7 +11,7 @@ try: import quant_cuda _quant_cuda_available = True -except: +except ImportError: logger.warning('CUDA extension not installed.') _quant_cuda_available = False @@ -92,9 +92,12 @@ class QuantLinear(nn.Module): intweight = [] for idx in range(self.infeatures): - intweight.append(torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to( - torch.int)[:, None]) + intweight.append( + torch.round( + ( + linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) diff --git a/auto_gptq/nn_modules/qlinear_triton.py b/auto_gptq/nn_modules/qlinear_triton.py index a33690b..6a9178e 100644 --- a/auto_gptq/nn_modules/qlinear_triton.py +++ b/auto_gptq/nn_modules/qlinear_triton.py @@ -272,6 +272,7 @@ try: tl.store(c_ptrs, accumulator, mask=c_mask) except ImportError: logger.warning('triton not installed.') + raise def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): @@ -336,23 +337,40 @@ class QuantLinearFunction(torch.autograd.Function): class QuantLinear(nn.Module): - def __init__(self, bits, groupsize, infeatures, outfeatures, bias): + def __init__( + self, + bits, + groupsize, + infeatures, + outfeatures, + bias + ): super().__init__() if bits not in [2, 4, 8]: raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures self.outfeatures = outfeatures self.bits = bits - self.maxq = 2 ** self.bits - 1 self.groupsize = groupsize if groupsize != -1 else infeatures + self.maxq = 2 ** self.bits - 1 - self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) - self.register_buffer('qzeros', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), - dtype=torch.int32)) - self.register_buffer('scales', - torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) - self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + self.register_buffer( + 'qweight', + torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32) + ) + self.register_buffer( + 'qzeros', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32) + ) + self.register_buffer( + 'scales', + torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16) + ) + self.register_buffer( + 'g_idx', + torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32) + ) if bias: self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) else: @@ -370,9 +388,12 @@ class QuantLinear(nn.Module): intweight = [] for idx in range(self.infeatures): - intweight.append(torch.round( - (linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to( - torch.int)[:, None]) + intweight.append( + torch.round( + ( + linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]] + ).to(torch.int)[:, None] + ) intweight = torch.cat(intweight, dim=1) intweight = intweight.t().contiguous() intweight = intweight.numpy().astype(np.uint32) @@ -422,8 +443,9 @@ class QuantLinear(nn.Module): self.bits, self.maxq ) + out = out.reshape(out_shape) out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) + return out def autotune_warmup_linear(model, transpose=False, seqlen=2048):