Compare commits

...

10 commits

Author SHA1 Message Date
oobabooga
63f88b0a9d
Merge 683e02f33d into ad6d0218ae 2025-06-02 02:44:51 +00:00
oobabooga
683e02f33d Add a comment 2025-06-01 19:44:43 -07:00
oobabooga
18c0eb7e38 Fix counting tokens with an empty history 2025-06-01 19:36:14 -07:00
oobabooga
c9510f36e9 Merge branch 'dev' into last-message-only 2025-06-01 19:30:21 -07:00
oobabooga
ad6d0218ae Fix after 219f0a7731 2025-06-01 19:27:14 -07:00
oobabooga
92adceb7b5 UI: Fix the model downloader progress bar 2025-06-01 19:22:21 -07:00
oobabooga
7a81beb0c1 Turn long pasted text into an attachment automatically 2025-06-01 18:26:14 -07:00
oobabooga
405004e622 Remove debug statement 2025-06-01 17:44:43 -07:00
oobabooga
e1cea34d4d Prevent history loss if backend is restarted but UI is not refreshed 2025-06-01 17:44:01 -07:00
oobabooga
cc0a8ba19d Add a 0.05s sleep to make sure the first update goes through 2025-06-01 17:34:09 -07:00
7 changed files with 164 additions and 54 deletions

View file

@ -32,6 +32,7 @@ class ModelDownloader:
self.max_retries = max_retries
self.session = self.get_session()
self._progress_bar_slots = None
self.progress_queue = None
def get_session(self):
session = requests.Session()
@ -218,33 +219,45 @@ class ModelDownloader:
max_retries = self.max_retries
attempt = 0
file_downloaded_count_for_progress = 0
try:
while attempt < max_retries:
attempt += 1
session = self.session
headers = {}
mode = 'wb'
current_file_size_on_disk = 0
try:
if output_path.exists() and not start_from_scratch:
# Resume download
r = session.get(url, stream=True, timeout=20)
total_size = int(r.headers.get('content-length', 0))
if output_path.stat().st_size >= total_size:
current_file_size_on_disk = output_path.stat().st_size
r_head = session.head(url, timeout=20)
r_head.raise_for_status()
total_size = int(r_head.headers.get('content-length', 0))
if current_file_size_on_disk >= total_size and total_size > 0:
if self.progress_queue is not None and total_size > 0:
self.progress_queue.put((1.0, str(filename)))
return
headers = {'Range': f'bytes={output_path.stat().st_size}-'}
headers = {'Range': f'bytes={current_file_size_on_disk}-'}
mode = 'ab'
with session.get(url, stream=True, headers=headers, timeout=30) as r:
r.raise_for_status() # If status is not 2xx, raise an error
total_size = int(r.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1MB
r.raise_for_status()
total_size_from_stream = int(r.headers.get('content-length', 0))
if mode == 'ab':
effective_total_size = current_file_size_on_disk + total_size_from_stream
else:
effective_total_size = total_size_from_stream
filename_str = str(filename) # Convert PosixPath to string if necessary
block_size = 1024 * 1024
filename_str = str(filename)
tqdm_kwargs = {
'total': total_size,
'total': effective_total_size,
'initial': current_file_size_on_disk if mode == 'ab' else 0,
'unit': 'B',
'unit_scale': True,
'unit_divisor': 1024,
@ -261,16 +274,20 @@ class ModelDownloader:
})
with open(output_path, mode) as f:
if mode == 'ab':
f.seek(current_file_size_on_disk)
with tqdm.tqdm(**tqdm_kwargs) as t:
count = 0
file_downloaded_count_for_progress = current_file_size_on_disk
for data in r.iter_content(block_size):
f.write(data)
t.update(len(data))
if total_size != 0 and self.progress_bar is not None:
count += len(data)
self.progress_bar(float(count) / float(total_size), f"{filename_str}")
if effective_total_size != 0 and self.progress_queue is not None:
file_downloaded_count_for_progress += len(data)
progress_fraction = float(file_downloaded_count_for_progress) / float(effective_total_size)
self.progress_queue.put((progress_fraction, filename_str))
break
break # Exit loop if successful
except (RequestException, ConnectionError, Timeout) as e:
print(f"Error downloading {filename}: {e}.")
print(f"That was attempt {attempt}/{max_retries}.", end=' ')
@ -295,10 +312,9 @@ class ModelDownloader:
finally:
print(f"\nDownload of {len(file_list)} files to {output_folder} completed.")
def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False):
self.progress_bar = progress_bar
def download_model_files(self, model, branch, links, sha256, output_folder, progress_queue=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False):
self.progress_queue = progress_queue
# Create the folder and writing the metadata
output_folder.mkdir(parents=True, exist_ok=True)
if not is_llamacpp:

