mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 14:17:09 -04:00
Merge 1f3b1a1b94
into 28e6bd4fcd
This commit is contained in:
commit
ed6ecf1df4
5 changed files with 202 additions and 41 deletions
13
css/main.css
13
css/main.css
|
@ -1551,3 +1551,16 @@ strong {
|
|||
color: var(--body-text-color-subdued);
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.image-attachment {
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.attachment-image {
|
||||
border-radius: 16px;
|
||||
margin-bottom: 5px;
|
||||
object-fit: cover;
|
||||
object-position: center;
|
||||
border: 2px solid var(--border-color-primary);
|
||||
aspect-ratio: 1 / 1;
|
||||
}
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import base64
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
@ -16,6 +18,7 @@ from modules.chat import (
|
|||
load_character_memoized,
|
||||
load_instruction_template_memoized
|
||||
)
|
||||
from modules.logging_colors import logger
|
||||
from modules.presets import load_preset_memoized
|
||||
from modules.text_generation import decode, encode, generate_reply
|
||||
|
||||
|
@ -82,6 +85,50 @@ def process_parameters(body, is_legacy=False):
|
|||
return generate_params
|
||||
|
||||
|
||||
def process_image_url(url, image_id):
|
||||
"""Process an image URL and return attachment data for llama.cpp"""
|
||||
try:
|
||||
if url.startswith("data:"):
|
||||
if "base64," in url:
|
||||
image_data = url.split("base64,", 1)[1]
|
||||
else:
|
||||
raise ValueError("Unsupported data URL format")
|
||||
else:
|
||||
headers = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36'}
|
||||
response = requests.get(url, timeout=10, headers=headers)
|
||||
response.raise_for_status()
|
||||
image_data = base64.b64encode(response.content).decode('utf-8')
|
||||
|
||||
return {"image_data": image_data, "image_id": image_id}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image URL {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_multimodal_content(content):
|
||||
"""Extract text and images from OpenAI multimodal format"""
|
||||
if isinstance(content, str):
|
||||
return content, []
|
||||
|
||||
if isinstance(content, list):
|
||||
text_content = ""
|
||||
images = []
|
||||
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text_content += item.get("text", "")
|
||||
elif item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
if image_url:
|
||||
image = process_image_url(image_url, len(images) + 1)
|
||||
if image:
|
||||
images.append(image)
|
||||
|
||||
return text_content, images
|
||||
|
||||
return str(content), []
|
||||
|
||||
|
||||
def convert_history(history):
|
||||
'''
|
||||
Chat histories in this program are in the format [message, reply].
|
||||
|
@ -93,19 +140,29 @@ def convert_history(history):
|
|||
user_input = ""
|
||||
user_input_last = True
|
||||
system_message = ""
|
||||
all_images = [] # Simple list to collect all images
|
||||
|
||||
for entry in history:
|
||||
content = entry["content"]
|
||||
role = entry["role"]
|
||||
|
||||
if role == "user":
|
||||
user_input = content
|
||||
# Process multimodal content
|
||||
processed_content, images = process_multimodal_content(content)
|
||||
if images:
|
||||
image_refs = "".join(f"[img-{img['image_id']}]" for img in images)
|
||||
processed_content = f"{processed_content} {image_refs}"
|
||||
|
||||
user_input = processed_content
|
||||
user_input_last = True
|
||||
all_images.extend(images) # Add any images to our collection
|
||||
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, '', ''])
|
||||
current_message = ""
|
||||
|
||||
current_message = content
|
||||
current_message = processed_content
|
||||
|
||||
elif role == "assistant":
|
||||
if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "":
|
||||
continue # skip tool calls
|
||||
|
@ -126,7 +183,11 @@ def convert_history(history):
|
|||
if not user_input_last:
|
||||
user_input = ""
|
||||
|
||||
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
|
||||
return user_input, system_message, {
|
||||
'internal': chat_dialogue,
|
||||
'visible': copy.deepcopy(chat_dialogue),
|
||||
'images': all_images # Simple list of all images from the conversation
|
||||
}
|
||||
|
||||
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
|
||||
|
@ -150,9 +211,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
elif m['role'] == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
|
||||
if 'content' not in m and "image_url" not in m:
|
||||
# Handle multimodal content validation
|
||||
content = m.get('content')
|
||||
if content is None:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
# Validate multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or 'type' not in item:
|
||||
raise InvalidRequestError(message="messages: invalid content item format", param='messages')
|
||||
if item['type'] not in ['text', 'image_url']:
|
||||
raise InvalidRequestError(message="messages: unsupported content type", param='messages')
|
||||
if item['type'] == 'text' and 'text' not in item:
|
||||
raise InvalidRequestError(message="messages: missing text in content item", param='messages')
|
||||
if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']):
|
||||
raise InvalidRequestError(message="messages: missing image_url in content item", param='messages')
|
||||
|
||||
# Chat Completions
|
||||
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
|
||||
created_time = int(time.time())
|
||||
|
@ -205,6 +280,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
'stream': stream
|
||||
})
|
||||
|
||||
# Add images to state for llama.cpp multimodal support
|
||||
if history.get('images'):
|
||||
generate_params['image_attachments'] = history['images']
|
||||
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens in [None, 0]:
|
||||
generate_params['max_new_tokens'] = 512
|
||||
|
|
|
@ -220,13 +220,22 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
# Add attachment content if present
|
||||
if user_key in metadata and "attachments" in metadata[user_key]:
|
||||
attachments_text = ""
|
||||
image_refs = ""
|
||||
|
||||
for attachment in metadata[user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
# Add image reference for multimodal models
|
||||
image_refs += f"[img-{attachment['image_id']}]"
|
||||
else:
|
||||
# Handle text/PDF attachments as before
|
||||
filename = attachment.get("name", "file")
|
||||
content = attachment.get("content", "")
|
||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
|
||||
if image_refs or attachments_text:
|
||||
enhanced_user_msg = f"{user_msg} {image_refs}"
|
||||
if attachments_text:
|
||||
enhanced_user_msg = f"{user_msg}\n\nATTACHMENTS:\n{attachments_text}"
|
||||
enhanced_user_msg += f"\n\nATTACHMENTS:\n{attachments_text}"
|
||||
|
||||
messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg})
|
||||
|
||||
|
@ -240,22 +249,29 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
has_attachments = user_key in metadata and "attachments" in metadata[user_key]
|
||||
|
||||
if (user_input or has_attachments) and not impersonate and not _continue:
|
||||
# For the current user input being processed, check if we need to add attachments
|
||||
if not impersonate and not _continue and len(history_data.get('metadata', {})) > 0:
|
||||
current_row_idx = len(history)
|
||||
user_key = f"user_{current_row_idx}"
|
||||
|
||||
enhanced_user_input = user_input
|
||||
|
||||
if user_key in metadata and "attachments" in metadata[user_key]:
|
||||
attachments_text = ""
|
||||
image_refs = ""
|
||||
|
||||
for attachment in metadata[user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
image_refs += f"[img-{attachment['image_id']}]"
|
||||
else:
|
||||
filename = attachment.get("name", "file")
|
||||
content = attachment.get("content", "")
|
||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||
|
||||
if image_refs or attachments_text:
|
||||
enhanced_user_input = f"{user_input} {image_refs}"
|
||||
if attachments_text:
|
||||
user_input = f"{user_input}\n\nATTACHMENTS:\n{attachments_text}"
|
||||
enhanced_user_input += f"\n\nATTACHMENTS:\n{attachments_text}"
|
||||
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
messages.append({"role": "user", "content": enhanced_user_input})
|
||||
|
||||
def make_prompt(messages):
|
||||
if state['mode'] == 'chat-instruct' and _continue:
|
||||
|
@ -495,26 +511,43 @@ def add_message_attachment(history, row_idx, file_path, is_user=True):
|
|||
file_extension = path.suffix.lower()
|
||||
|
||||
try:
|
||||
# Handle different file types
|
||||
if file_extension == '.pdf':
|
||||
# Handle image files
|
||||
if file_extension in ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']:
|
||||
# Convert image to base64
|
||||
with open(path, 'rb') as f:
|
||||
image_data = base64.b64encode(f.read()).decode('utf-8')
|
||||
|
||||
# Generate unique image ID
|
||||
image_id = len([att for att in history['metadata'][key]["attachments"] if att.get("type") == "image"]) + 1
|
||||
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": "image",
|
||||
"image_data": image_data,
|
||||
"image_id": image_id,
|
||||
"file_path": str(path) # For UI preview
|
||||
}
|
||||
|
||||
elif file_extension == '.pdf':
|
||||
# Process PDF file
|
||||
content = extract_pdf_text(path)
|
||||
file_type = "application/pdf"
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": "application/pdf",
|
||||
"content": content,
|
||||
}
|
||||
else:
|
||||
# Default handling for text files
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
file_type = "text/plain"
|
||||
|
||||
# Add attachment
|
||||
attachment = {
|
||||
"name": filename,
|
||||
"type": file_type,
|
||||
"type": "text/plain",
|
||||
"content": content,
|
||||
}
|
||||
|
||||
history['metadata'][key]["attachments"].append(attachment)
|
||||
return content # Return the content for reuse
|
||||
return attachment # Return the attachment for reuse
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing attachment {filename}: {e}")
|
||||
return None
|
||||
|
@ -590,6 +623,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
|||
for file_path in files:
|
||||
add_message_attachment(output, row_idx, file_path, is_user=True)
|
||||
|
||||
# Collect image attachments for llama.cpp
|
||||
image_attachments = []
|
||||
if 'metadata' in output:
|
||||
user_key = f"user_{row_idx}"
|
||||
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
|
||||
for attachment in output['metadata'][user_key]["attachments"]:
|
||||
if attachment.get("type") == "image":
|
||||
image_attachments.append(attachment)
|
||||
|
||||
# Add image attachments to state for the generation
|
||||
if image_attachments:
|
||||
state['image_attachments'] = image_attachments
|
||||
|
||||
# Add web search results as attachments if enabled
|
||||
if state.get('enable_web_search', False):
|
||||
search_query = generate_search_query(text, state)
|
||||
|
|
|
@ -372,7 +372,17 @@ def format_message_attachments(history, role, index):
|
|||
for attachment in attachments:
|
||||
name = html.escape(attachment["name"])
|
||||
|
||||
# Make clickable if URL exists
|
||||
if attachment.get("type") == "image":
|
||||
# Show image preview
|
||||
file_path = attachment.get("file_path", "")
|
||||
attachments_html += (
|
||||
f'<div class="attachment-box image-attachment">'
|
||||
f'<img src="file/{file_path}" alt="{name}" class="attachment-image" />'
|
||||
f'<div class="attachment-name">{name}</div>'
|
||||
f'</div>'
|
||||
)
|
||||
else:
|
||||
# Make clickable if URL exists (web search)
|
||||
if "url" in attachment:
|
||||
name = f'<a href="{html.escape(attachment["url"])}" target="_blank" rel="noopener noreferrer">{name}</a>'
|
||||
|
||||
|
@ -382,6 +392,7 @@ def format_message_attachments(history, role, index):
|
|||
f'<div class="attachment-name">{name}</div>'
|
||||
f'</div>'
|
||||
)
|
||||
|
||||
attachments_html += '</div>'
|
||||
return attachments_html
|
||||
|
||||
|
|
|
@ -121,6 +121,18 @@ class LlamaServer:
|
|||
to_ban = [[int(token_id), False] for token_id in state['custom_token_bans'].split(',')]
|
||||
payload["logit_bias"] = to_ban
|
||||
|
||||
# Add image data if present
|
||||
if 'image_attachments' in state:
|
||||
image_data = []
|
||||
for attachment in state['image_attachments']:
|
||||
image_data.append({
|
||||
"data": attachment['image_data'],
|
||||
"id": attachment['image_id']
|
||||
})
|
||||
|
||||
if image_data:
|
||||
payload["image_data"] = image_data
|
||||
|
||||
return payload
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
|
@ -142,7 +154,7 @@ class LlamaServer:
|
|||
|
||||
if shared.args.verbose:
|
||||
logger.info("GENERATE_PARAMS=")
|
||||
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
||||
printable_payload = {k: v for k, v in payload.items() if k not in ["prompt", "image_data"]}
|
||||
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(printable_payload)
|
||||
print()
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue