mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 14:17:09 -04:00
Integrate with the API
This commit is contained in:
parent
f92e1f44a0
commit
2e21b1f5e3
1 changed files with 125 additions and 4 deletions
|
@ -1,8 +1,11 @@
|
|||
import base64
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
@ -16,6 +19,7 @@ from modules.chat import (
|
|||
load_character_memoized,
|
||||
load_instruction_template_memoized
|
||||
)
|
||||
from modules.logging_colors import logger
|
||||
from modules.presets import load_preset_memoized
|
||||
from modules.text_generation import decode, encode, generate_reply
|
||||
|
||||
|
@ -82,6 +86,67 @@ def process_parameters(body, is_legacy=False):
|
|||
return generate_params
|
||||
|
||||
|
||||
def get_current_timestamp():
|
||||
"""Returns the current time in 24-hour format"""
|
||||
return datetime.now().strftime('%b %d, %Y %H:%M')
|
||||
|
||||
|
||||
def process_image_url(url, image_id):
|
||||
"""Process an image URL and return attachment data"""
|
||||
try:
|
||||
if url.startswith("data:"):
|
||||
# Handle data URL (data:image/jpeg;base64,...)
|
||||
if "base64," in url:
|
||||
image_data = url.split("base64,", 1)[1]
|
||||
else:
|
||||
raise ValueError("Unsupported data URL format")
|
||||
else:
|
||||
# Handle regular URL - download image
|
||||
response = requests.get(url, timeout=10)
|
||||
response.raise_for_status()
|
||||
image_data = base64.b64encode(response.content).decode('utf-8')
|
||||
|
||||
return {
|
||||
"name": f"image_{image_id}",
|
||||
"type": "image",
|
||||
"image_data": image_data,
|
||||
"image_id": image_id,
|
||||
"file_path": f"api_image_{image_id}", # Add this for consistency with UI
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image URL {url}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def process_multimodal_content(content):
|
||||
"""Process multimodal content and return text content and attachments"""
|
||||
if isinstance(content, str):
|
||||
return content, []
|
||||
|
||||
if isinstance(content, list):
|
||||
text_content = ""
|
||||
image_refs = ""
|
||||
attachments = []
|
||||
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text_content += item.get("text", "")
|
||||
elif item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
if image_url:
|
||||
attachment = process_image_url(image_url, len(attachments) + 1)
|
||||
if attachment:
|
||||
attachments.append(attachment)
|
||||
image_refs += f"[img-{attachment['image_id']}]"
|
||||
else:
|
||||
# Log warning but continue processing
|
||||
logger.warning(f"Failed to process image URL: {image_url}")
|
||||
|
||||
return f"{image_refs}{text_content}", attachments
|
||||
|
||||
return str(content), []
|
||||
|
||||
|
||||
def convert_history(history):
|
||||
'''
|
||||
Chat histories in this program are in the format [message, reply].
|
||||
|
@ -93,26 +158,46 @@ def convert_history(history):
|
|||
user_input = ""
|
||||
user_input_last = True
|
||||
system_message = ""
|
||||
metadata = {}
|
||||
|
||||
# Keep track of attachments for the current message being built
|
||||
pending_attachments = []
|
||||
|
||||
for entry in history:
|
||||
content = entry["content"]
|
||||
role = entry["role"]
|
||||
|
||||
if role == "user":
|
||||
user_input = content
|
||||
# Process multimodal content
|
||||
processed_content, attachments = process_multimodal_content(content)
|
||||
user_input = processed_content
|
||||
user_input_last = True
|
||||
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, '', ''])
|
||||
current_message = ""
|
||||
|
||||
current_message = content
|
||||
current_message = processed_content
|
||||
pending_attachments = attachments # Store attachments for when message is added
|
||||
|
||||
elif role == "assistant":
|
||||
if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "":
|
||||
continue # skip tool calls
|
||||
current_reply = content
|
||||
user_input_last = False
|
||||
if current_message:
|
||||
row_idx = len(chat_dialogue) # Calculate index here, right before adding
|
||||
chat_dialogue.append([current_message, current_reply, ''])
|
||||
|
||||
# Add attachments to metadata if any
|
||||
if pending_attachments:
|
||||
user_key = f"user_{row_idx}"
|
||||
metadata[user_key] = {
|
||||
"timestamp": get_current_timestamp(),
|
||||
"attachments": pending_attachments
|
||||
}
|
||||
pending_attachments = [] # Clear pending attachments
|
||||
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
else:
|
||||
|
@ -123,10 +208,19 @@ def convert_history(history):
|
|||
elif role == "system":
|
||||
system_message += f"\n{content}" if system_message else content
|
||||
|
||||
# Handle case where there's a pending user message at the end
|
||||
if current_message and pending_attachments:
|
||||
row_idx = len(chat_dialogue) # This will be the index when the message is processed
|
||||
user_key = f"user_{row_idx}"
|
||||
metadata[user_key] = {
|
||||
"timestamp": get_current_timestamp(),
|
||||
"attachments": pending_attachments
|
||||
}
|
||||
|
||||
if not user_input_last:
|
||||
user_input = ""
|
||||
|
||||
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
|
||||
return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue), 'metadata': metadata}
|
||||
|
||||
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict:
|
||||
|
@ -150,9 +244,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
elif m['role'] == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
|
||||
if 'content' not in m and "image_url" not in m:
|
||||
# Handle multimodal content validation
|
||||
content = m.get('content')
|
||||
if content is None:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
# Validate multimodal content structure
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or 'type' not in item:
|
||||
raise InvalidRequestError(message="messages: invalid content item format", param='messages')
|
||||
if item['type'] not in ['text', 'image_url']:
|
||||
raise InvalidRequestError(message="messages: unsupported content type", param='messages')
|
||||
if item['type'] == 'text' and 'text' not in item:
|
||||
raise InvalidRequestError(message="messages: missing text in content item", param='messages')
|
||||
if item['type'] == 'image_url' and ('image_url' not in item or 'url' not in item['image_url']):
|
||||
raise InvalidRequestError(message="messages: missing image_url in content item", param='messages')
|
||||
|
||||
# Chat Completions
|
||||
object_type = 'chat.completion' if not stream else 'chat.completion.chunk'
|
||||
created_time = int(time.time())
|
||||
|
@ -189,6 +297,15 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
# History
|
||||
user_input, custom_system_message, history = convert_history(messages)
|
||||
|
||||
# Collect image attachments for multimodal support
|
||||
image_attachments = []
|
||||
if 'metadata' in history:
|
||||
for key, value in history['metadata'].items():
|
||||
if 'attachments' in value:
|
||||
for attachment in value['attachments']:
|
||||
if attachment.get('type') == 'image':
|
||||
image_attachments.append(attachment)
|
||||
|
||||
generate_params.update({
|
||||
'mode': body['mode'],
|
||||
'name1': name1,
|
||||
|
@ -205,6 +322,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
'stream': stream
|
||||
})
|
||||
|
||||
# Add image attachments to state for llama.cpp multimodal support
|
||||
if image_attachments:
|
||||
generate_params['image_attachments'] = image_attachments
|
||||
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens in [None, 0]:
|
||||
generate_params['max_new_tokens'] = 512
|
||||
|
|
Loading…
Add table
Reference in a new issue