View file

@ -865,6 +865,46 @@ function navigateLastAssistantMessage(direction) {
return false;
}
//------------------------------------------------
// Paste Handler for Long Text
//------------------------------------------------
const MAX_PLAIN_TEXT_LENGTH = 2500;
function setupPasteHandler() {
const textbox = document.querySelector("#chat-input textarea[data-testid=\"textbox\"]");
const fileInput = document.querySelector("#chat-input input[data-testid=\"file-upload\"]");
if (!textbox || !fileInput) {
setTimeout(setupPasteHandler, 500);
return;
}
textbox.addEventListener("paste", async (event) => {
const text = event.clipboardData?.getData("text");
if (text && text.length > MAX_PLAIN_TEXT_LENGTH) {
event.preventDefault();
const file = new File([text], "pasted_text.txt", {
type: "text/plain",
lastModified: Date.now()
});
const dataTransfer = new DataTransfer();
dataTransfer.items.add(file);
fileInput.files = dataTransfer.files;
fileInput.dispatchEvent(new Event("change", { bubbles: true }));
}
});
}
if (document.readyState === "loading") {
document.addEventListener("DOMContentLoaded", setupPasteHandler);
} else {
setupPasteHandler();
}
//------------------------------------------------
// Tooltips
//------------------------------------------------

View file

@ -826,6 +826,8 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
save_interval = 8
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
if i == 0:
time.sleep(0.05) # We need this to make sure the first update goes through
current_time = time.monotonic()
# Save on first iteration or if save_interval seconds have passed

View file

@ -116,7 +116,7 @@ def unload_model(keep_model_name=False):
return
is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer')
if shared.args.loader == 'ExLlamav3_HF':
if shared.model.__class__.__name__ == 'Exllamav3HF':
shared.model.unload()
shared.model = shared.tokenizer = None

View file

@ -6,6 +6,7 @@ import yaml
import extensions
from modules import shared
from modules.chat import load_history
with open(Path(__file__).resolve().parent / '../css/NotoSans/stylesheet.css', 'r') as f:
css = f.read()
@ -269,6 +270,10 @@ def gather_interface_values(*args):
if not shared.args.multi_user:
shared.persistent_interface_state = output
# Prevent history loss if backend is restarted but UI is not refreshed
if output['history'] is None and output['unique_id'] is not None:
output['history'] = load_history(output['unique_id'], output['character_menu'], output['mode'])
return output

View file

@ -18,7 +18,7 @@ def create_ui():
mu = shared.args.multi_user
shared.gradio['Chat input'] = gr.State()
shared.gradio['history'] = gr.State()
shared.gradio['history'] = gr.State({'internal': [], 'visible': [], 'metadata': {}})
with gr.Tab('Chat', id='Chat', elem_id='chat-tab'):
with gr.Row(elem_id='past-chats-row', elem_classes=['pretty_scrollbar']):

View file

