Update _base.py

fix problem that recursively adding file extension to model_base_name
This commit is contained in:
潘其威(William) 2023-05-30 07:26:42 +08:00 committed by GitHub
parent cfa7271617
commit 15db2cdc44
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -504,20 +504,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
self.model.to(CPU)
model_save_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
if use_safetensors:
model_save_name += ".safetensors"
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
safe_save(state_dict, join(save_dir, model_save_name))
else:
model_save_name += ".bin"
model_save_name = model_base_name + ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
self.model.config.save_pretrained(save_dir)
self.quantize_config.save_pretrained(save_dir)
self.quantize_config.model_name_or_path = save_dir
self.quantize_config.model_file_base_name = model_save_name
self.quantize_config.model_file_base_name = model_base_name
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs):
"""alias of save_quantized"""