Merge pull request #18 from PanQiWei/push_to_hub_integration
push_to_hub integration
This commit is contained in:
commit
1acb0c5eba
7 changed files with 96 additions and 12 deletions
|
@ -10,6 +10,7 @@ import torch.nn as nn
|
||||||
import transformers
|
import transformers
|
||||||
from safetensors.torch import load_file as safe_load, save_file as safe_save
|
from safetensors.torch import load_file as safe_load, save_file as safe_save
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||||
|
from transformers.utils.hub import PushToHubMixin
|
||||||
|
|
||||||
from ._const import *
|
from ._const import *
|
||||||
from ._utils import *
|
from ._utils import *
|
||||||
|
@ -19,7 +20,7 @@ logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseQuantizeConfig:
|
class BaseQuantizeConfig(PushToHubMixin):
|
||||||
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
|
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
|
||||||
damp_percent: float = field(default=0.01)
|
damp_percent: float = field(default=0.01)
|
||||||
desc_act: bool = field(default=True)
|
desc_act: bool = field(default=True)
|
||||||
|
@ -35,7 +36,7 @@ class BaseQuantizeConfig:
|
||||||
if self.group_size != -1 and self.group_size <= 0:
|
if self.group_size != -1 and self.group_size <= 0:
|
||||||
raise ValueError("unless equal to -1, group_size must greater then 0.")
|
raise ValueError("unless equal to -1, group_size must greater then 0.")
|
||||||
|
|
||||||
def save_pretrained(self, save_dir: str):
|
def save_pretrained(self, save_dir: str, **kwargs):
|
||||||
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
|
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
|
||||||
json.dump(self.to_dict(), f, indent=2)
|
json.dump(self.to_dict(), f, indent=2)
|
||||||
|
|
||||||
|
@ -53,7 +54,7 @@ class BaseQuantizeConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class BaseGPTQForCausalLM(nn.Module):
|
class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
layers_block_name: str = None
|
layers_block_name: str = None
|
||||||
outside_layer_modules: List[str] = None
|
outside_layer_modules: List[str] = None
|
||||||
inside_layer_modules: List[List[str]] = None
|
inside_layer_modules: List[List[str]] = None
|
||||||
|
@ -293,6 +294,11 @@ class BaseGPTQForCausalLM(nn.Module):
|
||||||
self.model.config.save_pretrained(save_dir)
|
self.model.config.save_pretrained(save_dir)
|
||||||
self.quantize_config.save_pretrained(save_dir)
|
self.quantize_config.save_pretrained(save_dir)
|
||||||
|
|
||||||
|
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs):
|
||||||
|
"""alias of save_quantized"""
|
||||||
|
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||||
|
self.save_quantized(save_dir, use_safetensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
|
@ -322,7 +328,7 @@ class BaseGPTQForCausalLM(nn.Module):
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
|
||||||
model_config = model.config.to_dict()
|
model_config = model.config.to_dict()
|
||||||
seq_len_keys = ["max_position_embeddings", "seq_length"]
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
||||||
if any([k in model_config for k in seq_len_keys]):
|
if any([k in model_config for k in seq_len_keys]):
|
||||||
for key in seq_len_keys:
|
for key in seq_len_keys:
|
||||||
if key in model_config:
|
if key in model_config:
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from ._base import BaseQuantizeConfig
|
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
|
||||||
from ._utils import check_and_get_model_type
|
from ._utils import check_and_get_model_type
|
||||||
from .bloom import BloomGPTQForCausalLM
|
from .bloom import BloomGPTQForCausalLM
|
||||||
from .gpt_neox import GPTNeoXGPTQForCausalLM
|
from .gpt_neox import GPTNeoXGPTQForCausalLM
|
||||||
|
@ -33,7 +33,7 @@ class AutoGPTQForCausalLM:
|
||||||
quantize_config: BaseQuantizeConfig,
|
quantize_config: BaseQuantizeConfig,
|
||||||
bf16: bool = False,
|
bf16: bool = False,
|
||||||
**model_init_kwargs
|
**model_init_kwargs
|
||||||
):
|
) -> BaseGPTQForCausalLM:
|
||||||
model_type = check_and_get_model_type(pretrained_model_name_or_path)
|
model_type = check_and_get_model_type(pretrained_model_name_or_path)
|
||||||
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
|
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
@ -49,7 +49,7 @@ class AutoGPTQForCausalLM:
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
use_safetensors: bool = False,
|
use_safetensors: bool = False,
|
||||||
use_triton: bool = False
|
use_triton: bool = False
|
||||||
):
|
) -> BaseGPTQForCausalLM:
|
||||||
model_type = check_and_get_model_type(save_dir)
|
model_type = check_and_get_model_type(save_dir)
|
||||||
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
|
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
|
|
|
@ -57,3 +57,22 @@ CUDA_VISIBLE_DEVICES=0 python run_text_summarization_task.py --base_model_dir PA
|
||||||
```
|
```
|
||||||
|
|
||||||
Use `--help` flag to see detailed descriptions for more command arguments.
|
Use `--help` flag to see detailed descriptions for more command arguments.
|
||||||
|
|
||||||
|
## Push To Hub
|
||||||
|
> Commands in this chapter should be run under `push_to_hub` folder.
|
||||||
|
|
||||||
|
You can upload and share your quantized model to Hugging Face Hub by using `push_to_hub` function.
|
||||||
|
|
||||||
|
`push_quantized_model_to_hf_hub.py` provide a simple example to upload quantized model, tokenizer and configs at once.
|
||||||
|
|
||||||
|
First, you need to login, run the following command in the virtual environment where Hugging Face Transformers is installed:
|
||||||
|
```shell
|
||||||
|
huggingface-cli login
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run the script like this:
|
||||||
|
```shell
|
||||||
|
python push_quantized_model_to_hf_hub.py --quantized_model_dir PATH/TO/QUANTIZED/MODEL/DIR --tokenizer_dir PATH/TO/TOKENIZER/DIR --repo_id REPO/ID
|
||||||
|
```
|
||||||
|
|
||||||
|
Use `--help` flag to see detailed descriptions for more command arguments.
|
|
@ -1,9 +1,10 @@
|
||||||
import datasets
|
import datasets
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
import torch
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
from auto_gptq.eval_tasks import LanguageModelingTask
|
from auto_gptq.eval_tasks import LanguageModelingTask
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
DATASET = "tatsu-lab/alpaca"
|
DATASET = "tatsu-lab/alpaca"
|
||||||
|
@ -63,6 +64,7 @@ def main():
|
||||||
task.model = None
|
task.model = None
|
||||||
model.cpu()
|
model.cpu()
|
||||||
del model
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||||
task.model = model
|
task.model = model
|
||||||
|
|
|
@ -2,10 +2,10 @@ from argparse import ArgumentParser
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from transformers import AutoTokenizer
|
import torch
|
||||||
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
from auto_gptq.eval_tasks import SequenceClassificationTask
|
from auto_gptq.eval_tasks import SequenceClassificationTask
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
|
DATASET = "cardiffnlp/tweet_sentiment_multilingual"
|
||||||
|
@ -67,6 +67,7 @@ def main():
|
||||||
task.model = None
|
task.model = None
|
||||||
model.cpu()
|
model.cpu()
|
||||||
del model
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||||
task.model = model
|
task.model = model
|
||||||
|
|
|
@ -2,10 +2,10 @@ import os
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
from transformers import AutoTokenizer, GenerationConfig
|
import torch
|
||||||
|
|
||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
from auto_gptq.eval_tasks import TextSummarizationTask
|
from auto_gptq.eval_tasks import TextSummarizationTask
|
||||||
|
from transformers import AutoTokenizer, GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
os.system("pip install py7zr")
|
os.system("pip install py7zr")
|
||||||
|
@ -61,6 +61,7 @@ def main():
|
||||||
task.model = None
|
task.model = None
|
||||||
model.cpu()
|
model.cpu()
|
||||||
del model
|
del model
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device="cuda:0", use_triton=args.use_triton)
|
||||||
task.model = model
|
task.model = model
|
||||||
|
|
55
examples/push_to_hub/push_quantized_model_to_hf_hub.py
Normal file
55
examples/push_to_hub/push_quantized_model_to_hf_hub.py
Normal file
|
@ -0,0 +1,55 @@
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from auto_gptq import AutoGPTQForCausalLM
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--quantized_model_dir", type=str, help="Directory that saves quantized model.")
|
||||||
|
parser.add_argument("--repo_id", type=str, help="The name of the repository you want to push to.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer_dir",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Directory that saves tokenizer, defaults to None, will not upload tokenizer if not specified."
|
||||||
|
)
|
||||||
|
parser.add_argument("--commit_message", type=str, default=None, help="Message to commit while pushing.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
choices=["cpu", "cuda"],
|
||||||
|
help="Which device to load the model."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--private",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not the repository created should be private."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--use_temp_dir",
|
||||||
|
action="store_true",
|
||||||
|
help="Whether or not to use a temporary directory to store the files saved before they are pushed to the Hub."
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
push_to_hub_kwargs = {
|
||||||
|
"repo_id": args.repo_id,
|
||||||
|
"commit_message": args.commit_message,
|
||||||
|
"private": args.private,
|
||||||
|
"use_temp_dir": args.use_temp_dir
|
||||||
|
}
|
||||||
|
|
||||||
|
model = AutoGPTQForCausalLM.from_quantized(args.quantized_model_dir, device=args.device)
|
||||||
|
model.push_to_hub(**push_to_hub_kwargs)
|
||||||
|
model.config.push_to_hub(**push_to_hub_kwargs)
|
||||||
|
model.quantize_config.push_to_hub(**push_to_hub_kwargs)
|
||||||
|
|
||||||
|
if args.tokenizer_dir:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir)
|
||||||
|
tokenizer.push_to_hub(**push_to_hub_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Add table
Reference in a new issue