Add support for CodeGen/2
This commit is contained in:
parent
560cf92d7d
commit
b8187ff05a
4 changed files with 21 additions and 2 deletions
|
@ -7,3 +7,4 @@ from .gptj import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
from .moss import *
|
from .moss import *
|
||||||
from .opt import *
|
from .opt import *
|
||||||
|
from .codegen import *
|
||||||
|
|
|
@ -6,7 +6,7 @@ from transformers import __version__ as transformers_version
|
||||||
CPU = device("cpu")
|
CPU = device("cpu")
|
||||||
CUDA_0 = device("cuda:0")
|
CUDA_0 = device("cuda:0")
|
||||||
|
|
||||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
|
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "codegen"]
|
||||||
if parse_version(transformers_version) >= parse_version("v4.28.0"):
|
if parse_version(transformers_version) >= parse_version("v4.28.0"):
|
||||||
SUPPORTED_MODELS.append("llama")
|
SUPPORTED_MODELS.append("llama")
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
|
||||||
from .llama import LlamaGPTQForCausalLM
|
from .llama import LlamaGPTQForCausalLM
|
||||||
from .moss import MOSSGPTQForCausalLM
|
from .moss import MOSSGPTQForCausalLM
|
||||||
from .opt import OPTGPTQForCausalLM
|
from .opt import OPTGPTQForCausalLM
|
||||||
|
from .codegen import CodeGenGPTQForCausalLM
|
||||||
|
|
||||||
|
|
||||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
|
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"gpt2": GPT2GPTQForCausalLM,
|
"gpt2": GPT2GPTQForCausalLM,
|
||||||
"llama": LlamaGPTQForCausalLM,
|
"llama": LlamaGPTQForCausalLM,
|
||||||
"opt": OPTGPTQForCausalLM,
|
"opt": OPTGPTQForCausalLM,
|
||||||
"moss": MOSSGPTQForCausalLM
|
"moss": MOSSGPTQForCausalLM,
|
||||||
|
"codegen": CodeGenGPTQForCausalLM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
16
auto_gptq/modeling/codegen.py
Normal file
16
auto_gptq/modeling/codegen.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from ._base import *
|
||||||
|
|
||||||
|
|
||||||
|
class CodeGenGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
layer_type = "CodeGenBlock"
|
||||||
|
layers_block_name = "transformer.h"
|
||||||
|
outside_layer_modules = ["transformer.wte", "transformer.ln_f"]
|
||||||
|
inside_layer_modules = [
|
||||||
|
["attn.qkv_proj"],
|
||||||
|
["attn.out_proj"],
|
||||||
|
["mlp.fc_in"],
|
||||||
|
["mlp.fc_out"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["CodeGenGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue