add disable_exllama argument
This commit is contained in:
parent
172deae049
commit
db9eabfc4b
2 changed files with 10 additions and 4 deletions
|
@ -146,7 +146,8 @@ def load_model_tokenizer(
|
|||
use_safetensors: bool = False,
|
||||
use_fast_tokenizer: bool = False,
|
||||
inject_fused_attention: bool = True,
|
||||
inject_fused_mlp: bool = True
|
||||
inject_fused_mlp: bool = True,
|
||||
disable_exllama: bool = False
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path=tokenizer_name_or_path or model_name_or_path,
|
||||
|
@ -176,7 +177,8 @@ def load_model_tokenizer(
|
|||
model_basename=model_basename,
|
||||
use_safetensors=use_safetensors,
|
||||
trust_remote_code=trust_remote_code,
|
||||
warmup_triton=False
|
||||
warmup_triton=False,
|
||||
disable_exllama=disable_exllama
|
||||
)
|
||||
|
||||
return model, tokenizer
|
||||
|
@ -234,6 +236,7 @@ def main():
|
|||
parser.add_argument("--use_triton", action="store_true")
|
||||
parser.add_argument("--use_safetensors", action="store_true")
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true")
|
||||
parser.add_argument("--disable_exllama", action="store_true")
|
||||
parser.add_argument("--no_inject_fused_attention", action="store_true")
|
||||
parser.add_argument("--no_inject_fused_mlp", action="store_true")
|
||||
parser.add_argument("--num_samples", type=int, default=10)
|
||||
|
@ -275,7 +278,8 @@ def main():
|
|||
use_safetensors=args.use_safetensors,
|
||||
use_fast_tokenizer=args.use_fast_tokenizer,
|
||||
inject_fused_attention=not args.no_inject_fused_attention,
|
||||
inject_fused_mlp=not args.no_inject_fused_mlp
|
||||
inject_fused_mlp=not args.no_inject_fused_mlp,
|
||||
disable_exllama=args.disable_exllama
|
||||
)
|
||||
end = time.time()
|
||||
logger.info(f"model and tokenizer loading time: {end - start:.4f}s")
|
||||
|
|
|
@ -37,6 +37,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--use_safetensors", action="store_true", help="Whether to use safetensors model file")
|
||||
parser.add_argument("--use_fast_tokenizer", action="store_true", help="Wheter to use fast tokenizer")
|
||||
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to use remote code")
|
||||
parser.add_argument("--disable_exllama", action="store_true", help="Whether to use disable exllama kernel")
|
||||
args = parser.parse_args()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
@ -68,7 +69,8 @@ if __name__ == "__main__":
|
|||
use_safetensors=args.use_safetensors,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
inject_fused_mlp=False,
|
||||
inject_fused_attention=False
|
||||
inject_fused_attention=False,
|
||||
disable_exllama=args.disable_exllama
|
||||
)
|
||||
else:
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
|
Loading…
Add table
Reference in a new issue