AutoGPTQ/auto_gptq/modeling/_base.py
2023-04-25 18:58:20 +08:00

392 lines
15 KiB
Python

import json
import os
from dataclasses import dataclass, field, fields
from logging import getLogger
from os.path import join
from typing import Dict, List, Union
import torch
import torch.nn as nn
import transformers
from safetensors.torch import load_file as safe_load, save_file as safe_save
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from ._const import *
from ._utils import *
from ..quantization import GPTQ, Quantizer
logger = getLogger(__name__)
@dataclass
class BaseQuantizeConfig:
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
damp_percent: float = field(default=0.01)
desc_act: bool = field(default=True)
group_size: int = field(default=-1)
def __post_init__(self):
fields_info = fields(self)
if self.bits not in fields_info[0].metadata["choices"]:
raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.")
if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("unless equal to -1, group_size must greater then 0.")
def save_pretrained(self, save_dir: str):
with open(join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2)
@classmethod
def from_pretrained(cls, save_dir: str):
with open(join(save_dir, "quantize_config.json"), "r", encoding="utf-8") as f:
return cls(**json.load(f))
def to_dict(self):
return {
"bits": self.bits,
"damp_percent": self.damp_percent,
"desc_act": self.desc_act,
"group_size": self.group_size
}
class BaseGPTQForCausalLM(nn.Module):
layers_block_name: str = None
outside_layer_modules: List[str] = None
inside_layer_modules: List[List[str]] = None
lm_head_name: str = "lm_head"
def __init__(self, model: PreTrainedModel, quantized: bool, quantize_config: BaseQuantizeConfig):
super().__init__()
self.model = model
self.model_type = self.model.config.model_type
self._quantized = quantized
self.quantize_config = quantize_config
self.config = self.model.config
@property
def quantized(self):
return self._quantized
def _move_outside_layer_modules(self, device):
for module_name in self.outside_layer_modules:
module = get_module_by_name(self.model, module_name)
if module is not None:
module.to(device)
@staticmethod
def _resize_attention_mask(attention_mask: List[torch.LongTensor]):
return attention_mask
@staticmethod
def _resize_position_ids(position_ids: List[torch.LongTensor]):
return position_ids
@torch.no_grad()
def quantize(self, examples: List[Dict[str, torch.LongTensor]], use_triton: bool = False):
if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.")
layer_inputs = []
attention_masks = []
position_ids = []
layer_outputs = []
layer_input_kwargs = []
class LayerHijacker(nn.Module):
"""hijack layer's forward pass to cache data"""
def __init__(self, m):
super().__init__()
self.module = m
def forward(self, inp=None, **kwargs):
if inp is None: # some models use all key-value arguments in forward pass call
for kwarg_name in ["hidden_states"]:
if kwarg_name in kwargs:
inp = kwargs[kwarg_name]
break
bsz = inp.size(0)
for i in range(bsz):
layer_inputs.append(inp[i].unsqueeze(0).to(CPU))
attention_masks.append(kwargs["attention_mask"][i].to(CPU))
if (pos_ids := kwargs.get("position_ids", None)) is not None:
position_ids.append(pos_ids[i].unsqueeze(0).to(CPU))
one_kwargs = dict()
for k, v in kwargs.items(): # make sure other arguments also be captured
if k not in ["hidden_states", "attention_mask", "position_ids"]:
if isinstance(v, torch.Tensor):
one_kwargs[k] = v[i].unsqueeze(0).to(CPU)
else:
one_kwargs[k] = v
layer_input_kwargs.append(one_kwargs)
raise ValueError
forward_pass_use_cache = self.model.config.use_cache
self.model.config.use_cache = False
num_examples = len(examples)
layers = get_module_by_name(self.model, self.layers_block_name)
layers[0] = layers[0].to(CUDA)
self._move_outside_layer_modules(CUDA)
# get inputs for first layer
layers[0] = LayerHijacker(layers[0])
for example in examples:
for k, v in example.items():
if len(v.shape) == 1:
v = v.unsqueeze(0)
example[k] = v.to(CUDA)
try:
self.model(**example)
except ValueError:
pass
layers[0] = layers[0].module
layers[0] = layers[0].cpu()
self._move_outside_layer_modules(CPU)
torch.cuda.empty_cache()
# resize attention mask for some special models
attention_masks = self._resize_attention_mask(attention_masks)
position_ids = self._resize_position_ids(position_ids)
quantizers = {}
for i in range(len(layers)):
logger.info(f"Start quantizing layer {i + 1}/{len(layers)}")
layer = layers[i].to(CUDA)
full = find_layers(layer)
for names in self.inside_layer_modules:
subset = {n: full[n] for n in names}
gptq = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
gptq[name].quantizer = Quantizer()
gptq[name].quantizer.configure(
self.quantize_config.bits,
perchannel=True,
sym=True,
mse=False
)
def add_batch(name):
def tmp(_, inp, out):
gptq[name].add_batch(inp[0].data, out.data)
return tmp
handles = []
for name in subset:
handles.append(subset[name].register_forward_hook(add_batch(name)))
for j in range(num_examples):
layer_input = layer_inputs[j].to(CUDA)
layer_attention_mask = attention_masks[j].to(CUDA)
additional_layer_inputs = {
"attention_mask": layer_attention_mask
}
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA)) is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = v.to(CUDA)
else:
additional_layer_inputs[k] = v
layer(layer_input, **additional_layer_inputs)[0][0].cpu()
for h in handles:
h.remove()
for name in subset:
logger.info(f'Quantizing {name} in layer {i + 1}/{len(layers)}...')
scale, zero, g_idx = gptq[name].fasterquant(
percdamp=self.quantize_config.damp_percent,
groupsize=self.quantize_config.group_size,
actorder=self.quantize_config.desc_act
)
quantizers[f'{self.layers_block_name}.{i}.{name}'] = (
gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu()
)
gptq[name].free()
for j in range(num_examples):
layer_input = layer_inputs[j].to(CUDA)
layer_attention_mask = attention_masks[j].to(CUDA)
additional_layer_inputs = {
"attention_mask": layer_attention_mask
}
if (layer_position_ids := None if not position_ids else position_ids[j].to(CUDA)) is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
additional_layer_inputs[k] = v.to(CUDA)
else:
additional_layer_inputs[k] = v
layer_output = layer(layer_input, **additional_layer_inputs)[0][0].cpu()
layer_outputs.append(layer_output.unsqueeze(0))
layers[i] = layer.to(CPU)
del layer
del gptq
torch.cuda.empty_cache()
layer_inputs, layer_outputs = layer_outputs, []
pack_model(
model=self.model,
quantizers=quantizers,
bits=self.quantize_config.bits,
group_size=self.quantize_config.group_size,
use_triton=use_triton
)
self._quantized = True
self.model.config.use_cache = forward_pass_use_cache
@property
def device(self):
return self.model.device
def to(self, device: Union[str, torch.device]):
self.model.to(device)
def forward(self, **kwargs):
return self.model(**kwargs)
def generate(self, **kwargs):
"""shortcut for model.generate"""
return self.model.generate(**kwargs)
def prepare_inputs_for_generation(self, *args, **kwargs):
"""shortcut for model.prepare_inputs_for_generation"""
return self.model.prepare_inputs_for_generation(*args, **kwargs)
def save_quantized(self, save_dir: str, use_safetensors: bool = False):
"""save quantized model and configs to local disk"""
os.makedirs(save_dir, exist_ok=True)
if not self.quantized:
raise EnvironmentError("can only save quantized model, please execute .quantize first.")
self.model.to(CPU)
model_save_name = f"gptq_model-{self.quantize_config.bits}bit"
if use_safetensors:
model_save_name += ".safetensors"
state_dict = self.model.state_dict()
state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
safe_save(state_dict, join(save_dir, model_save_name))
else:
model_save_name += ".bin"
torch.save(self.model.state_dict(), join(save_dir, model_save_name))
self.model.config.save_pretrained(save_dir)
self.quantize_config.save_pretrained(save_dir)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
bf16: bool = False,
**model_init_kwargs
):
"""load un-quantized pretrained model to cpu"""
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
# enforce some values despite user specified
model_init_kwargs["device_map"] = None
model_init_kwargs["torch_dtype"] = torch.bfloat16 if bf16 else torch.float16
model_init_kwargs["low_cpu_mem_usage"] = False
model_init_kwargs["trust_remote_code"] = True
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, **model_init_kwargs)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
logger.warning("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()
return cls(model, False, quantize_config)
@classmethod
def from_quantized(
cls,
save_dir: str,
device: str = "cpu",
use_safetensors: bool = False,
use_triton: bool = False
):
"""load quantized model from local disk"""
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=True)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
quantize_config = BaseQuantizeConfig.from_pretrained(save_dir)
model_save_name = join(save_dir, f"gptq_model-{quantize_config.bits}bit")
if use_safetensors:
model_save_name += ".safetensors"
else:
model_save_name += ".bin"
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
torch.set_default_dtype(torch.float)
model = model.eval()
layers = find_layers(model)
for name in [cls.lm_head_name]:
if name in layers:
del layers[name]
make_quant(model, layers, quantize_config.bits, quantize_config.group_size, use_triton=use_triton)
if model_save_name.endswith('.safetensors'):
model.load_state_dict(safe_load(model_save_name, "cpu"))
else:
model.load_state_dict(torch.load(model_save_name))
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
logger.warning("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()
model.to(device)
return cls(model, True, quantize_config)
__all__ = ["BaseGPTQForCausalLM", "BaseQuantizeConfig"]