Modify qlinear_cuda for tracing the GPTQ model (#367)
Changes: -- The change to the torch.bitwise_and is done because during tracing this model the current usage of the torch.bitwise_and result in an in-place variant of this op, resulting in an issue during the downstream lowering pipeline of the traced model via Torch-MLIR and IREE-SHARK. That's why the op usage is changed to not result in an in-place variaunt. -- The change to the torch.matmul call in the forward function is done because currently, it assumes that the weights will always be of fp16 type. But, when the model is executed for the float32 weights it results in an error. That's why the current change cast the LHS of the matmul to the same type as the RHS one. Both the above changes doesn't affect the model in any way. Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
This commit is contained in:
parent
51c043c6be
commit
e4b2493733
2 changed files with 8 additions and 8 deletions
|
@ -219,7 +219,7 @@ class QuantLinear(nn.Module):
|
|||
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
|
||||
self.wf.unsqueeze(0)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(self.scales.shape)
|
||||
|
@ -228,7 +228,7 @@ class QuantLinear(nn.Module):
|
|||
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
|
||||
self.wf.unsqueeze(-1)
|
||||
).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(weight, (2 ** self.bits) - 1, out=weight)
|
||||
weight = torch.bitwise_and(weight, (2 ** self.bits) - 1)
|
||||
elif self.bits == 3:
|
||||
zeros = self.qzeros.reshape(
|
||||
self.qzeros.shape[0], self.qzeros.shape[1] // 3, 3, 1
|
||||
|
@ -267,10 +267,10 @@ class QuantLinear(nn.Module):
|
|||
g_idx_i = self.g_idx[i*num_dim:(i+1)*num_dim]
|
||||
weights.append(scale_i[g_idx_i.long()] * (weight_i - zeros_i[g_idx_i.long()]))
|
||||
weights = torch.cat(weights,dim=1)
|
||||
out = torch.matmul(x.half(), weights)
|
||||
out = torch.matmul(x.to(weights.dtype), weights)
|
||||
out = out.half().reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out
|
||||
return out.to(x.dtype)
|
||||
|
||||
|
||||
__all__ = ["QuantLinear"]
|
||||
|
|
|
@ -229,7 +229,7 @@ class QuantLinear(nn.Module):
|
|||
|
||||
if self.bits in [2,4,8]:
|
||||
zeros = torch.bitwise_right_shift(torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits), self.wf.unsqueeze(0)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(zeros, (2 ** self.bits) - 1, out=zeros)
|
||||
zeros = torch.bitwise_and(zeros, (2 ** self.bits) - 1)
|
||||
|
||||
zeros = zeros + 1
|
||||
zeros = zeros.reshape(-1, 1, zeros.shape[1] * zeros.shape[2])
|
||||
|
@ -238,7 +238,7 @@ class QuantLinear(nn.Module):
|
|||
scales = scales.reshape(-1, 1, scales.shape[-1])
|
||||
|
||||
weight = torch.bitwise_right_shift(torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1), self.wf.unsqueeze(-1)).to(torch.int16 if self.bits == 8 else torch.int8)
|
||||
torch.bitwise_and(weight,(2 ** self.bits) - 1, out=weight)
|
||||
weight = torch.bitwise_and(weight,(2 ** self.bits) - 1)
|
||||
weight = weight.reshape(-1, self.group_size, weight.shape[2])
|
||||
elif self.bits == 3:
|
||||
zeros = self.qzeros.reshape(self.qzeros.shape[0], self.qzeros.shape[1]//3, 3, 1).expand(-1, -1, -1, 12)
|
||||
|
@ -266,10 +266,10 @@ class QuantLinear(nn.Module):
|
|||
weight = (scales * (weight - zeros))
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
out = torch.matmul(x.half(), weight)
|
||||
out = torch.matmul(x.to(weight.dtype), weight)
|
||||
out = out.half().reshape(out_shape)
|
||||
out = out + self.bias if self.bias is not None else out
|
||||
return out
|
||||
return out.to(x.dtype)
|
||||
|
||||
|
||||
__all__ = ["QuantLinear"]
|
||||
|
|
Loading…
Add table
Reference in a new issue