GPTQ backward compatibility support
This commit is contained in:
parent
9e0682a63e
commit
94de4ef185
1 changed files with 21 additions and 0 deletions
|
@ -939,6 +939,27 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
checkpoint
|
||||
)
|
||||
model.load_state_dict(checkpoint)
|
||||
# Preprocessing for backward compatibility
|
||||
if quantize_config.sym:
|
||||
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, disable_exllama=disable_exllama, use_qigen=use_qigen,
|
||||
desc_act=quantize_config.desc_act, group_size=quantize_config.group_size, bits=quantize_config.bits)
|
||||
for name, submodule in model.named_modules():
|
||||
if isinstance(submodule, QuantLinear):
|
||||
if use_qigen:
|
||||
submodule.zeros.data = torch.full_like(submodule.zeros.data, (torch.tensor(2 ** quantize_config.bits - 1) + 1) / 2)
|
||||
else:
|
||||
if quantize_config.bits == 2:
|
||||
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -1431655766)
|
||||
elif quantize_config.bits == 3:
|
||||
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] = 613566756
|
||||
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] = 1227133513
|
||||
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] = -1840700270
|
||||
elif quantize_config.bits == 4:
|
||||
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2004318072)
|
||||
elif quantize_config.bits == 8:
|
||||
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2139062144)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
# == step4: set seqlen == #
|
||||
model_config = model.config.to_dict()
|
||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||
|
|
Loading…
Add table
Reference in a new issue