Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
student686
22af50bab0 add new args of save_quantized method to push_to_hub method 2023-10-07 13:59:53 +08:00
student686
fc1184e7bc save_quantized method support shard checkpoint 2023-10-07 13:48:45 +08:00
student686
bf70350153 bump transformers version to 4.34.0 2023-10-07 13:47:37 +08:00
2 changed files with 94 additions and 46 deletions

View file

@ -2,6 +2,7 @@ import copy
import json import json
import warnings import warnings
import os import os
import re
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from logging import getLogger from logging import getLogger
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
@ -17,7 +18,7 @@ from safetensors.torch import load_file as safe_load
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
from transformers.utils.generic import ContextManagers from transformers.utils.generic import ContextManagers
from transformers.modeling_utils import no_init_weights from transformers.modeling_utils import no_init_weights, shard_checkpoint
from ._const import * from ._const import *
from ._utils import * from ._utils import *
@ -467,6 +468,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
private: Optional[bool] = None, private: Optional[bool] = None,
token: Optional[Union[bool, str]] = None, token: Optional[Union[bool, str]] = None,
create_pr: Optional[bool] = False, create_pr: Optional[bool] = False,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
) -> str: ) -> str:
""" """
Upload the model to the Hugging Face Hub. Upload the model to the Hugging Face Hub.
@ -504,7 +507,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
if save_dir is not None: if save_dir is not None:
logger.info(f"Saving model to {save_dir}") logger.info(f"Saving model to {save_dir}")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata) self.save_quantized(save_dir, use_safetensors, safetensors_metadata, max_shard_size, model_base_name)
repo_url = create_repo( repo_url = create_repo(
repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model" repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model"
@ -527,59 +530,104 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
repo_type="model", repo_type="model",
) )
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None): def save_quantized(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
):
"""save quantized model and configs to local disk""" """save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
if not self.quantized: if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.") raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to(CPU) if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" state_dict = self.model.state_dict()
if use_safetensors: if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
if safetensors_metadata is None: model_save_name = model_base_name + ".safetensors"
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata['format'] = "pt"
# Store the quantization configuration as safetensors metadata
from auto_gptq import __version__
safetensors_metadata['auto_gptq_version'] = str(__version__)
safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits)
safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size)
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
else: else:
model_save_name = model_base_name + ".bin" model_save_name = model_base_name + ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
# Shard checkpoint
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)
# Clean the folder from a previous save
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(model_base_name)
and isfile(full_filename)
and filename not in shards.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
# Save the model
for shard_file, shard in shards.items():
if use_safetensors:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata['format'] = "pt"
# Store the quantization configuration as safetensors metadata
from auto_gptq import __version__
safetensors_metadata['auto_gptq_version'] = str(__version__)
safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits)
safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size)
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
else:
torch.save(shard, join(save_dir, shard_file))
if index is not None:
index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
# Save the index as well
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
self.model.config.save_pretrained(save_dir) self.model.config.save_pretrained(save_dir)
self.quantize_config.save_pretrained(save_dir) self.quantize_config.save_pretrained(save_dir)
@ -589,7 +637,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs): def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
"""alias of save_quantized""" """alias of save_quantized"""
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.") logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata) self.save_quantized(save_dir, use_safetensors, safetensors_metadata, **kwargs)
@classmethod @classmethod
def from_pretrained( def from_pretrained(

View file

@ -80,7 +80,7 @@ requirements = [
"gekko", "gekko",
"torch>=1.13.0", "torch>=1.13.0",
"safetensors", "safetensors",
"transformers>=4.31.0", "transformers>=4.34.0",
"peft", "peft",
"tqdm", "tqdm",
] ]