@ -1,4 +1,6 @@
import importlib
import queue
import threading
import traceback
from functools import partial
from pathlib import Path
@ -205,48 +207,51 @@ def load_lora_wrapper(selected_loras):
def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False):
downloader_module = importlib.import_module("download-model")
downloader = downloader_module.ModelDownloader()
update_queue = queue.Queue()
try:
# Handle direct GGUF URLs
if repo_id.startswith("https://") and ("huggingface.co" in repo_id) and (repo_id.endswith(".gguf") or repo_id.endswith(".gguf?download=true")):
try:
path = repo_id.split("huggingface.co/")[1]
# Extract the repository ID (first two parts of the path)
parts = path.split("/")
if len(parts) >= 2:
extracted_repo_id = f"{parts[0]}/{parts[1]}"
# Extract the filename (last part of the path)
filename = repo_id.split("/")[-1]
if "?download=true" in filename:
filename = filename.replace("?download=true", "")
filename = repo_id.split("/")[-1].replace("?download=true", "")
repo_id = extracted_repo_id
specific_file = filename
except:
pass
except Exception as e:
yield f"Error parsing GGUF URL: {e}"
progress(0.0)
return
if repo_id == "":
yield ("Please enter a model path")
if not repo_id:
yield "Please enter a model path."
progress(0.0)
return
repo_id = repo_id.strip()
specific_file = specific_file.strip()
downloader = importlib.import_module("download-model").ModelDownloader()
progress(0.0)
progress(0.0, "Preparing download...")
model, branch = downloader.sanitize_model_and_branch_names(repo_id, None)
yield ("Getting the download links from Hugging Face")
yield "Getting download links from Hugging Face..."
links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file)
if not links:
yield "No files found to download for the given model/criteria."
progress(0.0)
return
# Check for multiple GGUF files
gguf_files = [link for link in links if link.lower().endswith('.gguf')]
if len(gguf_files) > 1 and not specific_file:
output = "Multiple GGUF files found. Please copy one of the following filenames to the 'File name' field:\n\n```\n"
for link in gguf_files:
output += f"{Path(link).name}\n"
output += "```"
yield output
return
@ -255,17 +260,13 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
output = "```\n"
for link in links:
output += f"{Path(link).name}" + "\n"
output += "```"
yield output
return
yield ("Getting the output folder")
yield "Determining output folder..."
output_folder = downloader.get_output_folder(
model,
branch,
is_lora,
is_llamacpp=is_llamacpp,
model, branch, is_lora, is_llamacpp=is_llamacpp,
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
)
@ -275,19 +276,65 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
output_folder = Path(shared.args.lora_dir)
if check:
progress(0.5)
yield ("Checking previously downloaded files")
yield "Checking previously downloaded files..."
progress(0.5, "Verifying files...")
downloader.check_model_files(model, branch, links, sha256, output_folder)
progress(1.0)
else:
yield (f"Downloading file{'s' if len(links) > 1 else ''} to `{output_folder}/`")
downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=4, is_llamacpp=is_llamacpp)
progress(1.0, "Verification complete.")
yield "File check complete."
return
yield (f"Model successfully saved to `{output_folder}/`.")
except:
progress(1.0)
yield traceback.format_exc().replace('\n', '\n\n')
yield ""
progress(0.0, "Download starting...")
def downloader_thread_target():
try:
downloader.download_model_files(
model, branch, links, sha256, output_folder,
progress_queue=update_queue,
threads=4,
is_llamacpp=is_llamacpp,
specific_file=specific_file
)
update_queue.put(("COMPLETED", f"Model successfully saved to `{output_folder}/`."))
except Exception as e:
tb_str = traceback.format_exc().replace('\n', '\n\n')
update_queue.put(("ERROR", tb_str))
download_thread = threading.Thread(target=downloader_thread_target)
download_thread.start()
while True:
try:
message = update_queue.get(timeout=0.2)
if not isinstance(message, tuple) or len(message) != 2:
continue
msg_identifier, data = message
if msg_identifier == "COMPLETED":
progress(1.0, "Download complete!")
yield data
break
elif msg_identifier == "ERROR":
progress(0.0, "Error occurred")
yield data
break
elif isinstance(msg_identifier, float):
progress_value = msg_identifier
description_str = data
progress(progress_value, f"Downloading: {description_str}")
except queue.Empty:
if not download_thread.is_alive():
yield "Download process finished."
break
download_thread.join()
except Exception as e:
progress(0.0)
tb_str = traceback.format_exc().replace('\n', '\n\n')
yield tb_str
def update_truncation_length(current_length, state):