support falcon
This commit is contained in:
parent
4d5b4fa5c6
commit
bcb345fb35
4 changed files with 20 additions and 2 deletions
|
@ -7,3 +7,4 @@ from .gptj import *
|
||||||
from .llama import *
|
from .llama import *
|
||||||
from .moss import *
|
from .moss import *
|
||||||
from .opt import *
|
from .opt import *
|
||||||
|
from .rw import *
|
|
@ -7,7 +7,7 @@ from ..utils.import_utils import compare_transformers_version
|
||||||
CPU = device("cpu")
|
CPU = device("cpu")
|
||||||
CUDA_0 = device("cuda:0")
|
CUDA_0 = device("cuda:0")
|
||||||
|
|
||||||
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss"]
|
SUPPORTED_MODELS = ["bloom", "gptj", "gpt2", "gpt_neox", "opt", "moss", "RefinedWebModel"]
|
||||||
if compare_transformers_version("v4.28.0", op="ge"):
|
if compare_transformers_version("v4.28.0", op="ge"):
|
||||||
SUPPORTED_MODELS.append("llama")
|
SUPPORTED_MODELS.append("llama")
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from .gpt2 import GPT2GPTQForCausalLM
|
||||||
from .llama import LlamaGPTQForCausalLM
|
from .llama import LlamaGPTQForCausalLM
|
||||||
from .moss import MOSSGPTQForCausalLM
|
from .moss import MOSSGPTQForCausalLM
|
||||||
from .opt import OPTGPTQForCausalLM
|
from .opt import OPTGPTQForCausalLM
|
||||||
|
from .rw import RWGPTQForCausalLM
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
|
|
||||||
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
|
@ -18,7 +19,8 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"gpt2": GPT2GPTQForCausalLM,
|
"gpt2": GPT2GPTQForCausalLM,
|
||||||
"llama": LlamaGPTQForCausalLM,
|
"llama": LlamaGPTQForCausalLM,
|
||||||
"opt": OPTGPTQForCausalLM,
|
"opt": OPTGPTQForCausalLM,
|
||||||
"moss": MOSSGPTQForCausalLM
|
"moss": MOSSGPTQForCausalLM,
|
||||||
|
"RefinedWebModel": RWGPTQForCausalLM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
15
auto_gptq/modeling/rw.py
Normal file
15
auto_gptq/modeling/rw.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from ._base import *
|
||||||
|
|
||||||
|
|
||||||
|
class RWGPTQForCausalLM(BaseGPTQForCausalLM):
|
||||||
|
layer_type = "DecoderLayer"
|
||||||
|
layers_block_name = "transformer.h"
|
||||||
|
outside_layer_modules = ["transformer.word_embeddings", "transformer.ln_f"]
|
||||||
|
inside_layer_modules = [
|
||||||
|
["self_attention.query_key_value"],
|
||||||
|
["self_attention.dense"],
|
||||||
|
["mlp.dense_h_to_4h"],
|
||||||
|
["mlp.dense_4h_to_h"]
|
||||||
|
]
|
||||||
|
|
||||||
|
__all__ = ["RWGPTQForCausalLM"]
|
Loading…
Add table
Reference in a new issue