Merge remote-tracking branches 'laaza/Mistral' and 'laaza/MPT'

This commit is contained in:
Automation Pipeline 2023-10-22 07:53:59 -04:00
commit 9fb99f61e7
5 changed files with 45 additions and 1 deletions

View file

@ -13,3 +13,5 @@ from .codegen import *
from .baichuan import *
from .internlm import *
from .qwen import *
from .mistral import *
from .mpt import *

View file

@ -21,11 +21,15 @@ SUPPORTED_MODELS = [
"baichuan",
"internlm",
"qwen",
"mpt",
]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")
if compare_transformers_version("v4.33.0", op="ge"):
SUPPORTED_MODELS.append("falcon")
if compare_transformers_version("v4.34.0", op="ge"):
SUPPORTED_MODELS.append("mistral")
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048

View file

@ -16,6 +16,8 @@ from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
from .baichuan import BaiChuanGPTQForCausalLM
from .internlm import InternLMGPTQForCausalLM
from .qwen import QwenGPTQForCausalLM
from .mistral import MistralGPTQForCausalLM
from .mpt import MPTGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = {
"bloom": BloomGPTQForCausalLM,
@ -33,6 +35,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"baichuan": BaiChuanGPTQForCausalLM,
"internlm": InternLMGPTQForCausalLM,
"qwen": QwenGPTQForCausalLM,
"mistral": MistralGPTQForCausalLM,
"mpt": MPTGPTQForCausalLM,
}

View file

@ -0,0 +1,16 @@
from ._base import *
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MistralDecoderLayer"
layers_block_name = "model.layers"
outside_layer_modules = ["model.embed_tokens", "model.norm"]
inside_layer_modules = [
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
["self_attn.o_proj"],
["mlp.up_proj", "mlp.gate_proj"],
["mlp.down_proj"],
]
__all__ = ["MistralGPTQForCausalLM"]

18
auto_gptq/modeling/mpt.py Normal file
View file

@ -0,0 +1,18 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "MPTBlock"
layers_block_name = "transformer.blocks"
outside_layer_modules = [
"transformer.wte", "transformer.norm_f"
]
inside_layer_modules = [
["attn.Wqkv"],
["attn.out_proj"],
["ffn.up_proj"],
["ffn.down_proj"]
]
__all__ = ["MPTGPTQForCausalLM"]