GPTQ backward compatibility support

This commit is contained in:
qwopqwop200 2023-09-08 10:16:29 +09:00 committed by GitHub
parent 9e0682a63e
commit 94de4ef185
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -939,6 +939,27 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
checkpoint
)
model.load_state_dict(checkpoint)
# Preprocessing for backward compatibility
if quantize_config.sym:
QuantLinear = dynamically_import_QuantLinear(use_triton=use_triton, disable_exllama=disable_exllama, use_qigen=use_qigen,
desc_act=quantize_config.desc_act, group_size=quantize_config.group_size, bits=quantize_config.bits)
for name, submodule in model.named_modules():
if isinstance(submodule, QuantLinear):
if use_qigen:
submodule.zeros.data = torch.full_like(submodule.zeros.data, (torch.tensor(2 ** quantize_config.bits - 1) + 1) / 2)
else:
if quantize_config.bits == 2:
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -1431655766)
elif quantize_config.bits == 3:
submodule.qzeros.data[:,range(0,submodule.qzeros.data.shape[1],3)] = 613566756
submodule.qzeros.data[:,range(1,submodule.qzeros.data.shape[1],3)] = 1227133513
submodule.qzeros.data[:,range(2,submodule.qzeros.data.shape[1],3)] = -1840700270
elif quantize_config.bits == 4:
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2004318072)
elif quantize_config.bits == 8:
submodule.qzeros.data = torch.full_like(submodule.qzeros.data, -2139062144)
else:
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
# == step4: set seqlen == #
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]