78 lines
2.5 KiB
Python
78 lines
2.5 KiB
Python
from logging import getLogger
|
|
from os.path import join, isfile
|
|
from typing import Optional, Union
|
|
|
|
import accelerate
|
|
import torch
|
|
import transformers
|
|
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
from ._const import *
|
|
from ._utils import *
|
|
|
|
from ._base import *
|
|
from ..nn_modules.fused_llama_attn import FusedLlamaAttentionForQuantizedModel
|
|
from ..nn_modules.fused_llama_mlp import FusedLlamaMLPForQuantizedModel
|
|
|
|
logger = getLogger(__name__)
|
|
|
|
|
|
class LlamaGPTQForCausalLM(BaseGPTQForCausalLM):
|
|
layer_type = "LlamaDecoderLayer"
|
|
layers_block_name = "model.layers"
|
|
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
|
inside_layer_modules = [
|
|
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
|
["self_attn.o_proj"],
|
|
["mlp.up_proj", "mlp.gate_proj"],
|
|
["mlp.down_proj"]
|
|
]
|
|
|
|
fused_attn_module_type = FusedLlamaAttentionForQuantizedModel
|
|
fused_mlp_module_type = FusedLlamaMLPForQuantizedModel
|
|
|
|
@classmethod
|
|
def from_quantized(
|
|
cls,
|
|
save_dir: str,
|
|
device_map: Optional[str] = None,
|
|
max_memory: Optional[dict] = None,
|
|
device: Optional[Union[str, int]] = None,
|
|
strict: bool = True,
|
|
use_triton: bool = False,
|
|
inject_fused_attention: bool = False,
|
|
inject_fused_mlp: bool = False,
|
|
use_cuda_fp16: bool = True,
|
|
quantize_config: Optional[BaseQuantizeConfig] = None,
|
|
model_basename: Optional[str] = None,
|
|
use_safetensors: bool = False,
|
|
trust_remote_code: bool = False,
|
|
warmup_triton: bool = True,
|
|
**kwargs
|
|
):
|
|
model = super(LlamaGPTQForCausalLM, cls).from_quantized(
|
|
save_dir=save_dir,
|
|
device_map=device_map,
|
|
max_memory=max_memory,
|
|
device=device,
|
|
strict=strict,
|
|
use_triton=use_triton,
|
|
inject_fused_attention=inject_fused_attention,
|
|
inject_fused_mlp=inject_fused_mlp,
|
|
use_cuda_fp16=use_cuda_fp16,
|
|
quantize_config=quantize_config,
|
|
model_basename=model_basename,
|
|
use_safetensors=use_safetensors,
|
|
trust_remote_code=trust_remote_code,
|
|
warmup_triton=warmup_triton,
|
|
**kwargs
|
|
)
|
|
|
|
if use_triton and warmup_triton and inject_fused_mlp:
|
|
from ..nn_modules.fused_llama_mlp import autotune_warmup_fused
|
|
autotune_warmup_fused(model, seqlen=model.seqlen)
|
|
|
|
return model
|
|
|
|
|
|
__all__ = ["LlamaGPTQForCausalLM"]
|