support falcon

This commit is contained in:
qwopqwop200 2023-05-27 07:53:39 +09:00 committed by GitHub
parent 4d5b4fa5c6
commit bcb345fb35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 2 deletions

View file

@ -7,3 +7,4 @@ from .gptj import *
from .llama import *
from .moss import *
from .opt import *
from .rw import *

View file

@ -7,7 +7,7 @@ from ..utils.import_utils import compare_transformers_version
CPU = device("cpu")
CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "RefinedWebModel"]
if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama")

View file

@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM
from .rw import RWGPTQForCausalLM
from inspect import signature
GPTQ_CAUSAL_LM_MODEL_MAP = {
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM
"moss": MOSSGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM
}

15
auto_gptq/modeling/rw.py Normal file
View file

@ -0,0 +1,15 @@
from ._base import *
class RWGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "DecoderLayer"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.word_embeddings", "transformer.ln_f"]
inside_layer_modules = [
["self_attention.query_key_value"],
["self_attention.dense"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"]
]
__all__ = ["RWGPTQForCausalLM"]