save_quantized method support shard checkpoint

This commit is contained in:
student686 2023-10-07 13:48:45 +08:00
parent bf70350153
commit fc1184e7bc

View file

@ -2,6 +2,7 @@ import copy
import json
import warnings
import os
import re
from dataclasses import dataclass, field, fields
from logging import getLogger
from os.path import join, isfile, isdir
@ -17,7 +18,7 @@ from safetensors.torch import load_file as safe_load
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd
from transformers.utils.generic import ContextManagers
from transformers.modeling_utils import no_init_weights
from transformers.modeling_utils import no_init_weights, shard_checkpoint
from ._const import *
from ._utils import *
@ -527,20 +528,55 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
repo_type="model",
)
def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None):
def save_quantized(
self,
save_dir: str,
use_safetensors: bool = False,
safetensors_metadata: Optional[Dict[str, str]] = None,
max_shard_size: str = "10GB",
model_base_name: Optional[str] = None
):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to(CPU)
if model_base_name is None:
model_base_name = (
self.quantize_config.model_file_base_name or
f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
)
model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g"
if use_safetensors:
model_save_name = model_base_name + ".safetensors"
state_dict = self.model.state_dict()
if use_safetensors:
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
model_save_name = model_base_name + ".safetensors"
else:
model_save_name = model_base_name + ".bin"
# Shard checkpoint
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name)
# Clean the folder from a previous save
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(model_base_name)
and isfile(full_filename)
and filename not in shards.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
# Save the model
for shard_file, shard in shards.items():
if use_safetensors:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
@ -556,13 +592,16 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
logger.warning(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
logger.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")
# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
@ -576,10 +615,17 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act)
safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent)
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
safe_save(shard, join(save_dir, shard_file), safetensors_metadata)
else:
model_save_name = model_base_name + ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
torch.save(shard, join(save_dir, shard_file))
if index is not None:
index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
# Save the index as well
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
self.model.config.save_pretrained(save_dir)
self.quantize_config.save_pretrained(save_dir)
@ -589,7 +635,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs):
"""alias of save_quantized"""
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
self.save_quantized(save_dir, use_safetensors, safetensors_metadata)
self.save_quantized(save_dir, use_safetensors, safetensors_metadata, **kwargs)
@classmethod
def from_pretrained(