506 lines
20 KiB
Python
506 lines
20 KiB
Python
import json
|
|
import os
|
|
from dataclasses import dataclass, field, fields
|
|
from logging import getLogger
|
|
from os.path import join
|
|
from typing import Dict, List, Optional, Union
|
|
|
|
import accelerate
|
|
import torch
|
|
import torch.nn as nn
|
|
import transformers
|
|
from accelerate.hooks import remove_hook_from_module, remove_hook_from_submodules
|
|
from safetensors.torch import save_file as safe_save
|
|
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
|
|
from transformers.utils.hub import PushToHubMixin
|
|
|
|
from ._const import *
|
|
from ._utils import *
|
|
from ..quantization import GPTQ
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class BaseQuantizeConfig(PushToHubMixin):
|
|
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
|
|
damp_percent: float = field(default=0.01)
|
|
desc_act: bool = field(default=True)
|
|
group_size: int = field(default=-1)
|
|
|
|
def __post_init__(self):
|
|
fields_info = fields(self)
|
|
|
|
if self.bits not in fields_info[0].metadata["choices"]:
|
|
raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.")
|
|
if not (0 < self.damp_percent < 1):
|
|
raise ValueError("damp_percent must between 0 and 1.")
|
|
if self.group_size != -1 and self.group_size <= 0:
|
|
raise ValueError("unless equal to -1, group_size must greater then 0.")
|
|
|
|
def save_pretrained(self, save_dir: str, **kwargs):
|
|
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
|
|
json.dump(self.to_dict(), f, indent=2)
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, save_dir: str):
|
|
with open(join(save_dir, "quantize_config.json"), "r", encoding="utf-8") as f:
|
|
return cls(**json.load(f))
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"bits": self.bits,
|
|
"damp_percent": self.damp_percent,
|
|
"desc_act": self.desc_act,
|
|
"group_size": self.group_size
|
|
}
|
|
|
|
|
|
class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|
layer_type: str = None
|
|
layers_block_name: str = None
|
|
outside_layer_modules: List[str] = None
|
|
inside_layer_modules: List[List[str]] = None
|
|
lm_head_name: str = "lm_head"
|
|
|
|
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig):
|
|
super().__init__()
|
|
|
|
self.model = model
|
|
self.model_type = self.model.config.model_type
|
|
self._quantized = quantized
|
|
self.quantize_config = quantize_config
|
|
self.config = self.model.config
|
|
|
|
@property
|
|
def quantized(self):
|
|
return self._quantized
|
|
|
|
@property
|
|
def hf_device_map(self):
|
|
return getattr(self.model, "hf_device_map", None)
|
|
|
|
@staticmethod
|
|
def _resize_attention_mask(attention_mask: List[torch.LongTensor]):
|
|
return attention_mask
|
|
|
|
@staticmethod
|
|
def _resize_position_ids(position_ids: List[torch.LongTensor]):
|
|
return position_ids
|
|
|
|
@torch.no_grad()
|
|
def quantize(
|
|
self,
|
|
examples: List[Dict[str, torch.LongTensor]],
|
|
use_triton: bool = False,
|
|
autotune_warmup_after_quantized: bool = False,
|
|
cache_examples_on_gpu: bool = True
|
|
):
|
|
if self.quantized:
|
|
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
|
|
|
device_map = self.hf_device_map
|
|
if device_map:
|
|
for name, device in device_map.items():
|
|
if device == "cpu":
|
|
module = get_module_by_name(self.model, name)
|
|
remove_hook_from_module(module, recurse=True)
|
|
accelerate.cpu_offload_with_hook(module, CUDA_0)
|
|
|
|
layer_inputs = []
|
|
attention_masks = []
|
|
position_ids = []
|
|
layer_input_kwargs = []
|
|
layer_outputs = []
|
|
|
|
class LayerHijacker(nn.Module):
|
|
"""hijack layer's forward pass to cache data"""
|
|
|
|
def __init__(self, m, device):
|
|
super().__init__()
|
|
self.module = m
|
|
self.data_device = device if cache_examples_on_gpu else CPU
|
|
|
|
def forward(self, inp=None, **kwargs):
|
|
if inp is None: # some models use all key-value arguments in forward pass call
|
|
for kwarg_name in ["hidden_states"]:
|
|
if kwarg_name in kwargs:
|
|
inp = kwargs[kwarg_name]
|
|
break
|
|
bsz = inp.size(0)
|
|
for i in range(bsz):
|
|
layer_inputs.append(move_to_device(inp[i].unsqueeze(0), self.data_device))
|
|
attention_masks.append(kwargs["attention_mask"][i].to(self.data_device))
|
|
if (pos_ids := kwargs.get("position_ids", None)) is not None:
|
|
position_ids.append(move_to_device(pos_ids[i].unsqueeze(0), self.data_device))
|
|
one_kwargs = dict()
|
|
for k, v in kwargs.items(): # make sure other arguments also be captured
|
|
if k not in ["hidden_states", "attention_mask", "position_ids"]:
|
|
if isinstance(v, torch.Tensor):
|
|
one_kwargs[k] = move_to_device(v[i].unsqueeze(0), self.data_device)
|
|
else:
|
|
one_kwargs[k] = v
|
|
layer_input_kwargs.append(one_kwargs)
|
|
raise ValueError
|
|
|
|
forward_pass_use_cache = self.model.config.use_cache
|
|
self.model.config.use_cache = False
|
|
|
|
num_examples = len(examples)
|
|
layers = get_module_by_name(self.model, self.layers_block_name)
|
|
|
|
force_layer_back_to_cpu = False
|
|
if get_device(layers[0]) == CPU:
|
|
layers[0] = layers[0].to(CUDA_0)
|
|
force_layer_back_to_cpu = True
|
|
|
|
cur_layer_device = get_device(layers[0])
|
|
ori_outside_layer_module_devices = {}
|
|
for module_name in self.outside_layer_modules:
|
|
module = get_module_by_name(self.model, module_name)
|
|
|
|
if module is None:
|
|
continue
|
|
|
|
ori_outside_layer_module_devices[module_name] = get_device(module)
|
|
if module is not None:
|
|
move_to_device(module, cur_layer_device)
|
|
|
|
# get inputs for first layer
|
|
layers[0] = LayerHijacker(layers[0], cur_layer_device)
|
|
for example in examples:
|
|
for k, v in example.items():
|
|
if len(v.shape) == 1:
|
|
v = v.unsqueeze(0)
|
|
example[k] = move_to_device(v, cur_layer_device)
|
|
try:
|
|
self.model(**example)
|
|
except ValueError:
|
|
pass
|
|
layers[0] = layers[0].module
|
|
|
|
move_to_device(layers[0], CPU if force_layer_back_to_cpu else cur_layer_device)
|
|
for module_name in self.outside_layer_modules:
|
|
module = get_module_by_name(self.model, module_name)
|
|
if module is not None:
|
|
move_to_device(module, ori_outside_layer_module_devices[module_name])
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
# resize attention mask and position ids for some special models
|
|
attention_masks = self._resize_attention_mask(attention_masks)
|
|
position_ids = self._resize_position_ids(position_ids)
|
|
|
|
quantizers = {}
|
|
for i in range(len(layers)):
|
|
logger.info(f"Start quantizing layer {i + 1}/{len(layers)}")
|
|
layer = layers[i]
|
|
force_layer_back_to_cpu = False
|
|
if get_device(layer) == CPU:
|
|
move_to_device(layer, CUDA_0)
|
|
force_layer_back_to_cpu = True
|
|
cur_layer_device = get_device(layer)
|
|
|
|
full = find_layers(layer)
|
|
for names in self.inside_layer_modules:
|
|
subset = {n: full[n] for n in names}
|
|
gptq = {}
|
|
for name in subset:
|
|
gptq[name] = GPTQ(subset[name])
|
|
gptq[name].quantizer.configure(
|
|
self.quantize_config.bits,
|
|
perchannel=True,
|
|
sym=True,
|
|
mse=False
|
|
)
|
|
|
|
def add_batch(name):
|
|
def tmp(_, inp, out):
|
|
gptq[name].add_batch(inp[0].data, out.data)
|
|
|
|
return tmp
|
|
|
|
handles = []
|
|
for name in subset:
|
|
handles.append(subset[name].register_forward_hook(add_batch(name)))
|
|
for j in range(num_examples):
|
|
layer_input = move_to_device(layer_inputs[j], cur_layer_device)
|
|
layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device)
|
|
additional_layer_inputs = {
|
|
"attention_mask": layer_attention_mask
|
|
}
|
|
if (
|
|
layer_position_ids := None if not position_ids
|
|
else move_to_device(position_ids[j], cur_layer_device)
|
|
) is not None:
|
|
additional_layer_inputs["position_ids"] = layer_position_ids
|
|
for k, v in layer_input_kwargs[j].items():
|
|
if isinstance(v, torch.Tensor):
|
|
additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
|
|
else:
|
|
additional_layer_inputs[k] = v
|
|
layer(layer_input, **additional_layer_inputs)
|
|
for h in handles:
|
|
h.remove()
|
|
|
|
for name in subset:
|
|
logger.info(f'Quantizing {name} in layer {i + 1}/{len(layers)}...')
|
|
scale, zero, g_idx = gptq[name].fasterquant(
|
|
percdamp=self.quantize_config.damp_percent,
|
|
groupsize=self.quantize_config.group_size,
|
|
actorder=self.quantize_config.desc_act
|
|
)
|
|
quantizers[f'{self.layers_block_name}.{i}.{name}'] = (
|
|
gptq[name].quantizer.to(CPU if force_layer_back_to_cpu else cur_layer_device),
|
|
move_to_device(scale, CPU if force_layer_back_to_cpu else cur_layer_device),
|
|
move_to_device(zero, CPU if force_layer_back_to_cpu else cur_layer_device),
|
|
move_to_device(g_idx, CPU if force_layer_back_to_cpu else cur_layer_device)
|
|
)
|
|
gptq[name].free()
|
|
|
|
for j in range(num_examples):
|
|
layer_input = move_to_device(layer_inputs[j], cur_layer_device)
|
|
layer_attention_mask = move_to_device(attention_masks[j], cur_layer_device)
|
|
additional_layer_inputs = {
|
|
"attention_mask": layer_attention_mask
|
|
}
|
|
if (
|
|
layer_position_ids := None if not position_ids
|
|
else move_to_device(position_ids[j], cur_layer_device)
|
|
) is not None:
|
|
additional_layer_inputs["position_ids"] = layer_position_ids
|
|
for k, v in layer_input_kwargs[j].items():
|
|
if isinstance(v, torch.Tensor):
|
|
additional_layer_inputs[k] = move_to_device(v, cur_layer_device)
|
|
else:
|
|
additional_layer_inputs[k] = v
|
|
layer_output = move_to_device(
|
|
layer(layer_input, **additional_layer_inputs)[0],
|
|
cur_layer_device if cache_examples_on_gpu else CPU
|
|
)
|
|
layer_outputs.append(layer_output)
|
|
|
|
layers[i] = move_to_device(layer, CPU if force_layer_back_to_cpu else cur_layer_device)
|
|
del layer
|
|
del gptq
|
|
del layer_inputs
|
|
layer_inputs, layer_outputs = layer_outputs, []
|
|
torch.cuda.empty_cache()
|
|
|
|
pack_model(
|
|
model=self.model,
|
|
quantizers=quantizers,
|
|
bits=self.quantize_config.bits,
|
|
group_size=self.quantize_config.group_size,
|
|
use_triton=use_triton,
|
|
autotune_warmup=autotune_warmup_after_quantized,
|
|
force_layer_back_to_cpu=force_layer_back_to_cpu
|
|
)
|
|
if device_map:
|
|
self.model = remove_hook_from_module(self.model, recurse=True)
|
|
self.model = accelerate.dispatch_model(self.model, device_map, offload_buffers=True)
|
|
self.model.config.use_cache = forward_pass_use_cache
|
|
|
|
self._quantized = True
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
@property
|
|
def device(self):
|
|
return self.model.device
|
|
|
|
def to(self, device: Union[str, torch.device]):
|
|
self.model.to(device)
|
|
|
|
def forward(self, **kwargs):
|
|
return self.model(**kwargs)
|
|
|
|
def generate(self, **kwargs):
|
|
"""shortcut for model.generate"""
|
|
with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type):
|
|
return self.model.generate(**kwargs)
|
|
|
|
def prepare_inputs_for_generation(self, *args, **kwargs):
|
|
"""shortcut for model.prepare_inputs_for_generation"""
|
|
return self.model.prepare_inputs_for_generation(*args, **kwargs)
|
|
|
|
def save_quantized(self, save_dir: str, use_safetensors: bool = False):
|
|
"""save quantized model and configs to local disk"""
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
if not self.quantized:
|
|
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
|
|
|
|
self.model.to(CPU)
|
|
|
|
model_save_name = f"gptq_model-{self.quantize_config.bits}bit"
|
|
if use_safetensors:
|
|
model_save_name += ".safetensors"
|
|
state_dict = self.model.state_dict()
|
|
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
|
|
safe_save(state_dict, join(save_dir, model_save_name))
|
|
else:
|
|
model_save_name += ".bin"
|
|
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
|
|
|
|
self.model.config.save_pretrained(save_dir)
|
|
self.quantize_config.save_pretrained(save_dir)
|
|
|
|
def save_pretrained(self, save_dir: str, use_safetensors: bool = False, **kwargs):
|
|
"""alias of save_quantized"""
|
|
logger.warning("you are using save_pretrained, which will re-direct to save_quantized.")
|
|
self.save_quantized(save_dir, use_safetensors)
|
|
|
|
@classmethod
|
|
def from_pretrained(
|
|
cls,
|
|
pretrained_model_name_or_path: str,
|
|
quantize_config: BaseQuantizeConfig,
|
|
max_memory: Optional[dict] = None,
|
|
**model_init_kwargs
|
|
):
|
|
"""load un-quantized pretrained model to cpu"""
|
|
|
|
if not torch.cuda.is_available():
|
|
raise EnvironmentError("Load pretrained model to do quantization requires CUDA available.")
|
|
|
|
def skip(*args, **kwargs):
|
|
pass
|
|
|
|
torch.nn.init.kaiming_uniform_ = skip
|
|
torch.nn.init.uniform_ = skip
|
|
torch.nn.init.normal_ = skip
|
|
|
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
|
|
if config.model_type not in SUPPORTED_MODELS:
|
|
raise TypeError(f"{config.model_type} isn't supported yet.")
|
|
|
|
# enforce some values despite user specified
|
|
model_init_kwargs["torch_dtype"] = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
model_init_kwargs["trust_remote_code"] = True
|
|
if max_memory:
|
|
if "disk" in max_memory:
|
|
raise NotImplementedError("disk offload not support yet.")
|
|
with accelerate.init_empty_weights():
|
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
|
model.tie_weights()
|
|
|
|
max_memory = accelerate.utils.get_balanced_memory(
|
|
model,
|
|
max_memory=max_memory,
|
|
no_split_module_classes=[cls.layer_type],
|
|
dtype=model_init_kwargs["torch_dtype"],
|
|
low_zero=False
|
|
)
|
|
model_init_kwargs["device_map"] = accelerate.infer_auto_device_map(
|
|
model,
|
|
max_memory=max_memory,
|
|
no_split_module_classes=[cls.layer_type],
|
|
dtype=model_init_kwargs["torch_dtype"]
|
|
)
|
|
model_init_kwargs["low_cpu_mem_usage"] = True
|
|
|
|
del model
|
|
else:
|
|
model_init_kwargs["device_map"] = None
|
|
model_init_kwargs["low_cpu_mem_usage"] = False
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
|
|
model_config = model.config.to_dict()
|
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
|
if any([k in model_config for k in seq_len_keys]):
|
|
for key in seq_len_keys:
|
|
if key in model_config:
|
|
model.seqlen = model_config[key]
|
|
break
|
|
else:
|
|
logger.warning("can't get model's sequence length from model config, will set to 4096.")
|
|
model.seqlen = 4096
|
|
model.eval()
|
|
|
|
return cls(model, False, quantize_config)
|
|
|
|
@classmethod
|
|
def from_quantized(
|
|
cls,
|
|
save_dir: str,
|
|
device: str = "cpu",
|
|
use_safetensors: bool = False,
|
|
use_triton: bool = False,
|
|
max_memory: Optional[dict] = None,
|
|
device_map: Optional[str] = None
|
|
):
|
|
"""load quantized model from local disk"""
|
|
if use_triton:
|
|
from ..nn_modules.qlinear_triton import autotune_warmup_linear
|
|
|
|
logger.warning("use_triton will force moving the hole model to GPU, make sure you have enough VRAM.")
|
|
device = "cuda:0"
|
|
|
|
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=True)
|
|
if config.model_type not in SUPPORTED_MODELS:
|
|
raise TypeError(f"{config.model_type} isn't supported yet.")
|
|
|
|
quantize_config = BaseQuantizeConfig.from_pretrained(save_dir)
|
|
|
|
model_save_name = join(save_dir, f"gptq_model-{quantize_config.bits}bit")
|
|
if use_safetensors:
|
|
model_save_name += ".safetensors"
|
|
else:
|
|
model_save_name += ".bin"
|
|
|
|
def skip(*args, **kwargs):
|
|
pass
|
|
|
|
torch.nn.init.kaiming_uniform_ = skip
|
|
torch.nn.init.uniform_ = skip
|
|
torch.nn.init.normal_ = skip
|
|
|
|
transformers.modeling_utils._init_weights = False
|
|
with accelerate.init_empty_weights():
|
|
torch.set_default_dtype(torch.half)
|
|
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
|
torch.set_default_dtype(torch.float)
|
|
|
|
layers = find_layers(model)
|
|
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
|
|
for name in list(layers.keys()):
|
|
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
|
|
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
|
|
del layers[name]
|
|
|
|
with accelerate.init_empty_weights():
|
|
make_quant(model, layers, quantize_config.bits, quantize_config.group_size, use_triton=use_triton)
|
|
model.tie_weights()
|
|
|
|
if max_memory and not device_map:
|
|
device_map = "auto"
|
|
if not max_memory and not device_map:
|
|
device_map = {"": device}
|
|
|
|
model = accelerate.load_checkpoint_and_dispatch(
|
|
model, model_save_name, device_map, max_memory, no_split_module_classes=[cls.layer_type]
|
|
)
|
|
|
|
model_config = model.config.to_dict()
|
|
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
|
|
if any([k in model_config for k in seq_len_keys]):
|
|
for key in seq_len_keys:
|
|
if key in model_config:
|
|
model.seqlen = model_config[key]
|
|
break
|
|
else:
|
|
logger.warning("can't get model's sequence length from model config, will set to 4096.")
|
|
model.seqlen = 4096
|
|
|
|
model.eval()
|
|
|
|
if use_triton:
|
|
autotune_warmup_linear(model, seqlen=model.seqlen)
|
|
|
|
return cls(model, True, quantize_config)
|
|
|
|
|
|
__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]
|