text-generation-webui-mirror/modules/llama_cpp_python_hijack.py

69 lines
2 KiB
Python

import importlib
import platform
from modules import shared
from modules.cache_utils import process_llamacpp_cache
imported_module = None
def llama_cpp_lib():
global imported_module
# Determine the platform
is_macos = platform.system() == 'Darwin'
# Define the library names based on the platform
if is_macos:
lib_names = [
(None, 'llama_cpp')
]
else:
lib_names = [
('cpu', 'llama_cpp'),
('tensorcores', 'llama_cpp_cuda_tensorcores'),
(None, 'llama_cpp_cuda'),
(None, 'llama_cpp')
]
for arg, lib_name in lib_names:
should_import = (arg is None or getattr(shared.args, arg))
if should_import:
if imported_module and imported_module != lib_name:
# Conflict detected, raise an exception
raise Exception(f"Cannot import `{lib_name}` because `{imported_module}` is already imported. Switching to a different version of llama-cpp-python currently requires a server restart.")
try:
return_lib = importlib.import_module(lib_name)
imported_module = lib_name
monkey_patch_llama_cpp_python(return_lib)
return return_lib
except ImportError:
continue
return None
def monkey_patch_llama_cpp_python(lib):
if getattr(lib.Llama, '_is_patched', False):
# If the patch is already applied, do nothing
return
def my_generate(self, *args, **kwargs):
if shared.args.streaming_llm:
new_sequence = args[0]
past_sequence = self._input_ids
# Do the cache trimming for StreamingLLM
process_llamacpp_cache(self, new_sequence, past_sequence)
for output in self.original_generate(*args, **kwargs):
yield output
lib.Llama.original_generate = lib.Llama.generate
lib.Llama.generate = my_generate
# Set the flag to indicate that the patch has been applied
lib.Llama._is_patched = True