fix bug
This commit is contained in:
parent
f752336cda
commit
ad5b0d72ee
6 changed files with 1 additions and 11 deletions
|
@ -39,7 +39,7 @@ class BaseQuantizeConfig(PushToHubMixin):
|
|||
damp_percent: float = field(default=0.01)
|
||||
desc_act: bool = field(default=True)
|
||||
static_groups: bool = field(default=False)
|
||||
sym: bool = field(default=True)
|
||||
sym: bool = field(default=False)
|
||||
true_sequential: bool = field(default=True)
|
||||
model_name_or_path: Optional[str] = field(default=None)
|
||||
model_file_base_name: Optional[str] = field(default=None)
|
||||
|
|
|
@ -157,7 +157,6 @@ class QuantLinear(nn.Module):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
@ -221,7 +220,6 @@ class QuantLinear(nn.Module):
|
|||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(self.scales.shape)
|
||||
|
||||
weight = torch.bitwise_right_shift(
|
||||
|
@ -239,7 +237,6 @@ class QuantLinear(nn.Module):
|
|||
zeros = zeros & 0x7
|
||||
zeros = torch.cat([zeros[:, :, 0, :11], zeros[:, :, 1, 1:12], zeros[:, :, 2, 1:11]], dim=2)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(self.scales.shape)
|
||||
|
||||
weight = self.qweight.reshape(
|
||||
|
|
|
@ -157,7 +157,6 @@ class QuantLinear(nn.Module):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
@ -231,7 +230,6 @@ class QuantLinear(nn.Module):
|
|||
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
||||
scales = self.scales
|
||||
|
@ -248,7 +246,6 @@ class QuantLinear(nn.Module):
|
|||
zeros = zeros & 0x7
|
||||
zeros = torch.cat([zeros[:,:,0,:11], zeros[:,:,1,1:12], zeros[:,:,2,1:11]], dim=2)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
||||
scales = self.scales
|
||||
|
|
|
@ -146,7 +146,6 @@ class QuantLinear(nn.Module):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
|
|
@ -114,7 +114,6 @@ class QuantLinear(nn.Module, TritonModuleMixin):
|
|||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
|
||||
i = 0
|
||||
|
|
|
@ -144,7 +144,6 @@ def quant_matmul_248_kernel(
|
|||
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
@ -290,7 +289,6 @@ def transpose_quant_matmul_248_kernel(
|
|||
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
|
||||
|
||||
zeros = (zeros >> zeros_shifter[None, :]) & maxq
|
||||
zeros = (zeros + 1)
|
||||
|
||||
a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
|
||||
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
|
||||
|
|
Loading…
Add table
Reference in a new issue