GPTQ backward compatibility support
This commit is contained in:
parent
9e0682a63e
commit
94de4ef185
1 changed files with 21 additions and 0 deletions
|
@ -939,6 +939,27 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
checkpoint
|
checkpoint
|
||||||
)
|
)
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
|
# Preprocessing for backward compatibility
|
||||||
|
if quantize_config.sym:
|
||||||
|
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, disable_exllama=disable_exllama, use_qigen=use_qigen,
|
||||||
|
desc_act=quantize_config.desc_act, group_size=quantize_config.group_size, bits=quantize_config.bits)
|
||||||
|
for name, submodule in model.named_modules():
|
||||||
|
if isinstance(submodule, QuantLinear):
|
||||||
|
if use_qigen:
|
||||||
|
submodule.zeros.data = torch.full_like(submodule.zeros.data, (torch.tensor(2 ** quantize_config.bits - 1) + 1) / 2)
|
||||||
|
else:
|
||||||
|
if quantize_config.bits == 2:
|
||||||
|
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -1431655766)
|
||||||
|
elif quantize_config.bits == 3:
|
||||||
|
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] = 613566756
|
||||||
|
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] = 1227133513
|
||||||
|
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] = -1840700270
|
||||||
|
elif quantize_config.bits == 4:
|
||||||
|
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2004318072)
|
||||||
|
elif quantize_config.bits == 8:
|
||||||
|
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2139062144)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||||
# == step4: set seqlen == #
|
# == step4: set seqlen == #
|
||||||
model_config = model.config.to_dict()
|
model_config = model.config.to_dict()
|
||||||
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||||
|
|
Loading…
Add table
Reference in a new issue