diff --git a/auto_gptq/modeling/__init__.py b/auto_gptq/modeling/__init__.py index 7e7624d..be0e4e5 100644 --- a/auto_gptq/modeling/__init__.py +++ b/auto_gptq/modeling/__init__.py @@ -10,3 +10,4 @@ from .opt import * from .rw import * from .gpt_bigcode import * from .codegen import * +from .baichuan import * diff --git a/auto_gptq/modeling/_const.py b/auto_gptq/modeling/_const.py index e38a51f..cb84ef7 100644 --- a/auto_gptq/modeling/_const.py +++ b/auto_gptq/modeling/_const.py @@ -7,7 +7,7 @@ from ..utils.import_utils import compare_transformers_version CPU = device("cpu") CUDA_0 = device("cuda:0") -SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode", "codegen", "RefinedWebModel", "RefinedWeb"] +SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode", "codegen", "RefinedWebModel", "RefinedWeb", "baichuan"] if compare_transformers_version("v4.28.0", op="ge"): SUPPORTED_MODELS.append("llama") diff --git a/auto_gptq/modeling/auto.py b/auto_gptq/modeling/auto.py index 936d1f1..e501830 100644 --- a/auto_gptq/modeling/auto.py +++ b/auto_gptq/modeling/auto.py @@ -13,6 +13,7 @@ from .moss import MOSSGPTQForCausalLM from .opt import OPTGPTQForCausalLM from .rw import RWGPTQForCausalLM from .gpt_bigcode import GPTBigCodeGPTQForCausalLM +from .baichuan import BaichuanGPTQForCausalLM GPTQ_CAUSAL_LM_MODEL_MAP = { @@ -26,7 +27,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = { "gpt_bigcode": GPTBigCodeGPTQForCausalLM, "codegen": CodeGenGPTQForCausalLM, "RefinedWebModel": RWGPTQForCausalLM, - "RefinedWeb":RWGPTQForCausalLM + "RefinedWeb":RWGPTQForCausalLM, + "baichuan":BaichuanGPTQForCausalLM } diff --git a/auto_gptq/modeling/baichuan.py b/auto_gptq/modeling/baichuan.py new file mode 100644 index 0000000..8a01c1c --- /dev/null +++ b/auto_gptq/modeling/baichuan.py @@ -0,0 +1,16 @@ +from ._base import * + + +class BaichuanGPTQForCausalLM(BaseGPTQForCausalLM): + layer_type = "DecoderLayer" + layers_block_name = "model.layers" + outside_layer_modules = ["model.embed_tokens", "model.norm"] + inside_layer_modules = [ + ["self_attn.W_pack"], + ["self_attn.o_proj"], + ["mlp.up_proj", "mlp.gate_proj"], + ["mlp.down_proj"] + ] + + +__all__ = ["BaichuanGPTQForCausalLM"]