Add support for Falcon as part of Transformers 4.33.0, including new Falcon 180B
This commit is contained in:
parent
1793227283
commit
02a87dce76
3 changed files with 65 additions and 0 deletions
|
@ -24,6 +24,8 @@ SUPPORTED_MODELS = [
|
||||||
]
|
]
|
||||||
if compare_transformers_version("v4.28.0", op="ge"):
|
if compare_transformers_version("v4.28.0", op="ge"):
|
||||||
SUPPORTED_MODELS.append("llama")
|
SUPPORTED_MODELS.append("llama")
|
||||||
|
if compare_transformers_version("v4.33.0", op="ge"):
|
||||||
|
SUPPORTED_MODELS.append("falcon")
|
||||||
|
|
||||||
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048
|
||||||
|
|
||||||
|
|
|
@ -29,6 +29,7 @@ GPTQ_CAUSAL_LM_MODEL_MAP = {
|
||||||
"codegen": CodeGenGPTQForCausalLM,
|
"codegen": CodeGenGPTQForCausalLM,
|
||||||
"RefinedWebModel": RWGPTQForCausalLM,
|
"RefinedWebModel": RWGPTQForCausalLM,
|
||||||
"RefinedWeb": RWGPTQForCausalLM,
|
"RefinedWeb": RWGPTQForCausalLM,
|
||||||
|
"falcon": RWGPTQForCausalLM,
|
||||||
"baichuan": BaiChuanGPTQForCausalLM,
|
"baichuan": BaiChuanGPTQForCausalLM,
|
||||||
"internlm": InternLMGPTQForCausalLM,
|
"internlm": InternLMGPTQForCausalLM,
|
||||||
"qwen": QwenGPTQForCausalLM,
|
"qwen": QwenGPTQForCausalLM,
|
||||||
|
|
62
examples/basic_inference.py
Normal file
62
examples/basic_inference.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
from transformers import AutoTokenizer, pipeline, logging
|
||||||
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Simple AutoGPTQ example')
|
||||||
|
parser.add_argument('model_name_or_path', type=str, help='Model folder or repo')
|
||||||
|
parser.add_argument('--model_basename', type=str, help='Model file basename if model is not named gptq_model-Xb-Ygr')
|
||||||
|
parser.add_argument('--use_slow', action="store_true", help='Use slow tokenizer')
|
||||||
|
parser.add_argument('--use_safetensors', action="store_true", help='Model file basename if model is not named gptq_model-Xb-Ygr')
|
||||||
|
parser.add_argument('--use_triton', action="store_true", help='Use Triton for inference?')
|
||||||
|
parser.add_argument('--bits', type=int, default=4, help='Specify GPTQ bits. Only needed if no quantize_config.json is provided')
|
||||||
|
parser.add_argument('--group_size', type=int, default=128, help='Specify GPTQ group_size. Only needed if no quantize_config.json is provided')
|
||||||
|
parser.add_argument('--desc_act', action="store_true", help='Specify GPTQ desc_act. Only needed if no quantize_config.json is provided')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
quantized_model_dir = args.model_name_or_path
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(quantized_model_dir, use_fast=not args.use_slow)
|
||||||
|
|
||||||
|
try:
|
||||||
|
quantize_config = BaseQuantizeConfig.from_pretrained(quantized_model_dir)
|
||||||
|
except:
|
||||||
|
quantize_config = BaseQuantizeConfig(
|
||||||
|
bits=args.bits,
|
||||||
|
group_size=args.group_size,
|
||||||
|
desc_act=args.desc_act
|
||||||
|
)
|
||||||
|
|
||||||
|
model = AutoGPTQForCausalLM.from_quantized(quantized_model_dir,
|
||||||
|
use_safetensors=True,
|
||||||
|
model_basename=args.model_basename,
|
||||||
|
device="cuda:0",
|
||||||
|
use_triton=args.use_triton,
|
||||||
|
quantize_config=quantize_config)
|
||||||
|
|
||||||
|
# Prevent printing spurious transformers error when using pipeline with AutoGPTQ
|
||||||
|
logging.set_verbosity(logging.CRITICAL)
|
||||||
|
|
||||||
|
prompt = "Tell me about AI"
|
||||||
|
prompt_template=f'''### Human: {prompt}
|
||||||
|
### Assistant:'''
|
||||||
|
|
||||||
|
print("*** Pipeline:")
|
||||||
|
pipe = pipeline(
|
||||||
|
"text-generation",
|
||||||
|
model=model,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_new_tokens=512,
|
||||||
|
temperature=0.7,
|
||||||
|
top_p=0.95,
|
||||||
|
repetition_penalty=1.15
|
||||||
|
)
|
||||||
|
|
||||||
|
print(pipe(prompt_template)[0]['generated_text'])
|
||||||
|
|
||||||
|
print("\n\n*** Generate:")
|
||||||
|
|
||||||
|
input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
|
||||||
|
output = model.generate(inputs=input_ids, temperature=0.7, max_new_tokens=512)
|
||||||
|
print(tokenizer.decode(output[0]))
|
||||||
|
|
Loading…
Add table
Reference in a new issue