Add support for Mistral models.

This commit is contained in:
LaaZa 2023-10-04 01:07:55 +03:00
parent 51c043c6be
commit 99acbead42
4 changed files with 23 additions and 1 deletions

View file

@ -12,4 +12,5 @@ from .gpt_bigcode import *
from .codegen import * from .codegen import *
from .baichuan import * from .baichuan import *
from .internlm import * from .internlm import *
from .qwen import * from .qwen import *
from .mistral import *

View file

@ -26,6 +26,9 @@ if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama") SUPPORTED_MODELS.append("llama")
if compare_transformers_version("v4.33.0", op="ge"): if compare_transformers_version("v4.33.0", op="ge"):
SUPPORTED_MODELS.append("falcon") SUPPORTED_MODELS.append("falcon")
if compare_transformers_version("v4.34.0", op="ge"):
SUPPORTED_MODELS.append("mistral")
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

View file

@ -16,6 +16,7 @@ from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM from .internlm import InternLMGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM from .qwen import QwenGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = { GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM, "bloom": BloomGPTQForCausalLM,
@ -33,6 +34,7 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"baichuan": BaiChuanGPTQForCausalLM, "baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM, "internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM, "qwen": QwenGPTQForCausalLM,
"mistral": MistralGPTQForCausalLM,
} }

View file

@ -0,0 +1,16 @@
from ._base import *
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MistralDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["MistralGPTQForCausalLM"]