Add initial support for MPT
This commit is contained in:
parent
393a2fbac2
commit
fb380fb9c2
4 changed files with 23 additions and 2 deletions
|
@ -7,3 +7,4 @@ from .gptj import *
|
|||
from .llama import *
|
||||
from .moss import *
|
||||
from .opt import *
|
||||
from .mpt import *
|
|
@ -6,7 +6,7 @@ from transformers import __version__ as transformers_version
|
|||
CPU = device("cpu")
|
||||
CUDA_0 = device("cuda:0")
|
||||
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "mpt"]
|
||||
if parse_version(transformers_version) >= parse_version("v4.28.0"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
|
|||
from .llama import LlamaGPTQForCausalLM
|
||||
from .moss import MOSSGPTQForCausalLM
|
||||
from .opt import OPTGPTQForCausalLM
|
||||
from .mpt import MPTGPTQForCausalLM
|
||||
|
||||
|
||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||
|
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
|||
"gpt2": GPT2GPTQForCausalLM,
|
||||
"llama": LlamaGPTQForCausalLM,
|
||||
"opt": OPTGPTQForCausalLM,
|
||||
"moss": MOSSGPTQForCausalLM
|
||||
"moss": MOSSGPTQForCausalLM,
|
||||
"mpt": MPTGPTQForCausalLM
|
||||
}
|
||||
|
||||
|
||||
|
|
18
auto_gptq/modeling/mpt.py
Normal file
18
auto_gptq/modeling/mpt.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||
|
||||
|
||||
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "MPTBlock"
|
||||
layers_block_name = "transformer.blocks"
|
||||
outside_layer_modules = [
|
||||
"transformer.wte", "transformer.norm_f"
|
||||
]
|
||||
|
||||
inside_layer_modules = [
|
||||
["attn.Wqkv"],
|
||||
["attn.out_proj"],
|
||||
["ffn.up_proj"],
|
||||
["ffn.down_proj"]
|
||||
]
|
||||
|
||||
__all__ = ["MPTGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue