add disable_exllama argument

This commit is contained in:
PanQiWei 2023-08-09 12:05:15 +08:00
parent 172deae049
commit db9eabfc4b
2 changed files with 10 additions and 4 deletions

View file

@ -146,7 +146,8 @@ def load_model_tokenizer(
use_safetensors: bool = False,
use_fast_tokenizer: bool = False,
inject_fused_attention: bool = True,
inject_fused_mlp: bool = True
inject_fused_mlp: bool = True,
disable_exllama: bool = False
):
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=tokenizer_name_or_path or model_name_or_path,
@ -176,7 +177,8 @@ def load_model_tokenizer(
model_basename=model_basename,
use_safetensors=use_safetensors,
trust_remote_code=trust_remote_code,
warmup_triton=False
warmup_triton=False,
disable_exllama=disable_exllama
)
return model, tokenizer
@ -234,6 +236,7 @@ def main():
parser.add_argument("--use_triton", action="store_true")
parser.add_argument("--use_safetensors", action="store_true")
parser.add_argument("--use_fast_tokenizer", action="store_true")
parser.add_argument("--disable_exllama", action="store_true")
parser.add_argument("--no_inject_fused_attention", action="store_true")
parser.add_argument("--no_inject_fused_mlp", action="store_true")
parser.add_argument("--num_samples", type=int, default=10)
@ -275,7 +278,8 @@ def main():
use_safetensors=args.use_safetensors,
use_fast_tokenizer=args.use_fast_tokenizer,
inject_fused_attention=not args.no_inject_fused_attention,
inject_fused_mlp=not args.no_inject_fused_mlp
inject_fused_mlp=not args.no_inject_fused_mlp,
disable_exllama=args.disable_exllama
)
end = time.time()
logger.info(f"model and tokenizer loading time: {end - start:.4f}s")

View file

@ -37,6 +37,7 @@ if __name__ == "__main__":
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
args = parser.parse_args()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@ -68,7 +69,8 @@ if __name__ == "__main__":
use_safetensors=args.use_safetensors,
trust_remote_code=args.trust_remote_code,
inject_fused_mlp=False,
inject_fused_attention=False
inject_fused_attention=False,
disable_exllama=args.disable_exllama
)
else:
from transformers import AutoModelForCausalLM