Add support for InternLM
This commit is contained in:
parent
590219d048
commit
e28e8ee809
4 changed files with 46 additions and 7 deletions
|
@ -11,3 +11,4 @@ from .rw import *
|
|||
from .gpt_bigcode import *
|
||||
from .codegen import *
|
||||
from .baichuan import *
|
||||
from .internlm import *
|
||||
|
|
|
@ -7,7 +7,20 @@ from ..utils.import_utils import compare_transformers_version
|
|||
CPU = device("cpu")
|
||||
CUDA_0 = device("cuda:0")
|
||||
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode", "codegen", "RefinedWebModel", "RefinedWeb", "baichuan"]
|
||||
SUPPORTED_MODELS = [
|
||||
"bloom",
|
||||
"gptj",
|
||||
"gpt2",
|
||||
"gpt_neox",
|
||||
"opt",
|
||||
"moss",
|
||||
"gpt_bigcode",
|
||||
"codegen",
|
||||
"RefinedWebModel",
|
||||
"RefinedWeb",
|
||||
"baichuan",
|
||||
"internlm",
|
||||
]
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from .opt import OPTGPTQForCausalLM
|
|||
from .rw import RWGPTQForCausalLM
|
||||
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
||||
from .baichuan import BaiChuanGPTQForCausalLM
|
||||
|
||||
from .internlm import InternLMGPTQForCausalLM
|
||||
|
||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||
"bloom": BloomGPTQForCausalLM,
|
||||
|
@ -27,8 +27,9 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
|||
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
|
||||
"codegen": CodeGenGPTQForCausalLM,
|
||||
"RefinedWebModel": RWGPTQForCausalLM,
|
||||
"RefinedWeb":RWGPTQForCausalLM,
|
||||
"baichuan":BaiChuanGPTQForCausalLM
|
||||
"RefinedWeb": RWGPTQForCausalLM,
|
||||
"baichuan": BaiChuanGPTQForCausalLM,
|
||||
"internlm": InternLMGPTQForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
|
@ -49,7 +50,9 @@ class AutoGPTQForCausalLM:
|
|||
trust_remote_code: bool = False,
|
||||
**model_init_kwargs
|
||||
) -> BaseGPTQForCausalLM:
|
||||
model_type = check_and_get_model_type(pretrained_model_name_or_path, trust_remote_code)
|
||||
model_type = check_and_get_model_type(
|
||||
pretrained_model_name_or_path, trust_remote_code
|
||||
)
|
||||
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
quantize_config=quantize_config,
|
||||
|
@ -79,9 +82,15 @@ class AutoGPTQForCausalLM:
|
|||
trainable: bool = False,
|
||||
**kwargs
|
||||
) -> BaseGPTQForCausalLM:
|
||||
model_type = check_and_get_model_type(save_dir or model_name_or_path, trust_remote_code)
|
||||
model_type = check_and_get_model_type(
|
||||
save_dir or model_name_or_path, trust_remote_code
|
||||
)
|
||||
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
|
||||
keywords = {key: kwargs[key] for key in signature(quant_func).parameters if key in kwargs}
|
||||
keywords = {
|
||||
key: kwargs[key]
|
||||
for key in signature(quant_func).parameters
|
||||
if key in kwargs
|
||||
}
|
||||
return quant_func(
|
||||
model_name_or_path=model_name_or_path,
|
||||
save_dir=save_dir,
|
||||
|
|
16
auto_gptq/modeling/internlm.py
Normal file
16
auto_gptq/modeling/internlm.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class InternLMGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "InternLMDecoderLayer"
|
||||
layers_block_name = "model.layers"
|
||||
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||
inside_layer_modules = [
|
||||
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
||||
["self_attn.o_proj"],
|
||||
["mlp.up_proj", "mlp.gate_proj"],
|
||||
["mlp.down_proj"],
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["InternLMGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue