Add support for InternLM

This commit is contained in:
tc 2023-07-07 09:25:40 -07:00
parent 590219d048
commit e28e8ee809
4 changed files with 46 additions and 7 deletions

View file

@ -11,3 +11,4 @@ from .rw import *
from .gpt_bigcode import *
from .codegen import *
from .baichuan import *
from .internlm import *

View file

@ -7,7 +7,20 @@ from ..utils.import_utils import compare_transformers_version
CPU = device("cpu")
CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode", "codegen", "RefinedWebModel", "RefinedWeb", "baichuan"]
SUPPORTED_MODELS = [
"bloom",
"gptj",
"gpt2",
"gpt_neox",
"opt",
"moss",
"gpt_bigcode",
"codegen",
"RefinedWebModel",
"RefinedWeb",
"baichuan",
"internlm",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")

View file

@ -14,7 +14,7 @@ from .opt import OPTGPTQForCausalLM
from .rw import RWGPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM,
@ -27,8 +27,9 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM,
"RefinedWeb":RWGPTQForCausalLM,
"baichuan":BaiChuanGPTQForCausalLM
"RefinedWeb": RWGPTQForCausalLM,
"baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM,
}
@ -49,7 +50,9 @@ class AutoGPTQForCausalLM:
trust_remote_code: bool = False,
**model_init_kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(pretrained_model_name_or_path, trust_remote_code)
model_type = check_and_get_model_type(
pretrained_model_name_or_path, trust_remote_code
)
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path,
quantize_config=quantize_config,
@ -79,9 +82,15 @@ class AutoGPTQForCausalLM:
trainable: bool = False,
**kwargs
) -> BaseGPTQForCausalLM:
model_type = check_and_get_model_type(save_dir or model_name_or_path, trust_remote_code)
model_type = check_and_get_model_type(
save_dir or model_name_or_path, trust_remote_code
)
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
keywords = {key: kwargs[key] for key in signature(quant_func).parameters if key in kwargs}
keywords = {
key: kwargs[key]
for key in signature(quant_func).parameters
if key in kwargs
}
return quant_func(
model_name_or_path=model_name_or_path,
save_dir=save_dir,

View file

@ -0,0 +1,16 @@
from ._base import *
class InternLMGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "InternLMDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["InternLMGPTQForCausalLM"]