diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index 84e92bf..fd86767 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -2,6 +2,7 @@ import copy import json import warnings import os +import re from dataclasses import dataclass, field, fields from logging import getLogger from os.path import join, isfile, isdir @@ -17,7 +18,7 @@ from safetensors.torch import load_file as safe_load from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel from transformers.utils.hub import PushToHubMixin, cached_file, create_repo, create_commit, CommitOperationAdd from transformers.utils.generic import ContextManagers -from transformers.modeling_utils import no_init_weights +from transformers.modeling_utils import no_init_weights, shard_checkpoint from ._const import * from ._utils import * @@ -467,6 +468,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): private: Optional[bool] = None, token: Optional[Union[bool, str]] = None, create_pr: Optional[bool] = False, + max_shard_size: str = "10GB", + model_base_name: Optional[str] = None ) -> str: """ Upload the model to the Hugging Face Hub. @@ -504,7 +507,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): if save_dir is not None: logger.info(f"Saving model to {save_dir}") - self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + self.save_quantized(save_dir, use_safetensors, safetensors_metadata, max_shard_size, model_base_name) repo_url = create_repo( repo_id=repo_id, token=token, private=private, exist_ok=True, repo_type="model" @@ -527,59 +530,104 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): repo_type="model", ) - def save_quantized(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None): + def save_quantized( + self, + save_dir: str, + use_safetensors: bool = False, + safetensors_metadata: Optional[Dict[str, str]] = None, + max_shard_size: str = "10GB", + model_base_name: Optional[str] = None + ): """save quantized model and configs to local disk""" os.makedirs(save_dir, exist_ok=True) if not self.quantized: raise EnvironmentError("can only save quantized model, please execute .quantize first.") - self.model.to(CPU) + if model_base_name is None: + model_base_name = ( + self.quantize_config.model_file_base_name or + f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" + ) - model_base_name = self.quantize_config.model_file_base_name or f"gptq_model-{self.quantize_config.bits}bit-{self.quantize_config.group_size}g" + state_dict = self.model.state_dict() if use_safetensors: - model_save_name = model_base_name + ".safetensors" - state_dict = self.model.state_dict() state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} - if safetensors_metadata is None: - safetensors_metadata = {} - elif not isinstance(safetensors_metadata, dict): - raise TypeError("safetensors_metadata must be a dictionary.") - else: - logger.debug(f"Received safetensors_metadata: {safetensors_metadata}") - new_safetensors_metadata = {} - converted_keys = False - for key, value in safetensors_metadata.items(): - if not isinstance(key, str) or not isinstance(value, str): - converted_keys = True - try: - new_key = str(key) - new_value = str(value) - except Exception as e: - raise TypeError(f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}") - if new_key in new_safetensors_metadata: - logger.warning(f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.") - new_safetensors_metadata[new_key] = new_value - safetensors_metadata = new_safetensors_metadata - if converted_keys: - logger.debug(f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}") - - # Format is required to enable Accelerate to load the metadata - # otherwise it raises an OSError - safetensors_metadata['format'] = "pt" - - # Store the quantization configuration as safetensors metadata - from auto_gptq import __version__ - safetensors_metadata['auto_gptq_version'] = str(__version__) - safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits) - safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size) - safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act) - safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent) - - safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) + model_save_name = model_base_name + ".safetensors" else: model_save_name = model_base_name + ".bin" - torch.save(self.model.state_dict(), join(save_dir, model_save_name)) + + # Shard checkpoint + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=model_save_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_dir): + full_filename = join(save_dir, filename) + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(model_base_name) + and isfile(full_filename) + and filename not in shards.keys() + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + + # Save the model + for shard_file, shard in shards.items(): + if use_safetensors: + if safetensors_metadata is None: + safetensors_metadata = {} + elif not isinstance(safetensors_metadata, dict): + raise TypeError("safetensors_metadata must be a dictionary.") + else: + logger.debug(f"Received safetensors_metadata: {safetensors_metadata}") + new_safetensors_metadata = {} + converted_keys = False + for key, value in safetensors_metadata.items(): + if not isinstance(key, str) or not isinstance(value, str): + converted_keys = True + try: + new_key = str(key) + new_value = str(value) + except Exception as e: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}") + if new_key in new_safetensors_metadata: + logger.warning( + f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.") + new_safetensors_metadata[new_key] = new_value + safetensors_metadata = new_safetensors_metadata + if converted_keys: + logger.debug( + f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}") + + # Format is required to enable Accelerate to load the metadata + # otherwise it raises an OSError + safetensors_metadata['format'] = "pt" + + # Store the quantization configuration as safetensors metadata + from auto_gptq import __version__ + safetensors_metadata['auto_gptq_version'] = str(__version__) + safetensors_metadata['gptq_bits'] = str(self.quantize_config.bits) + safetensors_metadata['gptq_group_size'] = str(self.quantize_config.group_size) + safetensors_metadata['gptq_desc_act'] = str(self.quantize_config.desc_act) + safetensors_metadata['gptq_damp_percent'] = str(self.quantize_config.damp_percent) + + safe_save(shard, join(save_dir, shard_file), safetensors_metadata) + else: + torch.save(shard, join(save_dir, shard_file)) + + if index is not None: + index_save_name = model_save_name + ".index.json" + index_save_path = join(save_dir, index_save_name) + # Save the index as well + with open(index_save_path, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) self.model.config.save_pretrained(save_dir) self.quantize_config.save_pretrained(save_dir) @@ -589,7 +637,7 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin): def save_pretrained(self, save_dir: str, use_safetensors: bool = False, safetensors_metadata: Optional[Dict[str, str]] = None, **kwargs): """alias of save_quantized""" logger.warning("you are using save_pretrained, which will re-direct to save_quantized.") - self.save_quantized(save_dir, use_safetensors, safetensors_metadata) + self.save_quantized(save_dir, use_safetensors, safetensors_metadata, **kwargs) @classmethod def from_pretrained( diff --git a/setup.py b/setup.py index a15136c..14d424d 100644 --- a/setup.py +++ b/setup.py @@ -80,7 +80,7 @@ requirements = [ "gekko", "torch>=1.13.0", "safetensors", - "transformers>=4.31.0", + "transformers>=4.34.0", "peft", "tqdm", ]