support falcon
This commit is contained in:
parent
4d5b4fa5c6
commit
bcb345fb35
4 changed files with 20 additions and 2 deletions
|
@ -7,3 +7,4 @@ from .gptj import *
|
|||
from .llama import *
|
||||
from .moss import *
|
||||
from .opt import *
|
||||
from .rw import *
|
|
@ -7,7 +7,7 @@ from ..utils.import_utils import compare_transformers_version
|
|||
CPU = device("cpu")
|
||||
CUDA_0 = device("cuda:0")
|
||||
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
|
||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "RefinedWebModel"]
|
||||
if compare_transformers_version("v4.28.0", op="ge"):
|
||||
SUPPORTED_MODELS.append("llama")
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
|
|||
from .llama import LlamaGPTQForCausalLM
|
||||
from .moss import MOSSGPTQForCausalLM
|
||||
from .opt import OPTGPTQForCausalLM
|
||||
from .rw import RWGPTQForCausalLM
|
||||
from inspect import signature
|
||||
|
||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||
|
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
|||
"gpt2": GPT2GPTQForCausalLM,
|
||||
"llama": LlamaGPTQForCausalLM,
|
||||
"opt": OPTGPTQForCausalLM,
|
||||
"moss": MOSSGPTQForCausalLM
|
||||
"moss": MOSSGPTQForCausalLM,
|
||||
"RefinedWebModel": RWGPTQForCausalLM
|
||||
}
|
||||
|
||||
|
||||
|
|
15
auto_gptq/modeling/rw.py
Normal file
15
auto_gptq/modeling/rw.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from ._base import *
|
||||
|
||||
|
||||
class RWGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||
layer_type = "DecoderLayer"
|
||||
layers_block_name = "transformer.h"
|
||||
outside_layer_modules = ["transformer.word_embeddings", "transformer.ln_f"]
|
||||
inside_layer_modules = [
|
||||
["self_attention.query_key_value"],
|
||||
["self_attention.dense"],
|
||||
["mlp.dense_h_to_4h"],
|
||||
["mlp.dense_4h_to_h"]
|
||||
]
|
||||
|
||||
__all__ = ["RWGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue