optimize import and format code
This commit is contained in:
parent
c6dee93f5d
commit
73cb1dbf09
2 changed files with 42 additions and 17 deletions
|
@ -11,7 +11,7 @@ try:
|
|||
import quant_cuda
|
||||
|
||||
_quant_cuda_available = True
|
||||
except:
|
||||
except ImportError:
|
||||
logger.warning('CUDA extension not installed.')
|
||||
_quant_cuda_available = False
|
||||
|
||||
|
@ -92,9 +92,12 @@ class QuantLinear(nn.Module):
|
|||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(
|
||||
torch.int)[:, None])
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(
|
||||
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
|
|
|
@ -272,6 +272,7 @@ try:
|
|||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||
except ImportError:
|
||||
logger.warning('triton not installed.')
|
||||
raise
|
||||
|
||||
|
||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||
|
@ -336,23 +337,40 @@ class QuantLinearFunction(torch.autograd.Function):
|
|||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
|
||||
def __init__(
|
||||
self,
|
||||
bits,
|
||||
groupsize,
|
||||
infeatures,
|
||||
outfeatures,
|
||||
bias
|
||||
):
|
||||
super().__init__()
|
||||
if bits not in [2, 4, 8]:
|
||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.maxq = 2 ** self.bits - 1
|
||||
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
||||
self.register_buffer('qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits),
|
||||
dtype=torch.int32))
|
||||
self.register_buffer('scales',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
||||
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
||||
self.register_buffer(
|
||||
'qweight',
|
||||
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'qzeros',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||
)
|
||||
self.register_buffer(
|
||||
'scales',
|
||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||
)
|
||||
self.register_buffer(
|
||||
'g_idx',
|
||||
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||
)
|
||||
if bias:
|
||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||
else:
|
||||
|
@ -370,9 +388,12 @@ class QuantLinear(nn.Module):
|
|||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
intweight.append(torch.round(
|
||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(
|
||||
torch.int)[:, None])
|
||||
intweight.append(
|
||||
torch.round(
|
||||
(
|
||||
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
|
||||
).to(torch.int)[:, None]
|
||||
)
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
|
@ -422,8 +443,9 @@ class QuantLinear(nn.Module):
|
|||
self.bits,
|
||||
self.maxq
|
||||
)
|
||||
out = out.reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out.reshape(out_shape)
|
||||
return out
|
||||
|
||||
|
||||
def autotune_warmup_linear(model, transpose=False, seqlen=2048):
|
||||
|
|
Loading…
Add table
Reference in a new issue