Merge remote-tracking branches 'laaza/Mistral' and 'laaza/MPT'
This commit is contained in:
commit
9fb99f61e7
5 changed files with 45 additions and 1 deletions
|
@ -12,4 +12,6 @@ from .gpt_bigcode import *
|
||||||
from .codegen import *
|
from .codegen import *
|
||||||
from .baichuan import *
|
from .baichuan import *
|
||||||
from .internlm import *
|
from .internlm import *
|
||||||
from .qwen import *
|
from .qwen import *
|
||||||
|
from .mistral import *
|
||||||
|
from .mpt import *
|
||||||
|
|
|
@ -21,11 +21,15 @@ SUPPORTED_MODELS = [
|
||||||
"baichuan",
|
"baichuan",
|
||||||
"internlm",
|
"internlm",
|
||||||
"qwen",
|
"qwen",
|
||||||
|
"mpt",
|
||||||
]
|
]
|
||||||
if compare_transformers_version("v4.28.0", op="ge"):
|
if compare_transformers_version("v4.28.0", op="ge"):
|
||||||
SUPPORTED_MODELS.append("llama")
|
SUPPORTED_MODELS.append("llama")
|
||||||
if compare_transformers_version("v4.33.0", op="ge"):
|
if compare_transformers_version("v4.33.0", op="ge"):
|
||||||
SUPPORTED_MODELS.append("falcon")
|
SUPPORTED_MODELS.append("falcon")
|
||||||
|
if compare_transformers_version("v4.34.0", op="ge"):
|
||||||
|
SUPPORTED_MODELS.append("mistral")
|
||||||
|
|
||||||
|
|
||||||
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,8 @@ from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
||||||
from .baichuan import BaiChuanGPTQForCausalLM
|
from .baichuan import BaiChuanGPTQForCausalLM
|
||||||
from .internlm import InternLMGPTQForCausalLM
|
from .internlm import InternLMGPTQForCausalLM
|
||||||
from .qwen import QwenGPTQForCausalLM
|
from .qwen import QwenGPTQForCausalLM
|
||||||
|
from .mistral import MistralGPTQForCausalLM
|
||||||
|
from .mpt import MPTGPTQForCausalLM
|
||||||
|
|
||||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"bloom": BloomGPTQForCausalLM,
|
"bloom": BloomGPTQForCausalLM,
|
||||||
|
@ -33,6 +35,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"baichuan": BaiChuanGPTQForCausalLM,
|
"baichuan": BaiChuanGPTQForCausalLM,
|
||||||
"internlm": InternLMGPTQForCausalLM,
|
"internlm": InternLMGPTQForCausalLM,
|
||||||
"qwen": QwenGPTQForCausalLM,
|
"qwen": QwenGPTQForCausalLM,
|
||||||
|
"mistral": MistralGPTQForCausalLM,
|
||||||
|
"mpt": MPTGPTQForCausalLM,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
16
auto_gptq/modeling/mistral.py
Normal file
16
auto_gptq/modeling/mistral.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from ._base import *
|
||||||
|
|
||||||
|
|
||||||
|
class MistralGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
layer_type = "MistralDecoderLayer"
|
||||||
|
layers_block_name = "model.layers"
|
||||||
|
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||||
|
inside_layer_modules = [
|
||||||
|
["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"],
|
||||||
|
["self_attn.o_proj"],
|
||||||
|
["mlp.up_proj", "mlp.gate_proj"],
|
||||||
|
["mlp.down_proj"],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["MistralGPTQForCausalLM"]
|
18
auto_gptq/modeling/mpt.py
Normal file
18
auto_gptq/modeling/mpt.py
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
from auto_gptq.modeling import BaseGPTQForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
layer_type = "MPTBlock"
|
||||||
|
layers_block_name = "transformer.blocks"
|
||||||
|
outside_layer_modules = [
|
||||||
|
"transformer.wte", "transformer.norm_f"
|
||||||
|
]
|
||||||
|
|
||||||
|
inside_layer_modules = [
|
||||||
|
["attn.Wqkv"],
|
||||||
|
["attn.out_proj"],
|
||||||
|
["ffn.up_proj"],
|
||||||
|
["ffn.down_proj"]
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ = ["MPTGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue