mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-08 14:46:14 -04:00
Add multimodal support (llama.cpp)
This commit is contained in:
parent
6c3590ba9a
commit
f92e1f44a0
4 changed files with 117 additions and 36 deletions
13
css/main.css
13
css/main.css
|
@ -1550,3 +1550,16 @@ strong {
|
||||||
color: var(--body-text-color-subdued);
|
color: var(--body-text-color-subdued);
|
||||||
margin-top: 4px;
|
margin-top: 4px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.image-attachment {
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
.attachment-image {
|
||||||
|
border-radius: 16px;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
object-fit: cover;
|
||||||
|
object-position: center;
|
||||||
|
border: 2px solid var(--border-color-primary);
|
||||||
|
aspect-ratio: 1 / 1;
|
||||||
|
}
|
||||||
|
|
|
@ -220,13 +220,22 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
# Add attachment content if present
|
# Add attachment content if present
|
||||||
if user_key in metadata and "attachments" in metadata[user_key]:
|
if user_key in metadata and "attachments" in metadata[user_key]:
|
||||||
attachments_text = ""
|
attachments_text = ""
|
||||||
|
image_refs = ""
|
||||||
|
|
||||||
for attachment in metadata[user_key]["attachments"]:
|
for attachment in metadata[user_key]["attachments"]:
|
||||||
|
if attachment.get("type") == "image":
|
||||||
|
# Add image reference for multimodal models
|
||||||
|
image_refs += f"[img-{attachment['image_id']}]"
|
||||||
|
else:
|
||||||
|
# Handle text/PDF attachments as before
|
||||||
filename = attachment.get("name", "file")
|
filename = attachment.get("name", "file")
|
||||||
content = attachment.get("content", "")
|
content = attachment.get("content", "")
|
||||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||||
|
|
||||||
|
if image_refs or attachments_text:
|
||||||
|
enhanced_user_msg = f"{image_refs}{user_msg}"
|
||||||
if attachments_text:
|
if attachments_text:
|
||||||
enhanced_user_msg = f"{user_msg}\n\nATTACHMENTS:\n{attachments_text}"
|
enhanced_user_msg += f"\n\nATTACHMENTS:\n{attachments_text}"
|
||||||
|
|
||||||
messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg})
|
messages.insert(insert_pos, {"role": "user", "content": enhanced_user_msg})
|
||||||
|
|
||||||
|
@ -240,22 +249,29 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||||
has_attachments = user_key in metadata and "attachments" in metadata[user_key]
|
has_attachments = user_key in metadata and "attachments" in metadata[user_key]
|
||||||
|
|
||||||
if (user_input or has_attachments) and not impersonate and not _continue:
|
if (user_input or has_attachments) and not impersonate and not _continue:
|
||||||
# For the current user input being processed, check if we need to add attachments
|
|
||||||
if not impersonate and not _continue and len(history_data.get('metadata', {})) > 0:
|
|
||||||
current_row_idx = len(history)
|
current_row_idx = len(history)
|
||||||
user_key = f"user_{current_row_idx}"
|
user_key = f"user_{current_row_idx}"
|
||||||
|
|
||||||
|
enhanced_user_input = user_input
|
||||||
|
|
||||||
if user_key in metadata and "attachments" in metadata[user_key]:
|
if user_key in metadata and "attachments" in metadata[user_key]:
|
||||||
attachments_text = ""
|
attachments_text = ""
|
||||||
|
image_refs = ""
|
||||||
|
|
||||||
for attachment in metadata[user_key]["attachments"]:
|
for attachment in metadata[user_key]["attachments"]:
|
||||||
|
if attachment.get("type") == "image":
|
||||||
|
image_refs += f"[img-{attachment['image_id']}]"
|
||||||
|
else:
|
||||||
filename = attachment.get("name", "file")
|
filename = attachment.get("name", "file")
|
||||||
content = attachment.get("content", "")
|
content = attachment.get("content", "")
|
||||||
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
attachments_text += f"\nName: {filename}\nContents:\n\n=====\n{content}\n=====\n\n"
|
||||||
|
|
||||||
|
if image_refs or attachments_text:
|
||||||
|
enhanced_user_input = f"{image_refs}{user_input}"
|
||||||
if attachments_text:
|
if attachments_text:
|
||||||
user_input = f"{user_input}\n\nATTACHMENTS:\n{attachments_text}"
|
enhanced_user_input += f"\n\nATTACHMENTS:\n{attachments_text}"
|
||||||
|
|
||||||
messages.append({"role": "user", "content": user_input})
|
messages.append({"role": "user", "content": enhanced_user_input})
|
||||||
|
|
||||||
def make_prompt(messages):
|
def make_prompt(messages):
|
||||||
if state['mode'] == 'chat-instruct' and _continue:
|
if state['mode'] == 'chat-instruct' and _continue:
|
||||||
|
@ -493,26 +509,43 @@ def add_message_attachment(history, row_idx, file_path, is_user=True):
|
||||||
file_extension = path.suffix.lower()
|
file_extension = path.suffix.lower()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Handle different file types
|
# Handle image files
|
||||||
if file_extension == '.pdf':
|
if file_extension in ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif']:
|
||||||
|
# Convert image to base64
|
||||||
|
with open(path, 'rb') as f:
|
||||||
|
image_data = base64.b64encode(f.read()).decode('utf-8')
|
||||||
|
|
||||||
|
# Generate unique image ID
|
||||||
|
image_id = len([att for att in history['metadata'][key]["attachments"] if att.get("type") == "image"]) + 1
|
||||||
|
|
||||||
|
attachment = {
|
||||||
|
"name": filename,
|
||||||
|
"type": "image",
|
||||||
|
"image_data": image_data,
|
||||||
|
"image_id": image_id,
|
||||||
|
"file_path": str(path) # For UI preview
|
||||||
|
}
|
||||||
|
|
||||||
|
elif file_extension == '.pdf':
|
||||||
# Process PDF file
|
# Process PDF file
|
||||||
content = extract_pdf_text(path)
|
content = extract_pdf_text(path)
|
||||||
file_type = "application/pdf"
|
attachment = {
|
||||||
|
"name": filename,
|
||||||
|
"type": "application/pdf",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
# Default handling for text files
|
# Default handling for text files
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
content = f.read()
|
content = f.read()
|
||||||
file_type = "text/plain"
|
|
||||||
|
|
||||||
# Add attachment
|
|
||||||
attachment = {
|
attachment = {
|
||||||
"name": filename,
|
"name": filename,
|
||||||
"type": file_type,
|
"type": "text/plain",
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
history['metadata'][key]["attachments"].append(attachment)
|
history['metadata'][key]["attachments"].append(attachment)
|
||||||
return content # Return the content for reuse
|
return attachment # Return the attachment for reuse
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error processing attachment {filename}: {e}")
|
logger.error(f"Error processing attachment {filename}: {e}")
|
||||||
return None
|
return None
|
||||||
|
@ -567,6 +600,19 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
|
||||||
for file_path in files:
|
for file_path in files:
|
||||||
add_message_attachment(output, row_idx, file_path, is_user=True)
|
add_message_attachment(output, row_idx, file_path, is_user=True)
|
||||||
|
|
||||||
|
# Collect image attachments for llama.cpp
|
||||||
|
image_attachments = []
|
||||||
|
if 'metadata' in output:
|
||||||
|
user_key = f"user_{row_idx}"
|
||||||
|
if user_key in output['metadata'] and "attachments" in output['metadata'][user_key]:
|
||||||
|
for attachment in output['metadata'][user_key]["attachments"]:
|
||||||
|
if attachment.get("type") == "image":
|
||||||
|
image_attachments.append(attachment)
|
||||||
|
|
||||||
|
# Add image attachments to state for the generation
|
||||||
|
if image_attachments:
|
||||||
|
state['image_attachments'] = image_attachments
|
||||||
|
|
||||||
# Add web search results as attachments if enabled
|
# Add web search results as attachments if enabled
|
||||||
add_web_search_attachments(output, row_idx, text, state)
|
add_web_search_attachments(output, row_idx, text, state)
|
||||||
|
|
||||||
|
|
|
@ -372,7 +372,17 @@ def format_message_attachments(history, role, index):
|
||||||
for attachment in attachments:
|
for attachment in attachments:
|
||||||
name = html.escape(attachment["name"])
|
name = html.escape(attachment["name"])
|
||||||
|
|
||||||
# Make clickable if URL exists
|
if attachment.get("type") == "image":
|
||||||
|
# Show image preview
|
||||||
|
file_path = attachment.get("file_path", "")
|
||||||
|
attachments_html += (
|
||||||
|
f'<div class="attachment-box image-attachment">'
|
||||||
|
f'<img src="file/{file_path}" alt="{name}" class="attachment-image" />'
|
||||||
|
f'<div class="attachment-name">{name}</div>'
|
||||||
|
f'</div>'
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Make clickable if URL exists (web search)
|
||||||
if "url" in attachment:
|
if "url" in attachment:
|
||||||
name = f'<a href="{html.escape(attachment["url"])}" target="_blank" rel="noopener noreferrer">{name}</a>'
|
name = f'<a href="{html.escape(attachment["url"])}" target="_blank" rel="noopener noreferrer">{name}</a>'
|
||||||
|
|
||||||
|
@ -382,6 +392,7 @@ def format_message_attachments(history, role, index):
|
||||||
f'<div class="attachment-name">{name}</div>'
|
f'<div class="attachment-name">{name}</div>'
|
||||||
f'</div>'
|
f'</div>'
|
||||||
)
|
)
|
||||||
|
|
||||||
attachments_html += '</div>'
|
attachments_html += '</div>'
|
||||||
return attachments_html
|
return attachments_html
|
||||||
|
|
||||||
|
|
|
@ -140,6 +140,17 @@ class LlamaServer:
|
||||||
"cache_prompt": True
|
"cache_prompt": True
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# Add image data if present
|
||||||
|
if 'image_attachments' in state:
|
||||||
|
image_data = []
|
||||||
|
for attachment in state['image_attachments']:
|
||||||
|
image_data.append({
|
||||||
|
"data": attachment['image_data'],
|
||||||
|
"id": attachment['image_id']
|
||||||
|
})
|
||||||
|
if image_data:
|
||||||
|
payload["image_data"] = image_data
|
||||||
|
|
||||||
if shared.args.verbose:
|
if shared.args.verbose:
|
||||||
logger.info("GENERATE_PARAMS=")
|
logger.info("GENERATE_PARAMS=")
|
||||||
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
printable_payload = {k: v for k, v in payload.items() if k != "prompt"}
|
||||||
|
|
Loading…
Add table
Reference in a new issue