Add initial support for MPT

This commit is contained in:
LaaZa 2023-05-12 14:46:52 +03:00
parent 393a2fbac2
commit fb380fb9c2
4 changed files with 23 additions and 2 deletions

View file

@ -7,3 +7,4 @@ from .gptj import *
from .llama import *
from .moss import *
from .opt import *
from .mpt import *

View file

@ -6,7 +6,7 @@ from transformers import __version__ as transformers_version
CPU = device("cpu")
CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "mpt"]
if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS.append("llama")

View file

@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = {
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM
"moss": MOSSGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM
}

18
auto_gptq/modeling/mpt.py Normal file
View file

@ -0,0 +1,18 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MPTBlock"
layers_block_name = "transformer.blocks"
outside_layer_modules = [
"transformer.wte", "transformer.norm_f"
]
inside_layer_modules = [
["attn.Wqkv"],
["attn.out_proj"],
["ffn.up_proj"],
["ffn.down_proj"]
]
__all__ = ["MPTGPTQForCausalLM"]