Add support for InternLM
This commit is contained in:
parent
590219d048
commit
e28e8ee809
4 changed files with 46 additions and 7 deletions
|
@ -11,3 +11,4 @@ from .rw import *
|
||||||
from .gpt_bigcode import *
|
from .gpt_bigcode import *
|
||||||
from .codegen import *
|
from .codegen import *
|
||||||
from .baichuan import *
|
from .baichuan import *
|
||||||
|
from .internlm import *
|
||||||
|
|
|
@ -7,7 +7,20 @@ from ..utils.import_utils import compare_transformers_version
|
||||||
CPU = device("cpu")
|
CPU = device("cpu")
|
||||||
CUDA_0 = device("cuda:0")
|
CUDA_0 = device("cuda:0")
|
||||||
|
|
||||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode", "codegen", "RefinedWebModel", "RefinedWeb", "baichuan"]
|
SUPPORTED_MODELS = [
|
||||||
|
"bloom",
|
||||||
|
"gptj",
|
||||||
|
"gpt2",
|
||||||
|
"gpt_neox",
|
||||||
|
"opt",
|
||||||
|
"moss",
|
||||||
|
"gpt_bigcode",
|
||||||
|
"codegen",
|
||||||
|
"RefinedWebModel",
|
||||||
|
"RefinedWeb",
|
||||||
|
"baichuan",
|
||||||
|
"internlm",
|
||||||
|
]
|
||||||
if compare_transformers_version("v4.28.0", op="ge"):
|
if compare_transformers_version("v4.28.0", op="ge"):
|
||||||
SUPPORTED_MODELS.append("llama")
|
SUPPORTED_MODELS.append("llama")
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from .opt import OPTGPTQForCausalLM
|
||||||
from .rw import RWGPTQForCausalLM
|
from .rw import RWGPTQForCausalLM
|
||||||
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
||||||
from .baichuan import BaiChuanGPTQForCausalLM
|
from .baichuan import BaiChuanGPTQForCausalLM
|
||||||
|
from .internlm import InternLMGPTQForCausalLM
|
||||||
|
|
||||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"bloom": BloomGPTQForCausalLM,
|
"bloom": BloomGPTQForCausalLM,
|
||||||
|
@ -27,8 +27,9 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
|
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
|
||||||
"codegen": CodeGenGPTQForCausalLM,
|
"codegen": CodeGenGPTQForCausalLM,
|
||||||
"RefinedWebModel": RWGPTQForCausalLM,
|
"RefinedWebModel": RWGPTQForCausalLM,
|
||||||
"RefinedWeb":RWGPTQForCausalLM,
|
"RefinedWeb": RWGPTQForCausalLM,
|
||||||
"baichuan":BaiChuanGPTQForCausalLM
|
"baichuan": BaiChuanGPTQForCausalLM,
|
||||||
|
"internlm": InternLMGPTQForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,7 +50,9 @@ class AutoGPTQForCausalLM:
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
**model_init_kwargs
|
**model_init_kwargs
|
||||||
) -> BaseGPTQForCausalLM:
|
) -> BaseGPTQForCausalLM:
|
||||||
model_type = check_and_get_model_type(pretrained_model_name_or_path, trust_remote_code)
|
model_type = check_and_get_model_type(
|
||||||
|
pretrained_model_name_or_path, trust_remote_code
|
||||||
|
)
|
||||||
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
|
return GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
|
||||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||||
quantize_config=quantize_config,
|
quantize_config=quantize_config,
|
||||||
|
@ -79,9 +82,15 @@ class AutoGPTQForCausalLM:
|
||||||
trainable: bool = False,
|
trainable: bool = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> BaseGPTQForCausalLM:
|
) -> BaseGPTQForCausalLM:
|
||||||
model_type = check_and_get_model_type(save_dir or model_name_or_path, trust_remote_code)
|
model_type = check_and_get_model_type(
|
||||||
|
save_dir or model_name_or_path, trust_remote_code
|
||||||
|
)
|
||||||
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
|
quant_func = GPTQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized
|
||||||
keywords = {key: kwargs[key] for key in signature(quant_func).parameters if key in kwargs}
|
keywords = {
|
||||||
|
key: kwargs[key]
|
||||||
|
for key in signature(quant_func).parameters
|
||||||
|
if key in kwargs
|
||||||
|
}
|
||||||
return quant_func(
|
return quant_func(
|
||||||
model_name_or_path=model_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
save_dir=save_dir,
|
save_dir=save_dir,
|
||||||
|
|
16
auto_gptq/modeling/internlm.py
Normal file
16
auto_gptq/modeling/internlm.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from ._base import *
|
||||||
|
|
||||||
|
|
||||||
|
class InternLMGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
layer_type = "InternLMDecoderLayer"
|
||||||
|
layers_block_name = "model.layers"
|
||||||
|
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||||
|
inside_layer_modules = [
|
||||||
|
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
||||||
|
["self_attn.o_proj"],
|
||||||
|
["mlp.up_proj", "mlp.gate_proj"],
|
||||||
|
["mlp.down_proj"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["InternLMGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue