From b2a2ddcb15d5aa993691591104a1c1e54cc11590 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 10 Jan 2023 23:39:50 -0300 Subject: [PATCH] Remove T5 support (it sucks) --- server.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/server.py b/server.py index f066f418..5c994306 100644 --- a/server.py +++ b/server.py @@ -8,8 +8,7 @@ from pathlib import Path import gradio as gr import transformers from html_generator import * -from transformers import AutoTokenizer, T5Tokenizer -from transformers import AutoModelForCausalLM, T5ForConditionalGeneration +from transformers import AutoTokenizer, AutoModelForCausalLM parser = argparse.ArgumentParser() @@ -37,8 +36,6 @@ def load_model(model_name): model = torch.load(Path(f"torch-dumps/{model_name}.pt")) elif model_name.lower().startswith(('gpt-neo', 'opt-', 'galactica')) and any(size in model_name.lower() for size in ('13b', '20b', '30b')): model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True) - elif model_name in ['flan-t5', 't5-large']: - model = T5ForConditionalGeneration.from_pretrained(Path(f"models/{model_name}")).cuda() else: model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() @@ -46,11 +43,7 @@ def load_model(model_name): else: settings = ["low_cpu_mem_usage=True"] cuda = "" - - if model_name in ['flan-t5', 't5-large']: - command = f"T5ForConditionalGeneration.from_pretrained" - else: - command = "AutoModelForCausalLM.from_pretrained" + command = "AutoModelForCausalLM.from_pretrained" if args.cpu: settings.append("torch_dtype=torch.float32") @@ -72,8 +65,6 @@ def load_model(model_name): # Loading the tokenizer if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists(): tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) - elif model_name in ['flan-t5', 't5-large']: - tokenizer = T5Tokenizer.from_pretrained(Path(f"models/{model_name}/")) else: tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))