diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index c018b97..87e76ec 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -504,20 +504,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): self.model.to(CPU) - model_save_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" + model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" if use_safetensors: - model_save_name += ".safetensors" + model_save_name = model_base_name + ".safetensors" state_dict = self.model.state_dict() state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} safe_save(state_dict, join(save_dir, model_save_name)) else: - model_save_name += ".bin" + model_save_name = model_base_name + ".bin" torch.save(self.model.state_dict(), join(save_dir, model_save_name)) self.model.config.save_pretrained(save_dir) self.quantize_config.save_pretrained(save_dir) self.quantize_config.model_name_or_path = save_dir - self.quantize_config.model_file_base_name = model_save_name + self.quantize_config.model_file_base_name = model_base_name def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs): """alias of save_quantized"""