mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-09 07:07:16 -04:00
parent
7dc87984a2
commit
d37a28730d
3 changed files with 14 additions and 1 deletions
|
@ -6,6 +6,7 @@ import yaml
|
||||||
|
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
|
generation_lock = None
|
||||||
model = None
|
model = None
|
||||||
tokenizer = None
|
tokenizer = None
|
||||||
model_name = "None"
|
model_name = "None"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import ast
|
import ast
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -17,6 +18,15 @@ from modules.logging_colors import logger
|
||||||
from modules.models import clear_torch_cache, local_rank
|
from modules.models import clear_torch_cache, local_rank
|
||||||
|
|
||||||
|
|
||||||
|
def generate_reply(*args, **kwargs):
|
||||||
|
shared.generation_lock.acquire()
|
||||||
|
try:
|
||||||
|
for result in _generate_reply(*args, **kwargs):
|
||||||
|
yield result
|
||||||
|
finally:
|
||||||
|
shared.generation_lock.release()
|
||||||
|
|
||||||
|
|
||||||
def get_max_prompt_length(state):
|
def get_max_prompt_length(state):
|
||||||
max_length = state['truncation_length'] - state['max_new_tokens']
|
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
|
@ -154,7 +164,7 @@ def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=Non
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
def generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
|
def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_chat=False):
|
||||||
state = apply_extensions('state', state)
|
state = apply_extensions('state', state)
|
||||||
generate_func = apply_extensions('custom_generate_reply')
|
generate_func = apply_extensions('custom_generate_reply')
|
||||||
if generate_func is None:
|
if generate_func is None:
|
||||||
|
|
|
@ -38,6 +38,7 @@ import zipfile
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
@ -1075,6 +1076,7 @@ if __name__ == "__main__":
|
||||||
'instruction_template': shared.settings['instruction_template']
|
'instruction_template': shared.settings['instruction_template']
|
||||||
})
|
})
|
||||||
|
|
||||||
|
shared.generation_lock = Lock()
|
||||||
# Launch the web UI
|
# Launch the web UI
|
||||||
create_interface()
|
create_interface()
|
||||||
while True:
|
while True:
|
||||||
|
|
Loading…
Add table
Reference in a new issue