optimize import and format code
This commit is contained in:
parent
c6dee93f5d
commit
73cb1dbf09
2 changed files with 42 additions and 17 deletions
|
@ -11,7 +11,7 @@ try:
|
||||||
import quant_cuda
|
import quant_cuda
|
||||||
|
|
||||||
_quant_cuda_available = True
|
_quant_cuda_available = True
|
||||||
except:
|
except ImportError:
|
||||||
logger.warning('CUDA extension not installed.')
|
logger.warning('CUDA extension not installed.')
|
||||||
_quant_cuda_available = False
|
_quant_cuda_available = False
|
||||||
|
|
||||||
|
@ -92,9 +92,12 @@ class QuantLinear(nn.Module):
|
||||||
|
|
||||||
intweight = []
|
intweight = []
|
||||||
for idx in range(self.infeatures):
|
for idx in range(self.infeatures):
|
||||||
intweight.append(torch.round(
|
intweight.append(
|
||||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(
|
torch.round(
|
||||||
torch.int)[:, None])
|
(
|
||||||
|
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
|
||||||
|
).to(torch.int)[:, None]
|
||||||
|
)
|
||||||
intweight = torch.cat(intweight, dim=1)
|
intweight = torch.cat(intweight, dim=1)
|
||||||
intweight = intweight.t().contiguous()
|
intweight = intweight.t().contiguous()
|
||||||
intweight = intweight.numpy().astype(np.uint32)
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
|
|
|
@ -272,6 +272,7 @@ try:
|
||||||
tl.store(c_ptrs, accumulator, mask=c_mask)
|
tl.store(c_ptrs, accumulator, mask=c_mask)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning('triton not installed.')
|
logger.warning('triton not installed.')
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
|
||||||
|
@ -336,23 +337,40 @@ class QuantLinearFunction(torch.autograd.Function):
|
||||||
|
|
||||||
class QuantLinear(nn.Module):
|
class QuantLinear(nn.Module):
|
||||||
|
|
||||||
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
|
def __init__(
|
||||||
|
self,
|
||||||
|
bits,
|
||||||
|
groupsize,
|
||||||
|
infeatures,
|
||||||
|
outfeatures,
|
||||||
|
bias
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if bits not in [2, 4, 8]:
|
if bits not in [2, 4, 8]:
|
||||||
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
raise NotImplementedError("Only 2,4,8 bits are supported.")
|
||||||
|
|
||||||
self.infeatures = infeatures
|
self.infeatures = infeatures
|
||||||
self.outfeatures = outfeatures
|
self.outfeatures = outfeatures
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
self.maxq = 2 ** self.bits - 1
|
|
||||||
self.groupsize = groupsize if groupsize != -1 else infeatures
|
self.groupsize = groupsize if groupsize != -1 else infeatures
|
||||||
|
self.maxq = 2 ** self.bits - 1
|
||||||
|
|
||||||
self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
|
self.register_buffer(
|
||||||
self.register_buffer('qzeros',
|
'qweight',
|
||||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits),
|
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)
|
||||||
dtype=torch.int32))
|
)
|
||||||
self.register_buffer('scales',
|
self.register_buffer(
|
||||||
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
|
'qzeros',
|
||||||
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
|
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'scales',
|
||||||
|
torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
'g_idx',
|
||||||
|
torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)
|
||||||
|
)
|
||||||
if bias:
|
if bias:
|
||||||
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
|
||||||
else:
|
else:
|
||||||
|
@ -370,9 +388,12 @@ class QuantLinear(nn.Module):
|
||||||
|
|
||||||
intweight = []
|
intweight = []
|
||||||
for idx in range(self.infeatures):
|
for idx in range(self.infeatures):
|
||||||
intweight.append(torch.round(
|
intweight.append(
|
||||||
(linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(
|
torch.round(
|
||||||
torch.int)[:, None])
|
(
|
||||||
|
linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]
|
||||||
|
).to(torch.int)[:, None]
|
||||||
|
)
|
||||||
intweight = torch.cat(intweight, dim=1)
|
intweight = torch.cat(intweight, dim=1)
|
||||||
intweight = intweight.t().contiguous()
|
intweight = intweight.t().contiguous()
|
||||||
intweight = intweight.numpy().astype(np.uint32)
|
intweight = intweight.numpy().astype(np.uint32)
|
||||||
|
@ -422,8 +443,9 @@ class QuantLinear(nn.Module):
|
||||||
self.bits,
|
self.bits,
|
||||||
self.maxq
|
self.maxq
|
||||||
)
|
)
|
||||||
|
out = out.reshape(out_shape)
|
||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out
|
||||||
|
|
||||||
|
|
||||||
def autotune_warmup_linear(model, transpose=False, seqlen=2048):
|
def autotune_warmup_linear(model, transpose=False, seqlen=2048):
|
||||||
|
|
Loading…
Add table
Reference in a new issue