Add support for Baichuan
This commit is contained in:
parent
9baff43f6f
commit
9fd558f2ba
4 changed files with 21 additions and 2 deletions
|
@ -10,3 +10,4 @@ from .opt import *
|
||||||
from .rw import *
|
from .rw import *
|
||||||
from .gpt_bigcode import *
|
from .gpt_bigcode import *
|
||||||
from .codegen import *
|
from .codegen import *
|
||||||
|
from .baichuan import *
|
||||||
|
|
|
@ -7,7 +7,7 @@ from ..utils.import_utils import compare_transformers_version
|
||||||
CPU = device("cpu")
|
CPU = device("cpu")
|
||||||
CUDA_0 = device("cuda:0")
|
CUDA_0 = device("cuda:0")
|
||||||
|
|
||||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode", "codegen", "RefinedWebModel", "RefinedWeb"]
|
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode", "codegen", "RefinedWebModel", "RefinedWeb", "baichuan"]
|
||||||
if compare_transformers_version("v4.28.0", op="ge"):
|
if compare_transformers_version("v4.28.0", op="ge"):
|
||||||
SUPPORTED_MODELS.append("llama")
|
SUPPORTED_MODELS.append("llama")
|
||||||
|
|
||||||
|
|
|
@ -13,6 +13,7 @@ from .moss import MOSSGPTQForCausalLM
|
||||||
from .opt import OPTGPTQForCausalLM
|
from .opt import OPTGPTQForCausalLM
|
||||||
from .rw import RWGPTQForCausalLM
|
from .rw import RWGPTQForCausalLM
|
||||||
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
|
||||||
|
from .baichuan import BaichuanGPTQForCausalLM
|
||||||
|
|
||||||
|
|
||||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
|
@ -26,7 +27,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
|
"gpt_bigcode": GPTBigCodeGPTQForCausalLM,
|
||||||
"codegen": CodeGenGPTQForCausalLM,
|
"codegen": CodeGenGPTQForCausalLM,
|
||||||
"RefinedWebModel": RWGPTQForCausalLM,
|
"RefinedWebModel": RWGPTQForCausalLM,
|
||||||
"RefinedWeb":RWGPTQForCausalLM
|
"RefinedWeb":RWGPTQForCausalLM,
|
||||||
|
"baichuan":BaichuanGPTQForCausalLM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
16
auto_gptq/modeling/baichuan.py
Normal file
16
auto_gptq/modeling/baichuan.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from ._base import *
|
||||||
|
|
||||||
|
|
||||||
|
class BaichuanGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
layer_type = "DecoderLayer"
|
||||||
|
layers_block_name = "model.layers"
|
||||||
|
outside_layer_modules = ["model.embed_tokens", "model.norm"]
|
||||||
|
inside_layer_modules = [
|
||||||
|
["self_attn.W_pack"],
|
||||||
|
["self_attn.o_proj"],
|
||||||
|
["mlp.up_proj", "mlp.gate_proj"],
|
||||||
|
["mlp.down_proj"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BaichuanGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue