support falcon

This commit is contained in:
qwopqwop200 2023-05-27 07:53:39 +09:00 committed by GitHub
parent 4d5b4fa5c6
commit bcb345fb35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 2 deletions

View file

@ -7,3 +7,4 @@ from .gptj import *
from .llama import * from .llama import *
from .moss import * from .moss import *
from .opt import * from .opt import *
from .rw import *

View file

@ -7,7 +7,7 @@ from ..utils.import_utils import compare_transformers_version
CPU = device("cpu") CPU = device("cpu")
CUDA_0 = device("cuda:0") CUDA_0 = device("cuda:0")
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"] SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "RefinedWebModel"]
if compare_transformers_version("v4.28.0", op="ge"): if compare_transformers_version("v4.28.0", op="ge"):
SUPPORTED_MODELS.append("llama") SUPPORTED_MODELS.append("llama")

View file

@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
from .llama import LlamaGPTQForCausalLM from .llama import LlamaGPTQForCausalLM
from .moss import MOSSGPTQForCausalLM from .moss import MOSSGPTQForCausalLM
from .opt import OPTGPTQForCausalLM from .opt import OPTGPTQForCausalLM
from .rw import RWGPTQForCausalLM
from inspect import signature from inspect import signature
GPTQ_CAUSAL_LM_MODEL_MAP = { GPTQ_CAUSAL_LM_MODEL_MAP = {
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
"gpt2": GPT2GPTQForCausalLM, "gpt2": GPT2GPTQForCausalLM,
"llama": LlamaGPTQForCausalLM, "llama": LlamaGPTQForCausalLM,
"opt": OPTGPTQForCausalLM, "opt": OPTGPTQForCausalLM,
"moss": MOSSGPTQForCausalLM "moss": MOSSGPTQForCausalLM,
"RefinedWebModel": RWGPTQForCausalLM
} }

15
auto_gptq/modeling/rw.py Normal file
View file

@ -0,0 +1,15 @@
from ._base import *
class RWGPTQForCausalLM(BaseGPTQForCausalLM):
layer_type = "DecoderLayer"
layers_block_name = "transformer.h"
outside_layer_modules = ["transformer.word_embeddings", "transformer.ln_f"]
inside_layer_modules = [
["self_attention.query_key_value"],
["self_attention.dense"],
["mlp.dense_h_to_4h"],
["mlp.dense_4h_to_h"]
]
__all__ = ["RWGPTQForCausalLM"]