Add support for GPTBigCode

This commit is contained in:
LaaZa 2023-05-08 12:28:29 +03:00
parent 560cf92d7d
commit 63247a0669
4 changed files with 22 additions and 2 deletions

View file

@ -7,3 +7,4 @@ from .gptj import *
from .llama import * from .llama import *
from .moss import * from .moss import *
from .opt import * from .opt import *
from .gpt_bigcode import *

View file

@ -6,7 +6,7 @@ from transformers import __version__ as transformers_version
CPU = device("cpu") CPU = device("cpu")
CUDA_0 = device("cuda:0") CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"] SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "gpt_bigcode"]
if parse_version(transformers_version) >= parse_version("v4.28.0"): if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS.append("llama") SUPPORTED_MODELS.append("llama")

View file

@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM from .opt import OPTGPTQForCausalLM
from .gpt_bigcode import GPTBigCodeGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = { GPTQ_CAUSAL_LM_MODEL_MAP = {
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"gpt2": GPT2GPTQForCausalLM, "gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM, "llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM, "opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM "moss": MOSSGPTQForCausalLM,
"gpt_bigcode": GPTBigCodeGPTQForCausalLM
} }

View file

@ -0,0 +1,17 @@
from auto_gptq.modeling import BaseGPTQForCausalLM
class GPTBigCodeGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "GPTBigCodeBlock"
layers_block_name = "transformer.h"
outside_layer_modules = [
"transformer.wpe", "transformer.wte", "transformer.ln_f"
]
inside_layer_modules = [
["attn.c_attn"],
["attn.c_proj"],
["mlp.c_fc"],
["mlp.c_proj"]
]
__all__ = ["GPTBigCodeGPTQForCausalLM"]