from typing import Dict, Optional, Union from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM from ._utils import check_and_get_model_type from .bloom import BloomGPTQForCausalLM from .gpt_neox import GPTNeoXGPTQForCausalLM from .gptj import GPTJGPTQForCausalLM from .gpt2 import GPT2GPTQForCausalLM from .llama import LlamaGPTQForCausalLM from .moss import MOSSGPTQForCausalLM from .opt import OPTGPTQForCausalLM from .rw import RWGPTQForCausalLM from inspect import signature GPTQ_CAUSAL_LM_MODEL_MAP = { "bloom": BloomGPTQForCausalLM, "gpt_neox": GPTNeoXGPTQForCausalLM, "gptj": GPTJGPTQForCausalLM, "gpt2": GPT2GPTQForCausalLM, "llama": LlamaGPTQForCausalLM, "opt": OPTGPTQForCausalLM, "moss": MOSSGPTQForCausalLM, "RefinedWebModel": RWGPTQForCausalLM } class AutoGPTQForCausalLM: def __init__(self): raise EnvironmentError( "AutoGPTQModelForCausalLM is designed to be instantiated\n" "using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n" "using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model." ) @classmethod def from_pretrained( cls, pretrained_model_name_or_path: str, quantize_config: BaseQuantizeConfig, max_memory: Optional[dict] = None, trust_remote_code: bool = False, **model_init_kwargs ) -> BaseGPTQForCausalLM: model_type = check_and_get_model_type(pretrained_model_name_or_path, trust_remote_code) return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( pretrained_model_name_or_path=pretrained_model_name_or_path, quantize_config=quantize_config, max_memory=max_memory, trust_remote_code=trust_remote_code, **model_init_kwargs ) @classmethod def from_quantized( cls, save_dir: str, device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, max_memory: Optional[dict] = None, device: Optional[Union[str, int]] = None, low_cpu_mem_usage: bool = False, use_triton: bool = False, inject_fused_attention: bool = True, inject_fused_mlp: bool = True, use_cuda_fp16: bool = True, quantize_config: Optional[BaseQuantizeConfig] = None, model_basename: Optional[str] = None, use_safetensors: bool = False, trust_remote_code: bool = False, warmup_triton: bool = False, **kwargs ) -> BaseGPTQForCausalLM: model_type = check_and_get_model_type(save_dir, trust_remote_code) quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized keywords = {key: kwargs[key] for key in signature(quant_func).parameters if key in kwargs} return quant_func( save_dir=save_dir, device_map=device_map, max_memory=max_memory, device=device, low_cpu_mem_usage=low_cpu_mem_usage, use_triton=use_triton, inject_fused_attention=inject_fused_attention, inject_fused_mlp=inject_fused_mlp, use_cuda_fp16=use_cuda_fp16, quantize_config=quantize_config, model_basename=model_basename, use_safetensors=use_safetensors, trust_remote_code=trust_remote_code, warmup_triton=warmup_triton, **keywords ) __all__ = ["AutoGPTQForCausalLM"]