diff --git a/css/main.css b/css/main.css index a9cb36ab..026ea6c8 100644 --- a/css/main.css +++ b/css/main.css @@ -1557,6 +1557,19 @@ strong { margin-top: 4px; } +.image-attachment { + flex-direction: column; +} + +.image-preview { + border-radius: 16px; + margin-bottom: 5px; + object-fit: cover; + object-position: center; + border: 2px solid var(--border-color-primary); + aspect-ratio: 1 / 1; +} + button:focus { outline: none; } diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 5181b18b..4e4e310f 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -1,8 +1,10 @@ +import base64 import copy import json import time from collections import deque +import requests import tiktoken from pydantic import ValidationError @@ -16,6 +18,7 @@ from modules.chat import ( load_character_memoized, load_instruction_template_memoized ) +from modules.logging_colors import logger from modules.presets import load_preset_memoized from modules.text_generation import decode, encode, generate_reply @@ -82,6 +85,50 @@ def process_parameters(body, is_legacy=False): return generate_params +def process_image_url(url, image_id): + """Process an image URL and return attachment data for llama.cpp""" + try: + if url.startswith("data:"): + if "base64," in url: + image_data = url.split("base64,", 1)[1] + else: + raise ValueError("Unsupported data URL format") + else: + headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'} + response = requests.get(url, timeout=10, headers=headers) + response.raise_for_status() + image_data = base64.b64encode(response.content).decode('utf-8') + + return {"image_data": image_data, "image_id": image_id} + except Exception as e: + logger.error(f"Error processing image URL {url}: {e}") + return None + + +def process_multimodal_content(content): + """Extract text and images from OpenAI multimodal format""" + if isinstance(content, str): + return content, [] + + if isinstance(content, list): + text_content = "" + images = [] + + for item in content: + if item.get("type") == "text": + text_content += item.get("text", "") + elif item.get("type") == "image_url": + image_url = item.get("image_url", {}).get("url", "") + if image_url: + image = process_image_url(image_url, len(images) + 1) + if image: + images.append(image) + + return text_content, images + + return str(content), [] + + def convert_history(history): ''' Chat histories in this program are in the format [message, reply]. @@ -93,19 +140,29 @@ def convert_history(history): user_input = "" user_input_last = True system_message = "" + all_images = [] # Simple list to collect all images for entry in history: content = entry["content"] role = entry["role"] if role == "user": - user_input = content + # Process multimodal content + processed_content, images = process_multimodal_content(content) + if images: + image_refs = "".join("<__media__>" for img in images) + processed_content = f"{processed_content} {image_refs}" + + user_input = processed_content user_input_last = True + all_images.extend(images) # Add any images to our collection + if current_message: chat_dialogue.append([current_message, '', '']) current_message = "" - current_message = content + current_message = processed_content + elif role == "assistant": if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "": continue # skip tool calls @@ -126,7 +183,11 @@ def convert_history(history): if not user_input_last: user_input = "" - return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)} + return user_input, system_message, { + 'internal': chat_dialogue, + 'visible': copy.deepcopy(chat_dialogue), + 'images': all_images # Simple list of all images from the conversation + } def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict: @@ -150,9 +211,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p elif m['role'] == 'function': raise InvalidRequestError(message="role: function is not supported.", param='messages') - if 'content' not in m and "image_url" not in m: + # Handle multimodal content validation + content = m.get('content') + if content is None: raise InvalidRequestError(message="messages: missing content", param='messages') + # Validate multimodal content structure + if isinstance(content, list): + for item in content: + if not isinstance(item, dict) or 'type' not in item: + raise InvalidRequestError(message="messages: invalid content item format", param='messages') + if item['type'] not in ['text', 'image_url']: + raise InvalidRequestError(message="messages: unsupported content type", param='messages') + if item['type'] == 'text' and 'text' not in item: + raise InvalidRequestError(message="messages: missing text in content item", param='messages') + if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']): + raise InvalidRequestError(message="messages: missing image_url in content item", param='messages') + # Chat Completions object_type = 'chat.completion' if not stream else 'chat.completion.chunk' created_time = int(time.time()) @@ -205,6 +280,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p 'stream': stream }) + # Add images to state for llama.cpp multimodal support + if history.get('images'): + generate_params['image_attachments'] = history['images'] + max_tokens = generate_params['max_new_tokens'] if max_tokens in [None, 0]: generate_params['max_new_tokens'] = 512 diff --git a/modules/chat.py b/modules/chat.py index 14f2a4f7..9dc8d1fd 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -220,13 +220,22 @@ def generate_chat_prompt(user_input, state, **kwargs): # Add attachment content if present if user_key in metadata and "attachments" in metadata[user_key]: attachments_text = "" - for attachment in metadata[user_key]["attachments"]: - filename = attachment.get("name", "file") - content = attachment.get("content", "") - attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" + image_refs = "" - if attachments_text: - enhanced_user_msg = f"{user_msg}\n\nATTACHMENTS:\n{attachments_text}" + for attachment in metadata[user_key]["attachments"]: + if attachment.get("type") == "image": + # Add image reference for multimodal models + image_refs += "<__media__>" + else: + # Handle text/PDF attachments as before + filename = attachment.get("name", "file") + content = attachment.get("content", "") + attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" + + if image_refs or attachments_text: + enhanced_user_msg = f"{user_msg} {image_refs}" + if attachments_text: + enhanced_user_msg += f"\n\nATTACHMENTS:\n{attachments_text}" messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg}) @@ -240,22 +249,29 @@ def generate_chat_prompt(user_input, state, **kwargs): has_attachments = user_key in metadata and "attachments" in metadata[user_key] if (user_input or has_attachments) and not impersonate and not _continue: - # For the current user input being processed, check if we need to add attachments - if not impersonate and not _continue and len(history_data.get('metadata', {})) > 0: - current_row_idx = len(history) - user_key = f"user_{current_row_idx}" + current_row_idx = len(history) + user_key = f"user_{current_row_idx}" - if user_key in metadata and "attachments" in metadata[user_key]: - attachments_text = "" - for attachment in metadata[user_key]["attachments"]: + enhanced_user_input = user_input + + if user_key in metadata and "attachments" in metadata[user_key]: + attachments_text = "" + image_refs = "" + + for attachment in metadata[user_key]["attachments"]: + if attachment.get("type") == "image": + image_refs += "<__media__>" + else: filename = attachment.get("name", "file") content = attachment.get("content", "") attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n" + if image_refs or attachments_text: + enhanced_user_input = f"{user_input} {image_refs}" if attachments_text: - user_input = f"{user_input}\n\nATTACHMENTS:\n{attachments_text}" + enhanced_user_input += f"\n\nATTACHMENTS:\n{attachments_text}" - messages.append({"role": "user", "content": user_input}) + messages.append({"role": "user", "content": enhanced_user_input}) def make_prompt(messages): if state['mode'] == 'chat-instruct' and _continue: @@ -495,29 +511,63 @@ def add_message_attachment(history, row_idx, file_path, is_user=True): file_extension = path.suffix.lower() try: - # Handle different file types - if file_extension == '.pdf': + # Handle image files + if file_extension in ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']: + # Convert image to base64 + with open(path, 'rb') as f: + image_data = base64.b64encode(f.read()).decode('utf-8') + + # Determine MIME type from extension + mime_type_map = { + '.jpg': 'image/jpeg', + '.jpeg': 'image/jpeg', + '.png': 'image/png', + '.webp': 'image/webp', + '.bmp': 'image/bmp', + '.gif': 'image/gif' + } + mime_type = mime_type_map.get(file_extension, 'image/jpeg') + + # Format as data URL + data_url = f"data:{mime_type};base64,{image_data}" + + # Generate unique image ID + image_id = len([att for att in history['metadata'][key]["attachments"] if att.get("type") == "image"]) + 1 + + attachment = { + "name": filename, + "type": "image", + "image_data": data_url, + "image_id": image_id, + "file_path": str(path) # For UI preview + } + elif file_extension == '.pdf': # Process PDF file content = extract_pdf_text(path) - file_type = "application/pdf" + attachment = { + "name": filename, + "type": "application/pdf", + "content": content, + } elif file_extension == '.docx': content = extract_docx_text(path) - file_type = "application/docx" + attachment = { + "name": filename, + "type": "application/docx", + "content": content, + } else: # Default handling for text files with open(path, 'r', encoding='utf-8') as f: content = f.read() - file_type = "text/plain" - - # Add attachment - attachment = { - "name": filename, - "type": file_type, - "content": content, - } + attachment = { + "name": filename, + "type": "text/plain", + "content": content, + } history['metadata'][key]["attachments"].append(attachment) - return content # Return the content for reuse + return attachment # Return the attachment for reuse except Exception as e: logger.error(f"Error processing attachment {filename}: {e}") return None @@ -645,6 +695,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess for file_path in files: add_message_attachment(output, row_idx, file_path, is_user=True) + # Collect image attachments for llama.cpp + image_attachments = [] + if 'metadata' in output: + user_key = f"user_{row_idx}" + if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]: + for attachment in output['metadata'][user_key]["attachments"]: + if attachment.get("type") == "image": + image_attachments.append(attachment) + + # Add image attachments to state for the generation + if image_attachments: + state['image_attachments'] = image_attachments + # Add web search results as attachments if enabled if state.get('enable_web_search', False): search_query = generate_search_query(text, state) diff --git a/modules/html_generator.py b/modules/html_generator.py index f90e3b04..770a3b1a 100644 --- a/modules/html_generator.py +++ b/modules/html_generator.py @@ -374,16 +374,27 @@ def format_message_attachments(history, role, index): for attachment in attachments: name = html.escape(attachment["name"]) - # Make clickable if URL exists - if "url" in attachment: - name = f'{name}' + if attachment.get("type") == "image": + # Show image preview + file_path = attachment.get("file_path", "") + attachments_html += ( + f'
' + f'{name}' + f'
{name}
' + f'
' + ) + else: + # Make clickable if URL exists (web search) + if "url" in attachment: + name = f'{name}' + + attachments_html += ( + f'
' + f'
{attachment_svg}
' + f'
{name}
' + f'
' + ) - attachments_html += ( - f'
' - f'
{attachment_svg}
' - f'
{name}
' - f'
' - ) attachments_html += '' return attachments_html diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index aa712541..ca1b2c47 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -121,6 +121,18 @@ class LlamaServer: to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')] payload["logit_bias"] = to_ban + # Add image data if present + if 'image_attachments' in state: + medias = [] + for attachment in state['image_attachments']: + medias.append({ + "type": "image", + "data": attachment['image_data'] + }) + + if medias: + payload["medias"] = medias + return payload def generate_with_streaming(self, prompt, state): @@ -142,7 +154,7 @@ class LlamaServer: if shared.args.verbose: logger.info("GENERATE_PARAMS=") - printable_payload = {k: v for k, v in payload.items() if k != "prompt"} + printable_payload = {k: v for k, v in payload.items() if k not in ["prompt", "image_data"]} pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload) print()