Add support for CodeGen/2
This commit is contained in:
parent
560cf92d7d
commit
b8187ff05a
4 changed files with 21 additions and 2 deletions
|
@ -7,3 +7,4 @@ from .gptj import *
|
|||
from .llama import *
|
||||
from .moss import *
|
||||
from .opt import *
|
||||
from .codegen import *
|
||||
|
|
|
@ -6,7 +6,7 @@ from transformers import __version__ as transformers_version
|
|||
CPU = device("cpu")
|
||||
CUDA_0 = device("cuda:0")
|
||||
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "codegen"]
|
||||
if parse_version(transformers_version) >= parse_version("v4.28.0"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
|
|||
from .llama import LlamaGPTQForCausalLM
|
||||
from .moss import MOSSGPTQForCausalLM
|
||||
from .opt import OPTGPTQForCausalLM
|
||||
from .codegen import CodeGenGPTQForCausalLM
|
||||
|
||||
|
||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||
|
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
|||
"gpt2": GPT2GPTQForCausalLM,
|
||||
"llama": LlamaGPTQForCausalLM,
|
||||
"opt": OPTGPTQForCausalLM,
|
||||
"moss": MOSSGPTQForCausalLM
|
||||
"moss": MOSSGPTQForCausalLM,
|
||||
"codegen": CodeGenGPTQForCausalLM
|
||||
}
|
||||
|
||||
|
||||
|
|
16
auto_gptq/modeling/codegen.py
Normal file
16
auto_gptq/modeling/codegen.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class CodeGenGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "CodeGenBlock"
|
||||
layers_block_name = "transformer.h"
|
||||
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
|
||||
inside_layer_modules = [
|
||||
["attn.qkv_proj"],
|
||||
["attn.out_proj"],
|
||||
["mlp.fc_in"],
|
||||
["mlp.fc_out"]
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["CodeGenGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue