mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 06:06:20 -04:00
Compare commits
30 commits
acd91a35d7
...
200e799603
Author | SHA1 | Date | |
---|---|---|---|
|
200e799603 | ||
|
b30a73016d | ||
|
7278548cd1 | ||
|
bb409c926e | ||
|
fe7e1a2565 | ||
|
e8595730b4 | ||
|
17c29fa0a2 | ||
|
dc3094549e | ||
|
ace8afb825 | ||
|
a41da1ec95 | ||
|
6e6f9971a2 | ||
|
1180bb0d80 | ||
|
9bb9ce079e | ||
|
1aa76b3beb | ||
|
1df2b0d3ae | ||
|
62455b415c | ||
|
022664f2bd | ||
|
a778270536 | ||
|
c19b995b8e | ||
|
b1495d52e5 | ||
|
44a6d8a761 | ||
|
4fa52a1302 | ||
|
4eecb6611f | ||
|
c5e54c0b37 | ||
|
14e6baeb48 | ||
|
bb1905ebc5 | ||
|
9b80d1d6c2 | ||
|
80cdbe4e09 | ||
|
769eee1ff3 | ||
|
7c883ef2f0 |
11 changed files with 318 additions and 288 deletions
12
README.md
12
README.md
|
@ -325,6 +325,18 @@ https://github.com/oobabooga/text-generation-webui/wiki
|
|||
|
||||
## Downloading models
|
||||
|
||||
### Pointing to an existing AI model library
|
||||
|
||||
Edit the file `text-generation-webui\user_data\CMD_FLAGS.txt` to include this line:
|
||||
|
||||
```
|
||||
--model-dir 'D:\MyAIModels\'
|
||||
```
|
||||
|
||||
Replace `D:\MyAIModels\` with the path to your model library folder. Sub-folders will be automatically parsed to enumerate all existing models.
|
||||
|
||||
### Manual model download
|
||||
|
||||
Models should be placed in the folder `text-generation-webui/user_data/models`. They are usually downloaded from [Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads).
|
||||
|
||||
* GGUF models are a single file and should be placed directly into `user_data/models`. Example:
|
||||
|
|
|
@ -229,10 +229,23 @@ function removeLastClick() {
|
|||
document.getElementById("Remove-last").click();
|
||||
}
|
||||
|
||||
function handleMorphdomUpdate(text) {
|
||||
function handleMorphdomUpdate(data) {
|
||||
// Determine target element and use it as query scope
|
||||
var target_element, target_html;
|
||||
if (data.last_message_only) {
|
||||
const childNodes = document.getElementsByClassName("messages")[0].childNodes;
|
||||
target_element = childNodes[childNodes.length - 1];
|
||||
target_html = data.html;
|
||||
} else {
|
||||
target_element = document.getElementById("chat").parentNode;
|
||||
target_html = "<div class=\"prose svelte-1ybaih5\">" + data.html + "</div>";
|
||||
}
|
||||
|
||||
const queryScope = target_element;
|
||||
|
||||
// Track open blocks
|
||||
const openBlocks = new Set();
|
||||
document.querySelectorAll(".thinking-block").forEach(block => {
|
||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||
const blockId = block.getAttribute("data-block-id");
|
||||
// If block exists and is open, add to open set
|
||||
if (blockId && block.hasAttribute("open")) {
|
||||
|
@ -242,7 +255,7 @@ function handleMorphdomUpdate(text) {
|
|||
|
||||
// Store scroll positions for any open blocks
|
||||
const scrollPositions = {};
|
||||
document.querySelectorAll(".thinking-block[open]").forEach(block => {
|
||||
queryScope.querySelectorAll(".thinking-block[open]").forEach(block => {
|
||||
const content = block.querySelector(".thinking-content");
|
||||
const blockId = block.getAttribute("data-block-id");
|
||||
if (content && blockId) {
|
||||
|
@ -255,8 +268,8 @@ function handleMorphdomUpdate(text) {
|
|||
});
|
||||
|
||||
morphdom(
|
||||
document.getElementById("chat").parentNode,
|
||||
"<div class=\"prose svelte-1ybaih5\">" + text + "</div>",
|
||||
target_element,
|
||||
target_html,
|
||||
{
|
||||
onBeforeElUpdated: function(fromEl, toEl) {
|
||||
// Preserve code highlighting
|
||||
|
@ -307,7 +320,7 @@ function handleMorphdomUpdate(text) {
|
|||
);
|
||||
|
||||
// Add toggle listeners for new blocks
|
||||
document.querySelectorAll(".thinking-block").forEach(block => {
|
||||
queryScope.querySelectorAll(".thinking-block").forEach(block => {
|
||||
if (!block._hasToggleListener) {
|
||||
block.addEventListener("toggle", function(e) {
|
||||
if (this.open) {
|
||||
|
|
|
@ -656,7 +656,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
update_message_metadata(output['metadata'], "user", row_idx, timestamp=get_current_timestamp())
|
||||
|
||||
# *Is typing...*
|
||||
if loading_message:
|
||||
if loading_message and shared.processing_message:
|
||||
yield {
|
||||
'visible': output['visible'][:-1] + [[output['visible'][-1][0], shared.processing_message]],
|
||||
'internal': output['internal'],
|
||||
|
@ -680,7 +680,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
})
|
||||
output['metadata'][key]["current_version_index"] = len(output['metadata'][key]["versions"]) - 1
|
||||
|
||||
if loading_message:
|
||||
if loading_message and shared.processing_message:
|
||||
yield {
|
||||
'visible': output['visible'][:-1] + [[visible_text, shared.processing_message]],
|
||||
'internal': output['internal'][:-1] + [[text, '']],
|
||||
|
@ -825,7 +825,9 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
|
|||
last_save_time = time.monotonic()
|
||||
save_interval = 8
|
||||
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
|
||||
if i == 0:
|
||||
time.sleep(0.125) # We need this to make sure the first update goes through
|
||||
|
||||
current_time = time.monotonic()
|
||||
# Save on first iteration or if save_interval seconds have passed
|
||||
|
|
|
@ -462,64 +462,69 @@ def actions_html(history, i, role, info_message=""):
|
|||
f'{version_nav_html}')
|
||||
|
||||
|
||||
def generate_instruct_html(history):
|
||||
output = f'<style>{instruct_css}</style><div class="chat" id="chat" data-mode="instruct"><div class="messages">'
|
||||
def generate_instruct_html(history, last_message_only=False):
|
||||
if not last_message_only:
|
||||
output = f'<style>{instruct_css}</style><div class="chat" id="chat" data-mode="instruct"><div class="messages">'
|
||||
else:
|
||||
output = ""
|
||||
|
||||
for i in range(len(history['visible'])):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
def create_message(role, content, raw_content):
|
||||
"""Inner function that captures variables from outer scope."""
|
||||
class_name = "user-message" if role == "user" else "assistant-message"
|
||||
|
||||
# Get timestamps
|
||||
user_timestamp = format_message_timestamp(history, "user", i)
|
||||
assistant_timestamp = format_message_timestamp(history, "assistant", i)
|
||||
# Get role-specific data
|
||||
timestamp = format_message_timestamp(history, role, i)
|
||||
attachments = format_message_attachments(history, role, i)
|
||||
|
||||
# Get attachments
|
||||
user_attachments = format_message_attachments(history, "user", i)
|
||||
assistant_attachments = format_message_attachments(history, "assistant", i)
|
||||
# Create info button if timestamp exists
|
||||
info_message = ""
|
||||
if timestamp:
|
||||
tooltip_text = get_message_tooltip(history, role, i)
|
||||
info_message = info_button.replace('title="message"', f'title="{html.escape(tooltip_text)}"')
|
||||
|
||||
# Create info buttons for timestamps if they exist
|
||||
info_message_user = ""
|
||||
if user_timestamp != "":
|
||||
tooltip_text = get_message_tooltip(history, "user", i)
|
||||
info_message_user = info_button.replace('title="message"', f'title="{html.escape(tooltip_text)}"')
|
||||
|
||||
info_message_assistant = ""
|
||||
if assistant_timestamp != "":
|
||||
tooltip_text = get_message_tooltip(history, "assistant", i)
|
||||
info_message_assistant = info_button.replace('title="message"', f'title="{html.escape(tooltip_text)}"')
|
||||
|
||||
if converted_visible[0]: # Don't display empty user messages
|
||||
output += (
|
||||
f'<div class="user-message" '
|
||||
f'data-raw="{html.escape(row_internal[0], quote=True)}"'
|
||||
f'data-index={i}>'
|
||||
f'<div class="text">'
|
||||
f'<div class="message-body">{converted_visible[0]}</div>'
|
||||
f'{user_attachments}'
|
||||
f'{actions_html(history, i, "user", info_message_user)}'
|
||||
f'</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
output += (
|
||||
f'<div class="assistant-message" '
|
||||
f'data-raw="{html.escape(row_internal[1], quote=True)}"'
|
||||
return (
|
||||
f'<div class="{class_name}" '
|
||||
f'data-raw="{html.escape(raw_content, quote=True)}"'
|
||||
f'data-index={i}>'
|
||||
f'<div class="text">'
|
||||
f'<div class="message-body">{converted_visible[1]}</div>'
|
||||
f'{assistant_attachments}'
|
||||
f'{actions_html(history, i, "assistant", info_message_assistant)}'
|
||||
f'<div class="message-body">{content}</div>'
|
||||
f'{attachments}'
|
||||
f'{actions_html(history, i, role, info_message)}'
|
||||
f'</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
output += "</div></div>"
|
||||
# Determine range
|
||||
start_idx = len(history['visible']) - 1 if last_message_only else 0
|
||||
end_idx = len(history['visible'])
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
|
||||
# Convert content
|
||||
if last_message_only:
|
||||
converted_visible = [None, convert_to_markdown_wrapped(row_visible[1], message_id=i, use_cache=i != len(history['visible']) - 1)]
|
||||
else:
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
|
||||
# Generate messages
|
||||
if not last_message_only and converted_visible[0]:
|
||||
output += create_message("user", converted_visible[0], row_internal[0])
|
||||
|
||||
output += create_message("assistant", converted_visible[1], row_internal[1])
|
||||
|
||||
if not last_message_only:
|
||||
output += "</div></div>"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=False):
|
||||
output = f'<style>{chat_styles[style]}</style><div class="chat" id="chat"><div class="messages">'
|
||||
def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=False, last_message_only=False):
|
||||
if not last_message_only:
|
||||
output = f'<style>{chat_styles[style]}</style><div class="chat" id="chat"><div class="messages">'
|
||||
else:
|
||||
output = ""
|
||||
|
||||
# We use ?character and ?time.time() to force the browser to reset caches
|
||||
img_bot = (
|
||||
|
@ -527,110 +532,117 @@ def generate_cai_chat_html(history, name1, name2, style, character, reset_cache=
|
|||
if Path("user_data/cache/pfp_character_thumb.png").exists() else ''
|
||||
)
|
||||
|
||||
img_me = (
|
||||
f'<img src="file/user_data/cache/pfp_me.png?{time.time() if reset_cache else ""}">'
|
||||
if Path("user_data/cache/pfp_me.png").exists() else ''
|
||||
)
|
||||
def create_message(role, content, raw_content):
|
||||
"""Inner function for CAI-style messages."""
|
||||
circle_class = "circle-you" if role == "user" else "circle-bot"
|
||||
name = name1 if role == "user" else name2
|
||||
|
||||
for i in range(len(history['visible'])):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
# Get role-specific data
|
||||
timestamp = format_message_timestamp(history, role, i, tooltip_include_timestamp=False)
|
||||
attachments = format_message_attachments(history, role, i)
|
||||
|
||||
# Get timestamps
|
||||
user_timestamp = format_message_timestamp(history, "user", i, tooltip_include_timestamp=False)
|
||||
assistant_timestamp = format_message_timestamp(history, "assistant", i, tooltip_include_timestamp=False)
|
||||
# Get appropriate image
|
||||
if role == "user":
|
||||
img = (f'<img src="file/user_data/cache/pfp_me.png?{time.time() if reset_cache else ""}">'
|
||||
if Path("user_data/cache/pfp_me.png").exists() else '')
|
||||
else:
|
||||
img = img_bot
|
||||
|
||||
# Get attachments
|
||||
user_attachments = format_message_attachments(history, "user", i)
|
||||
assistant_attachments = format_message_attachments(history, "assistant", i)
|
||||
|
||||
if converted_visible[0]: # Don't display empty user messages
|
||||
output += (
|
||||
f'<div class="message" '
|
||||
f'data-raw="{html.escape(row_internal[0], quote=True)}"'
|
||||
f'data-index={i}>'
|
||||
f'<div class="circle-you">{img_me}</div>'
|
||||
f'<div class="text">'
|
||||
f'<div class="username">{name1}{user_timestamp}</div>'
|
||||
f'<div class="message-body">{converted_visible[0]}</div>'
|
||||
f'{user_attachments}'
|
||||
f'{actions_html(history, i, "user")}'
|
||||
f'</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
output += (
|
||||
return (
|
||||
f'<div class="message" '
|
||||
f'data-raw="{html.escape(row_internal[1], quote=True)}"'
|
||||
f'data-raw="{html.escape(raw_content, quote=True)}"'
|
||||
f'data-index={i}>'
|
||||
f'<div class="circle-bot">{img_bot}</div>'
|
||||
f'<div class="{circle_class}">{img}</div>'
|
||||
f'<div class="text">'
|
||||
f'<div class="username">{name2}{assistant_timestamp}</div>'
|
||||
f'<div class="message-body">{converted_visible[1]}</div>'
|
||||
f'{assistant_attachments}'
|
||||
f'{actions_html(history, i, "assistant")}'
|
||||
f'<div class="username">{name}{timestamp}</div>'
|
||||
f'<div class="message-body">{content}</div>'
|
||||
f'{attachments}'
|
||||
f'{actions_html(history, i, role)}'
|
||||
f'</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
output += "</div></div>"
|
||||
# Determine range
|
||||
start_idx = len(history['visible']) - 1 if last_message_only else 0
|
||||
end_idx = len(history['visible'])
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
|
||||
# Convert content
|
||||
if last_message_only:
|
||||
converted_visible = [None, convert_to_markdown_wrapped(row_visible[1], message_id=i, use_cache=i != len(history['visible']) - 1)]
|
||||
else:
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
|
||||
# Generate messages
|
||||
if not last_message_only and converted_visible[0]:
|
||||
output += create_message("user", converted_visible[0], row_internal[0])
|
||||
|
||||
output += create_message("assistant", converted_visible[1], row_internal[1])
|
||||
|
||||
if not last_message_only:
|
||||
output += "</div></div>"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def generate_chat_html(history, name1, name2, reset_cache=False):
|
||||
output = f'<style>{chat_styles["wpp"]}</style><div class="chat" id="chat"><div class="messages">'
|
||||
def generate_chat_html(history, name1, name2, reset_cache=False, last_message_only=False):
|
||||
if not last_message_only:
|
||||
output = f'<style>{chat_styles["wpp"]}</style><div class="chat" id="chat"><div class="messages">'
|
||||
else:
|
||||
output = ""
|
||||
|
||||
for i in range(len(history['visible'])):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
def create_message(role, content, raw_content):
|
||||
"""Inner function for WPP-style messages."""
|
||||
text_class = "text-you" if role == "user" else "text-bot"
|
||||
|
||||
# Get timestamps
|
||||
user_timestamp = format_message_timestamp(history, "user", i)
|
||||
assistant_timestamp = format_message_timestamp(history, "assistant", i)
|
||||
# Get role-specific data
|
||||
timestamp = format_message_timestamp(history, role, i)
|
||||
attachments = format_message_attachments(history, role, i)
|
||||
|
||||
# Get attachments
|
||||
user_attachments = format_message_attachments(history, "user", i)
|
||||
assistant_attachments = format_message_attachments(history, "assistant", i)
|
||||
# Create info button if timestamp exists
|
||||
info_message = ""
|
||||
if timestamp:
|
||||
tooltip_text = get_message_tooltip(history, role, i)
|
||||
info_message = info_button.replace('title="message"', f'title="{html.escape(tooltip_text)}"')
|
||||
|
||||
# Create info buttons for timestamps if they exist
|
||||
info_message_user = ""
|
||||
if user_timestamp != "":
|
||||
tooltip_text = get_message_tooltip(history, "user", i)
|
||||
info_message_user = info_button.replace('title="message"', f'title="{html.escape(tooltip_text)}"')
|
||||
|
||||
info_message_assistant = ""
|
||||
if assistant_timestamp != "":
|
||||
tooltip_text = get_message_tooltip(history, "assistant", i)
|
||||
info_message_assistant = info_button.replace('title="message"', f'title="{html.escape(tooltip_text)}"')
|
||||
|
||||
if converted_visible[0]: # Don't display empty user messages
|
||||
output += (
|
||||
f'<div class="message" '
|
||||
f'data-raw="{html.escape(row_internal[0], quote=True)}"'
|
||||
f'data-index={i}>'
|
||||
f'<div class="text-you">'
|
||||
f'<div class="message-body">{converted_visible[0]}</div>'
|
||||
f'{user_attachments}'
|
||||
f'{actions_html(history, i, "user", info_message_user)}'
|
||||
f'</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
output += (
|
||||
return (
|
||||
f'<div class="message" '
|
||||
f'data-raw="{html.escape(row_internal[1], quote=True)}"'
|
||||
f'data-raw="{html.escape(raw_content, quote=True)}"'
|
||||
f'data-index={i}>'
|
||||
f'<div class="text-bot">'
|
||||
f'<div class="message-body">{converted_visible[1]}</div>'
|
||||
f'{assistant_attachments}'
|
||||
f'{actions_html(history, i, "assistant", info_message_assistant)}'
|
||||
f'<div class="{text_class}">'
|
||||
f'<div class="message-body">{content}</div>'
|
||||
f'{attachments}'
|
||||
f'{actions_html(history, i, role, info_message)}'
|
||||
f'</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
output += "</div></div>"
|
||||
# Determine range
|
||||
start_idx = len(history['visible']) - 1 if last_message_only else 0
|
||||
end_idx = len(history['visible'])
|
||||
|
||||
for i in range(start_idx, end_idx):
|
||||
row_visible = history['visible'][i]
|
||||
row_internal = history['internal'][i]
|
||||
|
||||
# Convert content
|
||||
if last_message_only:
|
||||
converted_visible = [None, convert_to_markdown_wrapped(row_visible[1], message_id=i, use_cache=i != len(history['visible']) - 1)]
|
||||
else:
|
||||
converted_visible = [convert_to_markdown_wrapped(entry, message_id=i, use_cache=i != len(history['visible']) - 1) for entry in row_visible]
|
||||
|
||||
# Generate messages
|
||||
if not last_message_only and converted_visible[0]:
|
||||
output += create_message("user", converted_visible[0], row_internal[0])
|
||||
|
||||
output += create_message("assistant", converted_visible[1], row_internal[1])
|
||||
|
||||
if not last_message_only:
|
||||
output += "</div></div>"
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
@ -644,15 +656,15 @@ def time_greeting():
|
|||
return "Good evening!"
|
||||
|
||||
|
||||
def chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=False):
|
||||
def chat_html_wrapper(history, name1, name2, mode, style, character, reset_cache=False, last_message_only=False):
|
||||
if len(history['visible']) == 0:
|
||||
greeting = f"<div class=\"welcome-greeting\">{time_greeting()} How can I help you today?</div>"
|
||||
result = f'<div class="chat" id="chat">{greeting}</div>'
|
||||
elif mode == 'instruct':
|
||||
result = generate_instruct_html(history)
|
||||
result = generate_instruct_html(history, last_message_only=last_message_only)
|
||||
elif style == 'wpp':
|
||||
result = generate_chat_html(history, name1, name2)
|
||||
result = generate_chat_html(history, name1, name2, last_message_only=last_message_only)
|
||||
else:
|
||||
result = generate_cai_chat_html(history, name1, name2, style, character, reset_cache)
|
||||
result = generate_cai_chat_html(history, name1, name2, style, character, reset_cache=reset_cache, last_message_only=last_message_only)
|
||||
|
||||
return {'html': result}
|
||||
return {'html': result, 'last_message_only': last_message_only}
|
||||
|
|
|
@ -21,7 +21,7 @@ lora_names = []
|
|||
# Generation variables
|
||||
stop_everything = False
|
||||
generation_lock = None
|
||||
processing_message = '*Is typing...*'
|
||||
processing_message = ''
|
||||
|
||||
# UI variables
|
||||
gradio = {}
|
||||
|
@ -47,7 +47,6 @@ settings = {
|
|||
'max_new_tokens_max': 4096,
|
||||
'prompt_lookup_num_tokens': 0,
|
||||
'max_tokens_second': 0,
|
||||
'max_updates_second': 12,
|
||||
'auto_max_new_tokens': True,
|
||||
'ban_eos_token': False,
|
||||
'add_bos_token': True,
|
||||
|
|
|
@ -65,41 +65,39 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
|
|||
all_stop_strings += st
|
||||
|
||||
shared.stop_everything = False
|
||||
last_update = -1
|
||||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
state = copy.deepcopy(state)
|
||||
state['stream'] = True
|
||||
|
||||
min_update_interval = 0
|
||||
if state.get('max_updates_second', 0) > 0:
|
||||
min_update_interval = 1 / state['max_updates_second']
|
||||
|
||||
# Generate
|
||||
last_update = -1
|
||||
latency_threshold = 1 / 1000
|
||||
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
|
||||
cur_time = time.monotonic()
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
if escape_html:
|
||||
reply = html.escape(reply)
|
||||
|
||||
if is_stream:
|
||||
cur_time = time.time()
|
||||
|
||||
# Limit number of tokens/second to make text readable in real time
|
||||
if state['max_tokens_second'] > 0:
|
||||
diff = 1 / state['max_tokens_second'] - (cur_time - last_update)
|
||||
if diff > 0:
|
||||
time.sleep(diff)
|
||||
|
||||
last_update = time.time()
|
||||
last_update = time.monotonic()
|
||||
yield reply
|
||||
|
||||
# Limit updates to avoid lag in the Gradio UI
|
||||
# API updates are not limited
|
||||
else:
|
||||
if cur_time - last_update > min_update_interval:
|
||||
last_update = cur_time
|
||||
# If 'generate_func' takes less than 0.001 seconds to yield the next token
|
||||
# (equivalent to more than 1000 tok/s), assume that the UI is lagging behind and skip yielding
|
||||
if (cur_time - last_update) > latency_threshold:
|
||||
yield reply
|
||||
last_update = time.monotonic()
|
||||
|
||||
if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything):
|
||||
break
|
||||
|
|
|
@ -6,6 +6,7 @@ import yaml
|
|||
|
||||
import extensions
|
||||
from modules import shared
|
||||
from modules.chat import load_history
|
||||
|
||||
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f:
|
||||
css = f.read()
|
||||
|
@ -194,7 +195,6 @@ def list_interface_input_elements():
|
|||
'max_new_tokens',
|
||||
'prompt_lookup_num_tokens',
|
||||
'max_tokens_second',
|
||||
'max_updates_second',
|
||||
'do_sample',
|
||||
'dynamic_temperature',
|
||||
'temperature_last',
|
||||
|
@ -270,6 +270,10 @@ def gather_interface_values(*args):
|
|||
if not shared.args.multi_user:
|
||||
shared.persistent_interface_state = output
|
||||
|
||||
# Prevent history loss if backend is restarted but UI is not refreshed
|
||||
if output['history'] is None and output['unique_id'] is not None:
|
||||
output['history'] = load_history(output['unique_id'], output['character_menu'], output['mode'])
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ def create_ui():
|
|||
mu = shared.args.multi_user
|
||||
|
||||
shared.gradio['Chat input'] = gr.State()
|
||||
shared.gradio['history'] = gr.JSON(visible=False)
|
||||
shared.gradio['history'] = gr.State({'internal': [], 'visible': [], 'metadata': {}})
|
||||
|
||||
with gr.Tab('Chat', id='Chat', elem_id='chat-tab'):
|
||||
with gr.Row(elem_id='past-chats-row', elem_classes=['pretty_scrollbar']):
|
||||
|
@ -195,7 +195,7 @@ def create_event_handlers():
|
|||
shared.reload_inputs = gradio(reload_arr)
|
||||
|
||||
# Morph HTML updates instead of updating everything
|
||||
shared.gradio['display'].change(None, gradio('display'), None, js="(data) => handleMorphdomUpdate(data.html)")
|
||||
shared.gradio['display'].change(None, gradio('display'), None, js="(data) => handleMorphdomUpdate(data)")
|
||||
|
||||
shared.gradio['Generate'].click(
|
||||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
|
|
|
@ -71,8 +71,6 @@ def create_ui(default_preset):
|
|||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], value=shared.settings['max_new_tokens'], step=1, label='max_new_tokens', info='⚠️ Setting this too high can cause prompt truncation.')
|
||||
shared.gradio['prompt_lookup_num_tokens'] = gr.Slider(value=shared.settings['prompt_lookup_num_tokens'], minimum=0, maximum=10, step=1, label='prompt_lookup_num_tokens', info='Activates Prompt Lookup Decoding.')
|
||||
shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum tokens/second', info='To make text readable in real time.')
|
||||
shared.gradio['max_updates_second'] = gr.Slider(value=shared.settings['max_updates_second'], minimum=0, maximum=24, step=1, label='Maximum UI updates/second', info='Set this if you experience lag in the UI during streaming.')
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
|
253
one_click.py
253
one_click.py
|
@ -70,12 +70,8 @@ def is_installed():
|
|||
def cpu_has_avx2():
|
||||
try:
|
||||
import cpuinfo
|
||||
|
||||
info = cpuinfo.get_cpu_info()
|
||||
if 'avx2' in info['flags']:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return 'avx2' in info['flags']
|
||||
except:
|
||||
return True
|
||||
|
||||
|
@ -83,30 +79,112 @@ def cpu_has_avx2():
|
|||
def cpu_has_amx():
|
||||
try:
|
||||
import cpuinfo
|
||||
|
||||
info = cpuinfo.get_cpu_info()
|
||||
if 'amx' in info['flags']:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return 'amx' in info['flags']
|
||||
except:
|
||||
return True
|
||||
|
||||
|
||||
def torch_version():
|
||||
site_packages_path = None
|
||||
for sitedir in site.getsitepackages():
|
||||
if "site-packages" in sitedir and conda_env_path in sitedir:
|
||||
site_packages_path = sitedir
|
||||
break
|
||||
def load_state():
|
||||
"""Load installer state from JSON file"""
|
||||
if os.path.exists(state_file):
|
||||
try:
|
||||
with open(state_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
if site_packages_path:
|
||||
torch_version_file = open(os.path.join(site_packages_path, 'torch', 'version.py')).read().splitlines()
|
||||
torver = [line for line in torch_version_file if line.startswith('__version__')][0].split('__version__ = ')[1].strip("'")
|
||||
|
||||
def save_state(state):
|
||||
"""Save installer state to JSON file"""
|
||||
with open(state_file, 'w') as f:
|
||||
json.dump(state, f)
|
||||
|
||||
|
||||
def get_gpu_choice():
|
||||
"""Get GPU choice from state file or ask user"""
|
||||
state = load_state()
|
||||
gpu_choice = state.get('gpu_choice')
|
||||
|
||||
if not gpu_choice:
|
||||
if "GPU_CHOICE" in os.environ:
|
||||
choice = os.environ["GPU_CHOICE"].upper()
|
||||
print_big_message(f"Selected GPU choice \"{choice}\" based on the GPU_CHOICE environment variable.")
|
||||
else:
|
||||
choice = get_user_choice(
|
||||
"What is your GPU?",
|
||||
{
|
||||
'A': 'NVIDIA - CUDA 12.4',
|
||||
'B': 'AMD - Linux/macOS only, requires ROCm 6.2.4',
|
||||
'C': 'Apple M Series',
|
||||
'D': 'Intel Arc (beta)',
|
||||
'N': 'CPU mode'
|
||||
},
|
||||
)
|
||||
|
||||
# Convert choice to GPU name
|
||||
gpu_choice = {"A": "NVIDIA", "B": "AMD", "C": "APPLE", "D": "INTEL", "N": "NONE"}[choice]
|
||||
|
||||
# Save choice to state
|
||||
state['gpu_choice'] = gpu_choice
|
||||
save_state(state)
|
||||
|
||||
return gpu_choice
|
||||
|
||||
|
||||
def get_pytorch_install_command(gpu_choice):
|
||||
"""Get PyTorch installation command based on GPU choice"""
|
||||
base_cmd = f"python -m pip install torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION} "
|
||||
|
||||
if gpu_choice == "NVIDIA":
|
||||
return base_cmd + "--index-url https://download.pytorch.org/whl/cu124"
|
||||
elif gpu_choice == "AMD":
|
||||
return base_cmd + "--index-url https://download.pytorch.org/whl/rocm6.2.4"
|
||||
elif gpu_choice in ["APPLE", "NONE"]:
|
||||
return base_cmd + "--index-url https://download.pytorch.org/whl/cpu"
|
||||
elif gpu_choice == "INTEL":
|
||||
if is_linux():
|
||||
return "python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
return "python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
from torch import __version__ as torver
|
||||
return base_cmd
|
||||
|
||||
return torver
|
||||
|
||||
def get_pytorch_update_command(gpu_choice):
|
||||
"""Get PyTorch update command based on GPU choice"""
|
||||
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}"
|
||||
|
||||
if gpu_choice == "NVIDIA":
|
||||
return f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124"
|
||||
elif gpu_choice == "AMD":
|
||||
return f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.2.4"
|
||||
elif gpu_choice in ["APPLE", "NONE"]:
|
||||
return f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
|
||||
elif gpu_choice == "INTEL":
|
||||
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
|
||||
return f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
return base_cmd
|
||||
|
||||
|
||||
def get_requirements_file(gpu_choice):
|
||||
"""Get requirements file path based on GPU choice"""
|
||||
requirements_base = os.path.join("requirements", "full")
|
||||
|
||||
if gpu_choice == "AMD":
|
||||
file_name = f"requirements_amd{'_noavx2' if not cpu_has_avx2() else ''}.txt"
|
||||
elif gpu_choice == "APPLE":
|
||||
file_name = f"requirements_apple_{'intel' if is_x86_64() else 'silicon'}.txt"
|
||||
elif gpu_choice in ["INTEL", "NONE"]:
|
||||
file_name = f"requirements_cpu_only{'_noavx2' if not cpu_has_avx2() else ''}.txt"
|
||||
elif gpu_choice == "NVIDIA":
|
||||
file_name = f"requirements{'_noavx2' if not cpu_has_avx2() else ''}.txt"
|
||||
else:
|
||||
raise ValueError(f"Unknown GPU choice: {gpu_choice}")
|
||||
|
||||
return os.path.join(requirements_base, file_name)
|
||||
|
||||
|
||||
def get_current_commit():
|
||||
|
@ -209,28 +287,8 @@ def get_user_choice(question, options_dict):
|
|||
|
||||
def update_pytorch_and_python():
|
||||
print_big_message("Checking for PyTorch updates.")
|
||||
|
||||
# Update the Python version. Left here for future reference in case this becomes necessary.
|
||||
# print_big_message("Checking for PyTorch and Python updates.")
|
||||
# current_python_version = f"{sys.version_info.major}.{sys.version_info.minor}"
|
||||
# if current_python_version != PYTHON_VERSION:
|
||||
# run_cmd(f"conda install -y python={PYTHON_VERSION}", assert_success=True, environment=True)
|
||||
|
||||
torver = torch_version()
|
||||
base_cmd = f"python -m pip install --upgrade torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION}"
|
||||
|
||||
if "+cu" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cu124"
|
||||
elif "+rocm" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/rocm6.2.4"
|
||||
elif "+cpu" in torver:
|
||||
install_cmd = f"{base_cmd} --index-url https://download.pytorch.org/whl/cpu"
|
||||
elif "+cxx11" in torver:
|
||||
intel_extension = "intel-extension-for-pytorch==2.1.10+xpu" if is_linux() else "intel-extension-for-pytorch==2.1.10"
|
||||
install_cmd = f"{base_cmd} {intel_extension} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
install_cmd = base_cmd
|
||||
|
||||
gpu_choice = get_gpu_choice()
|
||||
install_cmd = get_pytorch_update_command(gpu_choice)
|
||||
run_cmd(install_cmd, assert_success=True, environment=True)
|
||||
|
||||
|
||||
|
@ -256,43 +314,11 @@ def install_webui():
|
|||
if os.path.isfile(state_file):
|
||||
os.remove(state_file)
|
||||
|
||||
# Ask the user for the GPU vendor
|
||||
if "GPU_CHOICE" in os.environ:
|
||||
choice = os.environ["GPU_CHOICE"].upper()
|
||||
print_big_message(f"Selected GPU choice \"{choice}\" based on the GPU_CHOICE environment variable.")
|
||||
|
||||
# Warn about changed meanings and handle old choices
|
||||
if choice == "B":
|
||||
print_big_message("Warning: GPU_CHOICE='B' now means 'AMD' in the new version.")
|
||||
elif choice == "C":
|
||||
print_big_message("Warning: GPU_CHOICE='C' now means 'Apple M Series' in the new version.")
|
||||
elif choice == "D":
|
||||
print_big_message("Warning: GPU_CHOICE='D' now means 'Intel Arc' in the new version.")
|
||||
else:
|
||||
choice = get_user_choice(
|
||||
"What is your GPU?",
|
||||
{
|
||||
'A': 'NVIDIA - CUDA 12.4',
|
||||
'B': 'AMD - Linux/macOS only, requires ROCm 6.2.4',
|
||||
'C': 'Apple M Series',
|
||||
'D': 'Intel Arc (beta)',
|
||||
'N': 'CPU mode'
|
||||
},
|
||||
)
|
||||
|
||||
# Convert choices to GPU names for compatibility
|
||||
gpu_choice_to_name = {
|
||||
"A": "NVIDIA",
|
||||
"B": "AMD",
|
||||
"C": "APPLE",
|
||||
"D": "INTEL",
|
||||
"N": "NONE"
|
||||
}
|
||||
|
||||
selected_gpu = gpu_choice_to_name[choice]
|
||||
# Get GPU choice and save it to state
|
||||
gpu_choice = get_gpu_choice()
|
||||
|
||||
# Write a flag to CMD_FLAGS.txt for CPU mode
|
||||
if selected_gpu == "NONE":
|
||||
if gpu_choice == "NONE":
|
||||
cmd_flags_path = os.path.join(script_dir, "user_data", "CMD_FLAGS.txt")
|
||||
with open(cmd_flags_path, 'r+') as cmd_flags_file:
|
||||
if "--cpu" not in cmd_flags_file.read():
|
||||
|
@ -300,34 +326,20 @@ def install_webui():
|
|||
cmd_flags_file.write("\n--cpu\n")
|
||||
|
||||
# Handle CUDA version display
|
||||
elif any((is_windows(), is_linux())) and selected_gpu == "NVIDIA":
|
||||
elif any((is_windows(), is_linux())) and gpu_choice == "NVIDIA":
|
||||
print("CUDA: 12.4")
|
||||
|
||||
# No PyTorch for AMD on Windows (?)
|
||||
elif is_windows() and selected_gpu == "AMD":
|
||||
elif is_windows() and gpu_choice == "AMD":
|
||||
print("PyTorch setup on Windows is not implemented yet. Exiting...")
|
||||
sys.exit(1)
|
||||
|
||||
# Find the Pytorch installation command
|
||||
install_pytorch = f"python -m pip install torch=={TORCH_VERSION} torchvision=={TORCHVISION_VERSION} torchaudio=={TORCHAUDIO_VERSION} "
|
||||
|
||||
if selected_gpu == "NVIDIA":
|
||||
install_pytorch += "--index-url https://download.pytorch.org/whl/cu124"
|
||||
elif selected_gpu == "AMD":
|
||||
install_pytorch += "--index-url https://download.pytorch.org/whl/rocm6.2.4"
|
||||
elif selected_gpu in ["APPLE", "NONE"]:
|
||||
install_pytorch += "--index-url https://download.pytorch.org/whl/cpu"
|
||||
elif selected_gpu == "INTEL":
|
||||
if is_linux():
|
||||
install_pytorch = "python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
else:
|
||||
install_pytorch = "python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 torchaudio==2.1.0a0 intel-extension-for-pytorch==2.1.10 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/"
|
||||
|
||||
# Install Git and then Pytorch
|
||||
print_big_message("Installing PyTorch.")
|
||||
install_pytorch = get_pytorch_install_command(gpu_choice)
|
||||
run_cmd(f"conda install -y ninja git && {install_pytorch} && python -m pip install py-cpuinfo==9.0.0", assert_success=True, environment=True)
|
||||
|
||||
if selected_gpu == "INTEL":
|
||||
if gpu_choice == "INTEL":
|
||||
# Install oneAPI dependencies via conda
|
||||
print_big_message("Installing Intel oneAPI runtime libraries.")
|
||||
run_cmd("conda install -y -c https://software.repos.intel.com/python/conda/ -c conda-forge dpcpp-cpp-rt=2024.0 mkl-dpcpp=2024.0", environment=True)
|
||||
|
@ -349,31 +361,15 @@ def update_requirements(initial_installation=False, pull=True):
|
|||
assert_success=True
|
||||
)
|
||||
|
||||
torver = torch_version()
|
||||
requirements_base = os.path.join("requirements", "full")
|
||||
|
||||
if "+rocm" in torver:
|
||||
file_name = f"requirements_amd{'_noavx2' if not cpu_has_avx2() else ''}.txt"
|
||||
elif "+cpu" in torver or "+cxx11" in torver:
|
||||
file_name = f"requirements_cpu_only{'_noavx2' if not cpu_has_avx2() else ''}.txt"
|
||||
elif is_macos():
|
||||
file_name = f"requirements_apple_{'intel' if is_x86_64() else 'silicon'}.txt"
|
||||
else:
|
||||
file_name = f"requirements{'_noavx2' if not cpu_has_avx2() else ''}.txt"
|
||||
|
||||
requirements_file = os.path.join(requirements_base, file_name)
|
||||
|
||||
# Load state from JSON file
|
||||
current_commit = get_current_commit()
|
||||
wheels_changed = False
|
||||
if os.path.exists(state_file):
|
||||
with open(state_file, 'r') as f:
|
||||
last_state = json.load(f)
|
||||
|
||||
if 'wheels_changed' in last_state or last_state.get('last_installed_commit') != current_commit:
|
||||
wheels_changed = not os.path.exists(state_file)
|
||||
if not wheels_changed:
|
||||
state = load_state()
|
||||
if 'wheels_changed' in state or state.get('last_installed_commit') != current_commit:
|
||||
wheels_changed = True
|
||||
else:
|
||||
wheels_changed = True
|
||||
|
||||
gpu_choice = get_gpu_choice()
|
||||
requirements_file = get_requirements_file(gpu_choice)
|
||||
|
||||
if pull:
|
||||
# Read .whl lines before pulling
|
||||
|
@ -409,19 +405,17 @@ def update_requirements(initial_installation=False, pull=True):
|
|||
print_big_message(f"File '{file}' was updated during 'git pull'. Please run the script again.")
|
||||
|
||||
# Save state before exiting
|
||||
current_state = {}
|
||||
state = load_state()
|
||||
if wheels_changed:
|
||||
current_state['wheels_changed'] = True
|
||||
|
||||
with open(state_file, 'w') as f:
|
||||
json.dump(current_state, f)
|
||||
|
||||
state['wheels_changed'] = True
|
||||
save_state(state)
|
||||
sys.exit(1)
|
||||
|
||||
# Save current state
|
||||
current_state = {'last_installed_commit': current_commit}
|
||||
with open(state_file, 'w') as f:
|
||||
json.dump(current_state, f)
|
||||
state = load_state()
|
||||
state['last_installed_commit'] = current_commit
|
||||
state.pop('wheels_changed', None) # Remove wheels_changed flag
|
||||
save_state(state)
|
||||
|
||||
if os.environ.get("INSTALL_EXTENSIONS", "").lower() in ("yes", "y", "true", "1", "t", "on"):
|
||||
install_extensions_requirements()
|
||||
|
@ -432,11 +426,10 @@ def update_requirements(initial_installation=False, pull=True):
|
|||
# Update PyTorch
|
||||
if not initial_installation:
|
||||
update_pytorch_and_python()
|
||||
torver = torch_version()
|
||||
clean_outdated_pytorch_cuda_dependencies()
|
||||
|
||||
print_big_message(f"Installing webui requirements from file: {requirements_file}")
|
||||
print(f"TORCH: {torver}\n")
|
||||
print(f"GPU Choice: {gpu_choice}\n")
|
||||
|
||||
# Prepare the requirements file
|
||||
textgen_requirements = open(requirements_file).read().splitlines()
|
||||
|
|
|
@ -18,7 +18,6 @@ max_new_tokens_min: 1
|
|||
max_new_tokens_max: 4096
|
||||
prompt_lookup_num_tokens: 0
|
||||
max_tokens_second: 0
|
||||
max_updates_second: 12
|
||||
auto_max_new_tokens: true
|
||||
ban_eos_token: false
|
||||
add_bos_token: true
|
||||
|
|
Loading…
Add table
Reference in a new issue