diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index 944655d..d13c651 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -10,6 +10,7 @@ import torch.nn as nn import transformers from safetensors.torch import load_file as safe_load, save_file as safe_save from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel +from transformers.utils.hub import PushToHubMixin from ._const import * from ._utils import * @@ -19,7 +20,7 @@ logger = getLogger(__name__) @dataclass -class BaseQuantizeConfig: +class BaseQuantizeConfig(PushToHubMixin): bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]}) damp_percent: float = field(default=0.01) desc_act: bool = field(default=True) @@ -35,7 +36,7 @@ class BaseQuantizeConfig: if self.group_size != -1 and self.group_size <= 0: raise ValueError("unless equal to -1, group_size must greater then 0.") - def save_pretrained(self, save_dir: str): + def save_pretrained(self, save_dir: str, **kwargs): with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2) @@ -53,7 +54,7 @@ class BaseQuantizeConfig: } -class BaseGPTQForCausalLM(nn.Module): +class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): layers_block_name: str = None outside_layer_modules: List[str] = None inside_layer_modules: List[List[str]] = None @@ -293,6 +294,11 @@ class BaseGPTQForCausalLM(nn.Module): self.model.config.save_pretrained(save_dir) self.quantize_config.save_pretrained(save_dir) + def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs): + """alias of save_quantized""" + logger.warning("you are using save_pretrained, which will re-direct to save_quantized.") + self.save_quantized(save_dir, use_safetensors) + @classmethod def from_pretrained( cls, @@ -322,7 +328,7 @@ class BaseGPTQForCausalLM(nn.Module): model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs) model_config = model.config.to_dict() - seq_len_keys = ["max_position_embeddings", "seq_length"] + seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"] if any([k in model_config for k in seq_len_keys]): for key in seq_len_keys: if key in model_config: diff --git a/auto_gptq/modeling/auto.py b/auto_gptq/modeling/auto.py index fd6e044..8ba4da6 100644 --- a/auto_gptq/modeling/auto.py +++ b/auto_gptq/modeling/auto.py @@ -1,4 +1,4 @@ -from ._base import BaseQuantizeConfig +from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM from ._utils import check_and_get_model_type from .bloom import BloomGPTQForCausalLM from .gpt_neox import GPTNeoXGPTQForCausalLM @@ -33,7 +33,7 @@ class AutoGPTQForCausalLM: quantize_config: BaseQuantizeConfig, bf16: bool = False, **model_init_kwargs - ): + ) -> BaseGPTQForCausalLM: model_type = check_and_get_model_type(pretrained_model_name_or_path) return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, @@ -49,7 +49,7 @@ class AutoGPTQForCausalLM: device: str = "cpu", use_safetensors: bool = False, use_triton: bool = False - ): + ) -> BaseGPTQForCausalLM: model_type = check_and_get_model_type(save_dir) return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( save_dir=save_dir, diff --git a/examples/README.md b/examples/README.md index 1954c2b..4aa1a65 100644 --- a/examples/README.md +++ b/examples/README.md @@ -51,3 +51,22 @@ CUDA_VISIBLE_DEVICES=0 python run_text_summarization_task.py --base_model_dir PA ``` Use `--help` flag to see detailed descriptions for more command arguments. + +## Push To Hub +> Commands in this chapter should be run under `push_to_hub` folder. + +You can upload and share your quantized model to Hugging Face Hub by using `push_to_hub` function. + +`push_quantized_model_to_hf_hub.py` provide a simple example to upload quantized model, tokenizer and configs at once. + +First, you need to login, run the following command in the virtual environment where Hugging Face Transformers is installed: +```shell +huggingface-cli login +``` + +Then run the script like this: +```shell +python push_quantized_model_to_hf_hub.py --quantized_model_dir PATH/TO/QUANTIZED/MODEL/DIR --tokenizer_dir PATH/TO/TOKENIZER/DIR --repo_id REPO/ID +``` + +Use `--help` flag to see detailed descriptions for more command arguments. \ No newline at end of file diff --git a/examples/push_to_hub/push_quantized_model_to_hf_hub.py b/examples/push_to_hub/push_quantized_model_to_hf_hub.py new file mode 100644 index 0000000..4fb08f2 --- /dev/null +++ b/examples/push_to_hub/push_quantized_model_to_hf_hub.py @@ -0,0 +1,55 @@ +from argparse import ArgumentParser + +from auto_gptq import AutoGPTQForCausalLM +from transformers import AutoTokenizer + + +def main(): + parser = ArgumentParser() + parser.add_argument("--quantized_model_dir", type=str, help="Directory that saves quantized model.") + parser.add_argument("--repo_id", type=str, help="The name of the repository you want to push to.") + parser.add_argument( + "--tokenizer_dir", + type=str, + default=None, + help="Directory that saves tokenizer, defaults to None, will not upload tokenizer if not specified." + ) + parser.add_argument("--commit_message", type=str, default=None, help="Message to commit while pushing.") + parser.add_argument( + "--device", + type=str, + default="cpu", + choices=["cpu", "cuda"], + help="Which device to load the model." + ) + parser.add_argument( + "--private", + action="store_true", + help="Whether or not the repository created should be private." + ) + parser.add_argument( + "--use_temp_dir", + action="store_true", + help="Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub." + ) + args = parser.parse_args() + + push_to_hub_kwargs = { + "repo_id": args.repo_id, + "commit_message": args.commit_message, + "private": args.private, + "use_temp_dir": args.use_temp_dir + } + + model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device=args.device) + model.push_to_hub(**push_to_hub_kwargs) + model.config.push_to_hub(**push_to_hub_kwargs) + model.quantize_config.push_to_hub(**push_to_hub_kwargs) + + if args.tokenizer_dir: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) + tokenizer.push_to_hub(**push_to_hub_kwargs) + + +if __name__ == "__main__": + main()