AutoGPTQ/auto_gptq/modeling/llama.py
2023-05-02 18:51:04 +09:00

136 lines
4.9 KiB
Python

import copy
import json
import os
from dataclasses import dataclass, field, fields
from logging import getLogger
from os.path import join
from typing import Dict, List, Optional, Union
import accelerate
import torch
import torch.nn as nn
import transformers
from accelerate.hooks import remove_hook_from_module
from safetensors.torch import save_file as safe_save
from transformers import AutoConfig, AutoModelForCausalLM, PreTrainedModel
from transformers.utils.hub import PushToHubMixin
from ._const import *
from ._utils import *
from ..quantization import GPTQ
from ..utils.data_utils import collate_data
from ._base import *
from ..nn_modules.fused_mlp_triton import make_fused_mlp, autotune_warmup_fused
from ..nn_modules.fused_attn import make_quant_attn
logger = getLogger(__name__)
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "LlamaDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"]
]
@classmethod
def from_quantized(
cls,
save_dir: str,
device: str = "cpu",
use_safetensors: bool = False,
use_triton: bool = False,
fused_attn: bool = False,
fused_mlp: bool = False,
max_memory: Optional[dict] = None,
device_map: Optional[str] = None,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
trust_remote_code: bool = False
):
"""load quantized model from local disk"""
if use_triton:
from ..nn_modules.qlinear_triton import autotune_warmup_linear
logger.warning("use_triton will force moving the hole model to GPU, make sure you have enough VRAM.")
device = "cuda:0"
config = AutoConfig.from_pretrained(save_dir, trust_remote_code=trust_remote_code)
if config.model_type not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
if quantize_config is None:
quantize_config = BaseQuantizeConfig.from_pretrained(save_dir)
if model_basename is None:
model_basename = f"gptq_model-{quantize_config.bits}bit-{quantize_config.group_size}g"
model_save_name = join(save_dir, model_basename)
if use_safetensors:
model_save_name += ".safetensors"
else:
model_save_name += ".bin"
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
transformers.modeling_utils._init_weights = False
with accelerate.init_empty_weights():
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=trust_remote_code)
torch.set_default_dtype(torch.float)
layers = find_layers(model)
ignore_layers = [cls.lm_head_name] + cls.outside_layer_modules
for name in list(layers.keys()):
if any([name.startswith(ignore_layer) for ignore_layer in ignore_layers]):
logger.info(f"{name} not been quantized, will be ignored when make_quant.")
del layers[name]
with accelerate.init_empty_weights():
make_quant(model, layers, quantize_config.bits, quantize_config.group_size, use_triton=use_triton, desc_act=quantize_config.desc_act)
model.tie_weights()
if max_memory and not device_map:
device_map = "auto"
if not max_memory and not device_map:
device_map = {"": device}
model = accelerate.load_checkpoint_and_dispatch(
model, model_save_name, device_map, max_memory, no_split_module_classes=[cls.layer_type]
)
if fused_attn:
make_quant_attn(model, use_triton=use_triton, desc_act=quantize_config.desc_act)
if use_triton and fused_mlp:
make_fused_mlp(model)
model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions"]
if any([k in model_config for k in seq_len_keys]):
for key in seq_len_keys:
if key in model_config:
model.seqlen = model_config[key]
break
else:
logger.warning("can't get model's sequence length from model config, will set to 4096.")
model.seqlen = 4096
model.eval()
if use_triton:
autotune_warmup_linear(model, seqlen=model.seqlen)
if fused_mlp:
autotune_warmup_fused(model, seqlen=model.seqlen)
return cls(model, True, quantize_config)
__all__ = ["LlamaGPTQForCausalLM"]