Merge pull request #18 from PanQiWei/push_to_hub_integration

push_to_hub integration
This commit is contained in:
潘其威(William) 2023-04-26 17:52:45 +08:00 committed by GitHub
commit 1acb0c5eba
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 12 deletions

View file

@ -10,6 +10,7 @@ import torch.nn as nn
import transformers
from safetensors.torch import load_file as safe_load, save_file as safe_save
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin
from ._const import *
from ._utils import *
@ -19,7 +20,7 @@ logger = getLogger(__name__)
@dataclass
class BaseQuantizeConfig:
class BaseQuantizeConfig(PushToHubMixin):
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
damp_percent: float = field(default=0.01)
desc_act: bool = field(default=True)
@ -35,7 +36,7 @@ class BaseQuantizeConfig:
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("unless equal to -1, group_size must greater then 0.")
def save_pretrained(self, save_dir: str):
def save_pretrained(self, save_dir: str, **kwargs):
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2)
@ -53,7 +54,7 @@ class BaseQuantizeConfig:
}
class BaseGPTQForCausalLM(nn.Module):
class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
layers_block_name: str = None
outside_layer_modules: List[str] = None
inside_layer_modules: List[List[str]] = None
@ -293,6 +294,11 @@ class BaseGPTQForCausalLM(nn.Module):
self.model.config.save_pretrained(save_dir)
self.quantize_config.save_pretrained(save_dir)
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs):
"""alias of save_quantized"""
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors)
@classmethod
def from_pretrained(
cls,
@ -322,7 +328,7 @@ class BaseGPTQForCausalLM(nn.Module):
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length"]
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:

View file

@ -1,4 +1,4 @@
from ._base import BaseQuantizeConfig
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
from ._utils import check_and_get_model_type
from .bloom import BloomGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
@ -33,7 +33,7 @@ class AutoGPTQForCausalLM:
quantize_config: BaseQuantizeConfig,
bf16: bool = False,
**model_init_kwargs
):
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(pretrained_model_name_or_path)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
@ -49,7 +49,7 @@ class AutoGPTQForCausalLM:
device: str = "cpu",
use_safetensors: bool = False,
use_triton: bool = False
):
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(save_dir)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
save_dir=save_dir,

View file

@ -57,3 +57,22 @@ CUDA_VISIBLE_DEVICES=0 python run_text_summarization_task.py --base_model_dir PA
```
Use `--help` flag to see detailed descriptions for more command arguments.
## Push To Hub
> Commands in this chapter should be run under `push_to_hub` folder.
You can upload and share your quantized model to Hugging Face Hub by using `push_to_hub` function.
`push_quantized_model_to_hf_hub.py` provide a simple example to upload quantized model, tokenizer and configs at once.
First, you need to login, run the following command in the virtual environment where Hugging Face Transformers is installed:
```shell
huggingface-cli login
```
Then run the script like this:
```shell
python push_quantized_model_to_hf_hub.py --quantized_model_dir PATH/TO/QUANTIZED/MODEL/DIR --tokenizer_dir PATH/TO/TOKENIZER/DIR --repo_id REPO/ID
```
Use `--help` flag to see detailed descriptions for more command arguments.

View file

@ -1,9 +1,10 @@
import datasets
from argparse import ArgumentParser
from transformers import AutoTokenizer
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.eval_tasks import LanguageModelingTask
from transformers import AutoTokenizer
DATASET = "tatsu-lab/alpaca"
@ -63,6 +64,7 @@ def main():
task.model = None
model.cpu()
del model
torch.cuda.empty_cache()
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
task.model = model

View file

@ -2,10 +2,10 @@ from argparse import ArgumentParser
from functools import partial
import datasets
from transformers import AutoTokenizer
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.eval_tasks import SequenceClassificationTask
from transformers import AutoTokenizer
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
@ -67,6 +67,7 @@ def main():
task.model = None
model.cpu()
del model
torch.cuda.empty_cache()
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
task.model = model

View file

@ -2,10 +2,10 @@ import os
from argparse import ArgumentParser
import datasets
from transformers import AutoTokenizer, GenerationConfig
import torch
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
from auto_gptq.eval_tasks import TextSummarizationTask
from transformers import AutoTokenizer, GenerationConfig
os.system("pip install py7zr")
@ -61,6 +61,7 @@ def main():
task.model = None
model.cpu()
del model
torch.cuda.empty_cache()
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
task.model = model

View file

@ -0,0 +1,55 @@
from argparse import ArgumentParser
from auto_gptq import AutoGPTQForCausalLM
from transformers import AutoTokenizer
def main():
parser = ArgumentParser()
parser.add_argument("--quantized_model_dir", type=str, help="Directory that saves quantized model.")
parser.add_argument("--repo_id", type=str, help="The name of the repository you want to push to.")
parser.add_argument(
"--tokenizer_dir",
type=str,
default=None,
help="Directory that saves tokenizer, defaults to None, will not upload tokenizer if not specified."
)
parser.add_argument("--commit_message", type=str, default=None, help="Message to commit while pushing.")
parser.add_argument(
"--device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="Which device to load the model."
)
parser.add_argument(
"--private",
action="store_true",
help="Whether or not the repository created should be private."
)
parser.add_argument(
"--use_temp_dir",
action="store_true",
help="Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub."
)
args = parser.parse_args()
push_to_hub_kwargs = {
"repo_id": args.repo_id,
"commit_message": args.commit_message,
"private": args.private,
"use_temp_dir": args.use_temp_dir
}
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device=args.device)
model.push_to_hub(**push_to_hub_kwargs)
model.config.push_to_hub(**push_to_hub_kwargs)
model.quantize_config.push_to_hub(**push_to_hub_kwargs)
if args.tokenizer_dir:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
tokenizer.push_to_hub(**push_to_hub_kwargs)
if __name__ == "__main__":
main()