save_quantized
method support shard checkpoint
This commit is contained in:
parent
bf70350153
commit
fc1184e7bc
1 changed files with 90 additions and 44 deletions
|
@ -2,6 +2,7 @@ import copy
|
||||||
import json
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from os.path import join, isfile, isdir
|
from os.path import join, isfile, isdir
|
||||||
|
@ -17,7 +18,7 @@ from safetensors.torch import load_file as safe_load
|
||||||
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
||||||
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
|
||||||
from transformers.utils.generic import ContextManagers
|
from transformers.utils.generic import ContextManagers
|
||||||
from transformers.modeling_utils import no_init_weights
|
from transformers.modeling_utils import no_init_weights, shard_checkpoint
|
||||||
|
|
||||||
from ._const import *
|
from ._const import *
|
||||||
from ._utils import *
|
from ._utils import *
|
||||||
|
@ -527,59 +528,104 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
repo_type="model",
|
repo_type="model",
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None):
|
def save_quantized(
|
||||||
|
self,
|
||||||
|
save_dir: str,
|
||||||
|
use_safetensors: bool = False,
|
||||||
|
safetensors_metadata: Optional[Dict[str, str]] = None,
|
||||||
|
max_shard_size: str = "10GB",
|
||||||
|
model_base_name: Optional[str] = None
|
||||||
|
):
|
||||||
"""save quantized model and configs to local disk"""
|
"""save quantized model and configs to local disk"""
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
if not self.quantized:
|
if not self.quantized:
|
||||||
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
||||||
|
|
||||||
self.model.to(CPU)
|
if model_base_name is None:
|
||||||
|
model_base_name = (
|
||||||
|
self.quantize_config.model_file_base_name or
|
||||||
|
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
||||||
|
)
|
||||||
|
|
||||||
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
|
state_dict = self.model.state_dict()
|
||||||
if use_safetensors:
|
if use_safetensors:
|
||||||
model_save_name = model_base_name + ".safetensors"
|
|
||||||
state_dict = self.model.state_dict()
|
|
||||||
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
||||||
if safetensors_metadata is None:
|
model_save_name = model_base_name + ".safetensors"
|
||||||
safetensors_metadata = {}
|
|
||||||
elif not isinstance(safetensors_metadata, dict):
|
|
||||||
raise TypeError("safetensors_metadata must be a dictionary.")
|
|
||||||
else:
|
|
||||||
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
|
|
||||||
new_safetensors_metadata = {}
|
|
||||||
converted_keys = False
|
|
||||||
for key, value in safetensors_metadata.items():
|
|
||||||
if not isinstance(key, str) or not isinstance(value, str):
|
|
||||||
converted_keys = True
|
|
||||||
try:
|
|
||||||
new_key = str(key)
|
|
||||||
new_value = str(value)
|
|
||||||
except Exception as e:
|
|
||||||
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
|
|
||||||
if new_key in new_safetensors_metadata:
|
|
||||||
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
|
|
||||||
new_safetensors_metadata[new_key] = new_value
|
|
||||||
safetensors_metadata = new_safetensors_metadata
|
|
||||||
if converted_keys:
|
|
||||||
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
|
|
||||||
|
|
||||||
# Format is required to enable Accelerate to load the metadata
|
|
||||||
# otherwise it raises an OSError
|
|
||||||
safetensors_metadata['format'] = "pt"
|
|
||||||
|
|
||||||
# Store the quantization configuration as safetensors metadata
|
|
||||||
from auto_gptq import __version__
|
|
||||||
safetensors_metadata['auto_gptq_version'] = str(__version__)
|
|
||||||
safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits)
|
|
||||||
safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size)
|
|
||||||
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
|
|
||||||
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
|
|
||||||
|
|
||||||
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
|
|
||||||
else:
|
else:
|
||||||
model_save_name = model_base_name + ".bin"
|
model_save_name = model_base_name + ".bin"
|
||||||
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
|
||||||
|
# Shard checkpoint
|
||||||
|
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)
|
||||||
|
|
||||||
|
# Clean the folder from a previous save
|
||||||
|
for filename in os.listdir(save_dir):
|
||||||
|
full_filename = join(save_dir, filename)
|
||||||
|
|
||||||
|
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||||
|
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
|
||||||
|
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
filename.startswith(model_base_name)
|
||||||
|
and isfile(full_filename)
|
||||||
|
and filename not in shards.keys()
|
||||||
|
and reg.fullmatch(filename_no_suffix) is not None
|
||||||
|
):
|
||||||
|
os.remove(full_filename)
|
||||||
|
|
||||||
|
# Save the model
|
||||||
|
for shard_file, shard in shards.items():
|
||||||
|
if use_safetensors:
|
||||||
|
if safetensors_metadata is None:
|
||||||
|
safetensors_metadata = {}
|
||||||
|
elif not isinstance(safetensors_metadata, dict):
|
||||||
|
raise TypeError("safetensors_metadata must be a dictionary.")
|
||||||
|
else:
|
||||||
|
logger.debug(f"Received safetensors_metadata: {safetensors_metadata}")
|
||||||
|
new_safetensors_metadata = {}
|
||||||
|
converted_keys = False
|
||||||
|
for key, value in safetensors_metadata.items():
|
||||||
|
if not isinstance(key, str) or not isinstance(value, str):
|
||||||
|
converted_keys = True
|
||||||
|
try:
|
||||||
|
new_key = str(key)
|
||||||
|
new_value = str(value)
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError(
|
||||||
|
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
|
||||||
|
if new_key in new_safetensors_metadata:
|
||||||
|
logger.warning(
|
||||||
|
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
|
||||||
|
new_safetensors_metadata[new_key] = new_value
|
||||||
|
safetensors_metadata = new_safetensors_metadata
|
||||||
|
if converted_keys:
|
||||||
|
logger.debug(
|
||||||
|
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
|
||||||
|
|
||||||
|
# Format is required to enable Accelerate to load the metadata
|
||||||
|
# otherwise it raises an OSError
|
||||||
|
safetensors_metadata['format'] = "pt"
|
||||||
|
|
||||||
|
# Store the quantization configuration as safetensors metadata
|
||||||
|
from auto_gptq import __version__
|
||||||
|
safetensors_metadata['auto_gptq_version'] = str(__version__)
|
||||||
|
safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits)
|
||||||
|
safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size)
|
||||||
|
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
|
||||||
|
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
|
||||||
|
|
||||||
|
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
|
||||||
|
else:
|
||||||
|
torch.save(shard, join(save_dir, shard_file))
|
||||||
|
|
||||||
|
if index is not None:
|
||||||
|
index_save_name = model_save_name + ".index.json"
|
||||||
|
index_save_path = join(save_dir, index_save_name)
|
||||||
|
# Save the index as well
|
||||||
|
with open(index_save_path, "w", encoding="utf-8") as f:
|
||||||
|
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
self.model.config.save_pretrained(save_dir)
|
self.model.config.save_pretrained(save_dir)
|
||||||
self.quantize_config.save_pretrained(save_dir)
|
self.quantize_config.save_pretrained(save_dir)
|
||||||
|
@ -589,7 +635,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
||||||
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
|
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
|
||||||
"""alias of save_quantized"""
|
"""alias of save_quantized"""
|
||||||
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
||||||
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
|
self.save_quantized(save_dir, use_safetensors, safetensors_metadata, **kwargs)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
|
|
Loading…
Add table
Reference in a new issue