Merge remote-tracking branches 'laaza/Mistral' and 'laaza/MPT'

This commit is contained in:
Automation Pipeline 2023-10-22 07:53:59 -04:00
commit 9fb99f61e7
5 changed files with 45 additions and 1 deletions

View file

@ -13,3 +13,5 @@ from .codegen import *
from .baichuan import * from .baichuan import *
from .internlm import * from .internlm import *
from .qwen import * from .qwen import *
from .mistral import *
from .mpt import *

View file

@ -21,11 +21,15 @@ SUPPORTED_MODELS = [
"baichuan", "baichuan",
"internlm", "internlm",
"qwen", "qwen",
"mpt",
] ]
if compare_transformers_version("v4.28.0", op="ge"): if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama") SUPPORTED_MODELS.append("llama")
if compare_transformers_version("v4.33.0", op="ge"): if compare_transformers_version("v4.33.0", op="ge"):
SUPPORTED_MODELS.append("falcon") SUPPORTED_MODELS.append("falcon")
if compare_transformers_version("v4.34.0", op="ge"):
SUPPORTED_MODELS.append("mistral")
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

View file

@ -16,6 +16,8 @@ from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM from .internlm import InternLMGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM from .qwen import QwenGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = { GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM, "bloom": BloomGPTQForCausalLM,
@ -33,6 +35,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"baichuan": BaiChuanGPTQForCausalLM, "baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM, "internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM, "qwen": QwenGPTQForCausalLM,
"mistral": MistralGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM,
} }

View file

@ -0,0 +1,16 @@
from ._base import *
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MistralDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["MistralGPTQForCausalLM"]

18
auto_gptq/modeling/mpt.py Normal file
View file

@ -0,0 +1,18 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MPTBlock"
layers_block_name = "transformer.blocks"
outside_layer_modules = [
"transformer.wte", "transformer.norm_f"
]
inside_layer_modules = [
["attn.Wqkv"],
["attn.out_proj"],
["ffn.up_proj"],
["ffn.down_proj"]
]
__all__ = ["MPTGPTQForCausalLM"]