Beginning of multi-user support (#2262)

Adds a lock to generate_reply
This commit is contained in:
flurb18 2023-05-24 08:38:20 -04:00 committed by GitHub
parent 7dc87984a2
commit d37a28730d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 14 additions and 1 deletions

View file

@ -6,6 +6,7 @@ import yaml
from modules.logging_colors import logger from modules.logging_colors import logger
generation_lock = None
model = None model = None
tokenizer = None tokenizer = None
model_name = "None" model_name = "None"

View file

@ -1,6 +1,7 @@
import ast import ast
import random import random
import re import re
import threading
import time import time
import traceback import traceback
@ -17,6 +18,15 @@ from modules.logging_colors import logger
from modules.models import clear_torch_cache, local_rank from modules.models import clear_torch_cache, local_rank
def generate_reply(*args, **kwargs):
shared.generation_lock.acquire()
try:
for result in _generate_reply(*args, **kwargs):
yield result
finally:
shared.generation_lock.release()
def get_max_prompt_length(state): def get_max_prompt_length(state):
max_length = state['truncation_length'] - state['max_new_tokens'] max_length = state['truncation_length'] - state['max_new_tokens']
if shared.soft_prompt: if shared.soft_prompt:
@ -154,7 +164,7 @@ def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=Non
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
def generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False): def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
state = apply_extensions('state', state) state = apply_extensions('state', state)
generate_func = apply_extensions('custom_generate_reply') generate_func = apply_extensions('custom_generate_reply')
if generate_func is None: if generate_func is None:

View file

@ -38,6 +38,7 @@ import zipfile
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from threading import Lock
import psutil import psutil
import torch import torch
@ -1075,6 +1076,7 @@ if __name__ == "__main__":
'instruction_template': shared.settings['instruction_template'] 'instruction_template': shared.settings['instruction_template']
}) })
shared.generation_lock = Lock()
# Launch the web UI # Launch the web UI
create_interface() create_interface()
while True: while True: