diff --git a/download-model.py b/download-model.py index 25517491..576a8b79 100644 --- a/download-model.py +++ b/download-model.py @@ -32,6 +32,7 @@ class ModelDownloader: self.max_retries = max_retries self.session = self.get_session() self._progress_bar_slots = None + self.progress_queue = None def get_session(self): session = requests.Session() @@ -218,33 +219,45 @@ class ModelDownloader: max_retries = self.max_retries attempt = 0 + file_downloaded_count_for_progress = 0 + try: while attempt < max_retries: attempt += 1 session = self.session headers = {} mode = 'wb' + current_file_size_on_disk = 0 try: if output_path.exists() and not start_from_scratch: - # Resume download - r = session.get(url, stream=True, timeout=20) - total_size = int(r.headers.get('content-length', 0)) - if output_path.stat().st_size >= total_size: + current_file_size_on_disk = output_path.stat().st_size + r_head = session.head(url, timeout=20) + r_head.raise_for_status() + total_size = int(r_head.headers.get('content-length', 0)) + + if current_file_size_on_disk >= total_size and total_size > 0: + if self.progress_queue is not None and total_size > 0: + self.progress_queue.put((1.0, str(filename))) return - headers = {'Range': f'bytes={output_path.stat().st_size}-'} + headers = {'Range': f'bytes={current_file_size_on_disk}-'} mode = 'ab' with session.get(url, stream=True, headers=headers, timeout=30) as r: - r.raise_for_status() # If status is not 2xx, raise an error - total_size = int(r.headers.get('content-length', 0)) - block_size = 1024 * 1024 # 1MB + r.raise_for_status() + total_size_from_stream = int(r.headers.get('content-length', 0)) + if mode == 'ab': + effective_total_size = current_file_size_on_disk + total_size_from_stream + else: + effective_total_size = total_size_from_stream - filename_str = str(filename) # Convert PosixPath to string if necessary + block_size = 1024 * 1024 + filename_str = str(filename) tqdm_kwargs = { - 'total': total_size, + 'total': effective_total_size, + 'initial': current_file_size_on_disk if mode == 'ab' else 0, 'unit': 'B', 'unit_scale': True, 'unit_divisor': 1024, @@ -261,16 +274,20 @@ class ModelDownloader: }) with open(output_path, mode) as f: + if mode == 'ab': + f.seek(current_file_size_on_disk) + with tqdm.tqdm(**tqdm_kwargs) as t: - count = 0 + file_downloaded_count_for_progress = current_file_size_on_disk for data in r.iter_content(block_size): f.write(data) t.update(len(data)) - if total_size != 0 and self.progress_bar is not None: - count += len(data) - self.progress_bar(float(count) / float(total_size), f"{filename_str}") + if effective_total_size != 0 and self.progress_queue is not None: + file_downloaded_count_for_progress += len(data) + progress_fraction = float(file_downloaded_count_for_progress) / float(effective_total_size) + self.progress_queue.put((progress_fraction, filename_str)) + break - break # Exit loop if successful except (RequestException, ConnectionError, Timeout) as e: print(f"Error downloading {filename}: {e}.") print(f"That was attempt {attempt}/{max_retries}.", end=' ') @@ -295,10 +312,9 @@ class ModelDownloader: finally: print(f"\nDownload of {len(file_list)} files to {output_folder} completed.") - def download_model_files(self, model, branch, links, sha256, output_folder, progress_bar=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False): - self.progress_bar = progress_bar + def download_model_files(self, model, branch, links, sha256, output_folder, progress_queue=None, start_from_scratch=False, threads=4, specific_file=None, is_llamacpp=False): + self.progress_queue = progress_queue - # Create the folder and writing the metadata output_folder.mkdir(parents=True, exist_ok=True) if not is_llamacpp: diff --git a/js/main.js b/js/main.js index 05c19571..8090937f 100644 --- a/js/main.js +++ b/js/main.js @@ -865,6 +865,46 @@ function navigateLastAssistantMessage(direction) { return false; } +//------------------------------------------------ +// Paste Handler for Long Text +//------------------------------------------------ + +const MAX_PLAIN_TEXT_LENGTH = 2500; + +function setupPasteHandler() { + const textbox = document.querySelector("#chat-input textarea[data-testid=\"textbox\"]"); + const fileInput = document.querySelector("#chat-input input[data-testid=\"file-upload\"]"); + + if (!textbox || !fileInput) { + setTimeout(setupPasteHandler, 500); + return; + } + + textbox.addEventListener("paste", async (event) => { + const text = event.clipboardData?.getData("text"); + + if (text && text.length > MAX_PLAIN_TEXT_LENGTH) { + event.preventDefault(); + + const file = new File([text], "pasted_text.txt", { + type: "text/plain", + lastModified: Date.now() + }); + + const dataTransfer = new DataTransfer(); + dataTransfer.items.add(file); + fileInput.files = dataTransfer.files; + fileInput.dispatchEvent(new Event("change", { bubbles: true })); + } + }); +} + +if (document.readyState === "loading") { + document.addEventListener("DOMContentLoaded", setupPasteHandler); +} else { + setupPasteHandler(); +} + //------------------------------------------------ // Tooltips //------------------------------------------------ diff --git a/modules/models.py b/modules/models.py index d329ae3c..c1e7fb56 100644 --- a/modules/models.py +++ b/modules/models.py @@ -116,7 +116,7 @@ def unload_model(keep_model_name=False): return is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer') - if shared.args.loader == 'ExLlamav3_HF': + if shared.model.__class__.__name__ == 'Exllamav3HF': shared.model.unload() shared.model = shared.tokenizer = None diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 862b3893..2a7d3d9d 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -1,4 +1,6 @@ import importlib +import queue +import threading import traceback from functools import partial from pathlib import Path @@ -205,48 +207,51 @@ def load_lora_wrapper(selected_loras): def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), return_links=False, check=False): + downloader_module = importlib.import_module("download-model") + downloader = downloader_module.ModelDownloader() + update_queue = queue.Queue() + try: # Handle direct GGUF URLs if repo_id.startswith("https://") and ("huggingface.co" in repo_id) and (repo_id.endswith(".gguf") or repo_id.endswith(".gguf?download=true")): try: path = repo_id.split("huggingface.co/")[1] - - # Extract the repository ID (first two parts of the path) parts = path.split("/") if len(parts) >= 2: extracted_repo_id = f"{parts[0]}/{parts[1]}" - - # Extract the filename (last part of the path) - filename = repo_id.split("/")[-1] - if "?download=true" in filename: - filename = filename.replace("?download=true", "") - + filename = repo_id.split("/")[-1].replace("?download=true", "") repo_id = extracted_repo_id specific_file = filename - except: - pass + except Exception as e: + yield f"Error parsing GGUF URL: {e}" + progress(0.0) + return - if repo_id == "": - yield ("Please enter a model path") + if not repo_id: + yield "Please enter a model path." + progress(0.0) return repo_id = repo_id.strip() specific_file = specific_file.strip() - downloader = importlib.import_module("download-model").ModelDownloader() - progress(0.0) + progress(0.0, "Preparing download...") + model, branch = downloader.sanitize_model_and_branch_names(repo_id, None) - - yield ("Getting the download links from Hugging Face") + yield "Getting download links from Hugging Face..." links, sha256, is_lora, is_llamacpp = downloader.get_download_links_from_huggingface(model, branch, text_only=False, specific_file=specific_file) + if not links: + yield "No files found to download for the given model/criteria." + progress(0.0) + return + # Check for multiple GGUF files gguf_files = [link for link in links if link.lower().endswith('.gguf')] if len(gguf_files) > 1 and not specific_file: output = "Multiple GGUF files found. Please copy one of the following filenames to the 'File name' field:\n\n```\n" for link in gguf_files: output += f"{Path(link).name}\n" - output += "```" yield output return @@ -255,17 +260,13 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur output = "```\n" for link in links: output += f"{Path(link).name}" + "\n" - output += "```" yield output return - yield ("Getting the output folder") + yield "Determining output folder..." output_folder = downloader.get_output_folder( - model, - branch, - is_lora, - is_llamacpp=is_llamacpp, + model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None ) @@ -275,19 +276,65 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur output_folder = Path(shared.args.lora_dir) if check: - progress(0.5) - - yield ("Checking previously downloaded files") + yield "Checking previously downloaded files..." + progress(0.5, "Verifying files...") downloader.check_model_files(model, branch, links, sha256, output_folder) - progress(1.0) - else: - yield (f"Downloading file{'s' if len(links) > 1 else ''} to `{output_folder}/`") - downloader.download_model_files(model, branch, links, sha256, output_folder, progress_bar=progress, threads=4, is_llamacpp=is_llamacpp) + progress(1.0, "Verification complete.") + yield "File check complete." + return - yield (f"Model successfully saved to `{output_folder}/`.") - except: - progress(1.0) - yield traceback.format_exc().replace('\n', '\n\n') + yield "" + progress(0.0, "Download starting...") + + def downloader_thread_target(): + try: + downloader.download_model_files( + model, branch, links, sha256, output_folder, + progress_queue=update_queue, + threads=4, + is_llamacpp=is_llamacpp, + specific_file=specific_file + ) + update_queue.put(("COMPLETED", f"Model successfully saved to `{output_folder}/`.")) + except Exception as e: + tb_str = traceback.format_exc().replace('\n', '\n\n') + update_queue.put(("ERROR", tb_str)) + + download_thread = threading.Thread(target=downloader_thread_target) + download_thread.start() + + while True: + try: + message = update_queue.get(timeout=0.2) + if not isinstance(message, tuple) or len(message) != 2: + continue + + msg_identifier, data = message + + if msg_identifier == "COMPLETED": + progress(1.0, "Download complete!") + yield data + break + elif msg_identifier == "ERROR": + progress(0.0, "Error occurred") + yield data + break + elif isinstance(msg_identifier, float): + progress_value = msg_identifier + description_str = data + progress(progress_value, f"Downloading: {description_str}") + + except queue.Empty: + if not download_thread.is_alive(): + yield "Download process finished." + break + + download_thread.join() + + except Exception as e: + progress(0.0) + tb_str = traceback.format_exc().replace('\n', '\n\n') + yield tb_str def update_truncation_length(current_length, state):