From 15db2cdc444419e83516d7d98f6f11cd96e90cce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BD=98=E5=85=B6=E5=A8=81=28William=29?= <46810637+PanQiWei@users.noreply.github.com> Date: Tue, 30 May 2023 07:26:42 +0800 Subject: [PATCH] Update _base.py fix problem that recursively adding file extension to model_base_name --- auto_gptq/modeling/_base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index c018b97..87e76ec 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -504,20 +504,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): self.model.to(CPU) - model_save_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" + model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" if use_safetensors: - model_save_name += ".safetensors" + model_save_name = model_base_name + ".safetensors" state_dict = self.model.state_dict() state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} safe_save(state_dict, join(save_dir, model_save_name)) else: - model_save_name += ".bin" + model_save_name = model_base_name + ".bin" torch.save(self.model.state_dict(), join(save_dir, model_save_name)) self.model.config.save_pretrained(save_dir) self.quantize_config.save_pretrained(save_dir) self.quantize_config.model_name_or_path = save_dir - self.quantize_config.model_file_base_name = model_save_name + self.quantize_config.model_file_base_name = model_base_name def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs): """alias of save_quantized"""