Update _base.py
fix problem that recursively adding file extension to model_base_name
This commit is contained in:
parent
cfa7271617
commit
15db2cdc44
1 changed files with 4 additions and 4 deletions
|
@ -504,20 +504,20 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
|
||||
self.model.to(CPU)
|
||||
|
||||
model_save_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||
if use_safetensors:
|
||||
model_save_name += ".safetensors"
|
||||
model_save_name = model_base_name + ".safetensors"
|
||||
state_dict = self.model.state_dict()
|
||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||
safe_save(state_dict, join(save_dir, model_save_name))
|
||||
else:
|
||||
model_save_name += ".bin"
|
||||
model_save_name = model_base_name + ".bin"
|
||||
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
||||
|
||||
self.model.config.save_pretrained(save_dir)
|
||||
self.quantize_config.save_pretrained(save_dir)
|
||||
self.quantize_config.model_name_or_path = save_dir
|
||||
self.quantize_config.model_file_base_name = model_save_name
|
||||
self.quantize_config.model_file_base_name = model_base_name
|
||||
|
||||
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs):
|
||||
"""alias of save_quantized"""
|
||||
|
|
Loading…
Add table
Reference in a new issue