text-generation-webui-mirror/modules/models.py
2025-06-01 19:27:14 -07:00

154 lines
4.7 KiB
Python

import sys
import time
from pathlib import Path
import modules.shared as shared
from modules.logging_colors import logger
from modules.models_settings import get_model_metadata
last_generation_time = time.time()
def load_model(model_name, loader=None):
logger.info(f"Loading \"{model_name}\"")
t0 = time.time()
shared.is_seq2seq = False
shared.model_name = model_name
load_func_map = {
'llama.cpp': llama_cpp_server_loader,
'Transformers': transformers_loader,
'ExLlamav3_HF': ExLlamav3_HF_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ExLlamav2': ExLlamav2_loader,
'TensorRT-LLM': TensorRT_LLM_loader,
}
metadata = get_model_metadata(model_name)
if loader is None:
if shared.args.loader is not None:
loader = shared.args.loader
else:
loader = metadata['loader']
if loader is None:
logger.error('The path to the model does not exist. Exiting.')
raise ValueError
if loader != 'llama.cpp' and 'sampler_hijack' not in sys.modules:
from modules import sampler_hijack
sampler_hijack.hijack_samplers()
shared.args.loader = loader
output = load_func_map[loader](model_name)
if type(output) is tuple:
model, tokenizer = output
else:
model = output
if model is None:
return None, None
else:
from modules.transformers_loader import load_tokenizer
tokenizer = load_tokenizer(model_name)
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
if loader.lower().startswith('exllama') or loader.lower().startswith('tensorrt') or loader == 'llama.cpp':
shared.settings['truncation_length'] = shared.args.ctx_size
logger.info(f"Loaded \"{model_name}\" in {(time.time()-t0):.2f} seconds.")
logger.info(f"LOADER: \"{loader}\"")
logger.info(f"TRUNCATION LENGTH: {shared.settings['truncation_length']}")
logger.info(f"INSTRUCTION TEMPLATE: \"{metadata['instruction_template']}\"")
return model, tokenizer
def llama_cpp_server_loader(model_name):
from modules.llama_cpp_server import LlamaServer
path = Path(f'{shared.args.model_dir}/{model_name}')
if path.is_file():
model_file = path
else:
model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
try:
model = LlamaServer(model_file)
return model, model
except Exception as e:
logger.error(f"Error loading the model with llama.cpp: {str(e)}")
def transformers_loader(model_name):
from modules.transformers_loader import load_model_HF
return load_model_HF(model_name)
def ExLlamav3_HF_loader(model_name):
from modules.exllamav3_hf import Exllamav3HF
return Exllamav3HF.from_pretrained(model_name)
def ExLlamav2_HF_loader(model_name):
from modules.exllamav2_hf import Exllamav2HF
return Exllamav2HF.from_pretrained(model_name)
def ExLlamav2_loader(model_name):
from modules.exllamav2 import Exllamav2Model
model, tokenizer = Exllamav2Model.from_pretrained(model_name)
return model, tokenizer
def TensorRT_LLM_loader(model_name):
try:
from modules.tensorrt_llm import TensorRTLLMModel
except ModuleNotFoundError:
raise ModuleNotFoundError("Failed to import 'tensorrt_llm'. Please install it manually following the instructions in the TensorRT-LLM GitHub repository.")
model = TensorRTLLMModel.from_pretrained(model_name)
return model
def unload_model(keep_model_name=False):
if shared.model is None:
return
is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer')
if shared.model.__class__.__name__ == 'Exllamav3HF':
shared.model.unload()
shared.model = shared.tokenizer = None
shared.lora_names = []
shared.model_dirty_from_training = False
if not is_llamacpp:
from modules.torch_utils import clear_torch_cache
clear_torch_cache()
if not keep_model_name:
shared.model_name = 'None'
def reload_model():
unload_model()
shared.model, shared.tokenizer = load_model(shared.model_name)
def unload_model_if_idle():
global last_generation_time
logger.info(f"Setting a timeout of {shared.args.idle_timeout} minutes to unload the model in case of inactivity.")
while True:
shared.generation_lock.acquire()
try:
if time.time() - last_generation_time > shared.args.idle_timeout * 60:
if shared.model is not None:
logger.info("Unloading the model for inactivity.")
unload_model(keep_model_name=True)
finally:
shared.generation_lock.release()
time.sleep(60)