add new args of save_quantized method to push_to_hub method

This commit is contained in:
student686 2023-10-07 13:59:53 +08:00
parent fc1184e7bc
commit 22af50bab0

View file

@ -468,6 +468,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None,
create_pr: Optional[bool] = False,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
) -> str:
"""
Upload the model to the Hugging Face Hub.
@ -505,7 +507,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if save_dir is not None:
logger.info(f"Saving model to {save_dir}")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
self.save_quantized(save_dir, use_safetensors, safetensors_metadata, max_shard_size, model_base_name)
repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model"