optimize import and format code

This commit is contained in:
PanQiWei 2023-04-26 13:08:47 +08:00
parent c6dee93f5d
commit 73cb1dbf09
2 changed files with 42 additions and 17 deletions

View file

@ -11,7 +11,7 @@ try:
import quant_cuda import quant_cuda
_quant_cuda_available = True _quant_cuda_available = True
except: except ImportError:
logger.warning('CUDA extension not installed.') logger.warning('CUDA extension not installed.')
_quant_cuda_available = False _quant_cuda_available = False
@ -92,9 +92,12 @@ class QuantLinear(nn.Module):
intweight = [] intweight = []
for idx in range(self.infeatures): for idx in range(self.infeatures):
intweight.append(torch.round( intweight.append(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to( torch.round(
torch.int)[:, None]) (
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32) intweight = intweight.numpy().astype(np.uint32)

View file

@ -272,6 +272,7 @@ try:
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
except ImportError: except ImportError:
logger.warning('triton not installed.') logger.warning('triton not installed.')
raise
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
@ -336,23 +337,40 @@ class QuantLinearFunction(torch.autograd.Function):
class QuantLinear(nn.Module): class QuantLinear(nn.Module):
def __init__(self, bits, groupsize, infeatures, outfeatures, bias): def __init__(
self,
bits,
groupsize,
infeatures,
outfeatures,
bias
):
super().__init__() super().__init__()
if bits not in [2, 4, 8]: if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.") raise NotImplementedError("Only 2,4,8 bits are supported.")
self.infeatures = infeatures self.infeatures = infeatures
self.outfeatures = outfeatures self.outfeatures = outfeatures
self.bits = bits self.bits = bits
self.maxq = 2 ** self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures self.groupsize = groupsize if groupsize != -1 else infeatures
self.maxq = 2 ** self.bits - 1
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) self.register_buffer(
self.register_buffer('qzeros', 'qweight',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
dtype=torch.int32)) )
self.register_buffer('scales', self.register_buffer(
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) 'qzeros',
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
)
self.register_buffer(
'scales',
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
)
self.register_buffer(
'g_idx',
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
)
if bias: if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else: else:
@ -370,9 +388,12 @@ class QuantLinear(nn.Module):
intweight = [] intweight = []
for idx in range(self.infeatures): for idx in range(self.infeatures):
intweight.append(torch.round( intweight.append(
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to( torch.round(
torch.int)[:, None]) (
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32) intweight = intweight.numpy().astype(np.uint32)
@ -422,8 +443,9 @@ class QuantLinear(nn.Module):
self.bits, self.bits,
self.maxq self.maxq
) )
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out
def autotune_warmup_linear(model, transpose=False, seqlen=2048): def autotune_warmup_linear(model, transpose=False, seqlen=2048):