Add support for CodeGen/2

This commit is contained in:
LaaZa 2023-05-08 17:34:00 +03:00
parent 560cf92d7d
commit b8187ff05a
4 changed files with 21 additions and 2 deletions

View file

@ -7,3 +7,4 @@ from .gptj import *
from .llama import * from .llama import *
from .moss import * from .moss import *
from .opt import * from .opt import *
from .codegen import *

View file

@ -6,7 +6,7 @@ from transformers import __version__ as transformers_version
CPU = device("cpu") CPU = device("cpu")
CUDA_0 = device("cuda:0") CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"] SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "codegen"]
if parse_version(transformers_version) >= parse_version("v4.28.0"): if parse_version(transformers_version) >= parse_version("v4.28.0"):
SUPPORTED_MODELS.append("llama") SUPPORTED_MODELS.append("llama")

View file

@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM from .opt import OPTGPTQForCausalLM
from .codegen import CodeGenGPTQForCausalLM
GPTQ_CAUSAL_LM_MODEL_MAP = { GPTQ_CAUSAL_LM_MODEL_MAP = {
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"gpt2": GPT2GPTQForCausalLM, "gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM, "llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM, "opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM "moss": MOSSGPTQForCausalLM,
"codegen": CodeGenGPTQForCausalLM
} }

View file

@ -0,0 +1,16 @@
from ._base import *
class CodeGenGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "CodeGenBlock"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
inside_layer_modules = [
["attn.qkv_proj"],
["attn.out_proj"],
["mlp.fc_in"],
["mlp.fc_out"]
]
__all__ = ["CodeGenGPTQForCausalLM"]