AutoGPTQ/auto_gptq/modeling/auto.py
2023-05-08 17:34:00 +03:00

78 lines
2.6 KiB
Python

from typing import Optional
from ._base import BaseQuantizeConfig, BaseGPTQForCausalLM
from ._utils import check_and_get_model_type
from .bloom import BloomGPTQForCausalLM
from .gpt_neox import GPTNeoXGPTQForCausalLM
from .gptj import GPTJGPTQForCausalLM
from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM,
"gpt_neox": GPTNeoXGPTQForCausalLM,
"gptj": GPTJGPTQForCausalLM,
"gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM
}
class AutoGPTQForCausalLM:
def __init__(self):
raise EnvironmentError(
"AutoGPTQModelForCausalLM is designed to be instantiated\n"
"using `AutoGPTQModelForCausalLM.from_pretrained` if want to quantize a pretrained model.\n"
"using `AutoGPTQModelForCausalLM.from_quantized` if want to inference with quantized model."
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
quantize_config: BaseQuantizeConfig,
max_memory: Optional[dict] = None,
**model_init_kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(pretrained_model_name_or_path)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
quantize_config=quantize_config,
max_memory=max_memory,
**model_init_kwargs
)
@classmethod
def from_quantized(
cls,
save_dir: str,
device: str = "cpu",
use_safetensors: bool = False,
use_triton: bool = False,
max_memory: Optional[dict] = None,
device_map: Optional[str] = None,
quantize_config: Optional[BaseQuantizeConfig] = None,
model_basename: Optional[str] = None,
trust_remote_code: bool = False
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(save_dir)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
save_dir=save_dir,
device=device,
use_safetensors=use_safetensors,
use_triton=use_triton,
max_memory=max_memory,
device_map=device_map,
quantize_config=quantize_config,
model_basename=model_basename,
trust_remote_code=trust_remote_code
)
__all__ = ["AutoGPTQForCausalLM"]