Tools support for OpenAI compatible API (#6827)

This commit is contained in:
Jonas 2025-05-08 17:30:27 +02:00 committed by GitHub
parent ed6e16191d
commit fa960496d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 209 additions and 17 deletions

View file

@ -1,11 +1,14 @@
import copy import copy
import time import time
import json
from collections import deque from collections import deque
import tiktoken import tiktoken
from extensions.openai.errors import InvalidRequestError from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
from extensions.openai.typing import ToolDefinition
from pydantic import ValidationError
from modules import shared from modules import shared
from modules.chat import ( from modules.chat import (
generate_chat_prompt, generate_chat_prompt,
@ -99,19 +102,24 @@ def convert_history(history):
user_input = content user_input = content
user_input_last = True user_input_last = True
if current_message: if current_message:
chat_dialogue.append([current_message, '']) chat_dialogue.append([current_message, '', ''])
current_message = "" current_message = ""
current_message = content current_message = content
elif role == "assistant": elif role == "assistant":
if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "":
continue # skip tool calls
current_reply = content current_reply = content
user_input_last = False user_input_last = False
if current_message: if current_message:
chat_dialogue.append([current_message, current_reply]) chat_dialogue.append([current_message, current_reply, ''])
current_message = "" current_message = ""
current_reply = "" current_reply = ""
else: else:
chat_dialogue.append(['', current_reply]) chat_dialogue.append(['', current_reply, ''])
elif role == "tool":
user_input_last = False
chat_dialogue.append(['', '', content])
elif role == "system": elif role == "system":
system_message += f"\n{content}" if system_message else content system_message += f"\n{content}" if system_message else content
@ -131,6 +139,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if 'messages' not in body: if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages') raise InvalidRequestError(message="messages is required", param='messages')
tools = None
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
messages = body['messages'] messages = body['messages']
for m in messages: for m in messages:
if 'role' not in m: if 'role' not in m:
@ -188,6 +200,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
'custom_system_message': custom_system_message, 'custom_system_message': custom_system_message,
'chat_template_str': chat_template_str, 'chat_template_str': chat_template_str,
'chat-instruct_command': chat_instruct_command, 'chat-instruct_command': chat_instruct_command,
'tools': tools,
'history': history, 'history': history,
'stream': stream 'stream': stream
}) })
@ -200,7 +213,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
requested_model = generate_params.pop('model') requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None) logprob_proc = generate_params.pop('logprob_proc', None)
def chat_streaming_chunk(content): def chat_streaming_chunk(content, chunk_tool_calls=None):
# begin streaming # begin streaming
chunk = { chunk = {
"id": cmpl_id, "id": cmpl_id,
@ -210,7 +223,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": None, "finish_reason": None,
"delta": {'role': 'assistant', 'content': content}, "delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
}], }],
} }
@ -219,6 +232,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else: # else:
# chunk[resp_list][0]["logprobs"] = None # chunk[resp_list][0]["logprobs"] = None
return chunk return chunk
# generate reply ####################################### # generate reply #######################################
@ -227,8 +241,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
yield {'prompt': prompt} yield {'prompt': prompt}
return return
debug_msg({'prompt': prompt, 'generate_params': generate_params})
if stream: if stream:
yield chat_streaming_chunk('') yield chat_streaming_chunk('')
@ -238,8 +250,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
answer = '' answer = ''
seen_content = '' seen_content = ''
tool_calls = []
end_last_tool_call = 0
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
for a in generator: for a in generator:
answer = a['internal'][-1][1] answer = a['internal'][-1][1]
if supported_tools is not None:
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
if len(tool_call) > 0:
for tc in tool_call:
tc["id"] = getToolCallId()
tc["index"] = str(len(tool_calls))
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
tool_calls.append(tc)
end_last_tool_call = len(answer)
if stream: if stream:
len_seen = len(seen_content) len_seen = len(seen_content)
new_content = answer[len_seen:] new_content = answer[len_seen:]
@ -247,18 +274,25 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue continue
seen_content = answer
chunk = chat_streaming_chunk(new_content) chunk = chat_streaming_chunk(new_content)
seen_content = answer
yield chunk yield chunk
# stop generation if tool_calls were generated previously
if len(tool_calls) > 0:
break
token_count = len(encode(prompt)[0]) token_count = len(encode(prompt)[0])
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])
stop_reason = "stop" stop_reason = "stop"
if len(tool_calls) > 0:
stop_reason = "tool_calls"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length" stop_reason = "length"
if stream: if stream:
chunk = chat_streaming_chunk('') chunk = chat_streaming_chunk('', tool_calls)
chunk[resp_list][0]['finish_reason'] = stop_reason chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = { chunk['usage'] = {
"prompt_tokens": token_count, "prompt_tokens": token_count,
@ -276,7 +310,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
resp_list: [{ resp_list: [{
"index": 0, "index": 0,
"finish_reason": stop_reason, "finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer} "message": {"role": "assistant", "content": answer},
"tool_calls": tool_calls
}], }],
"usage": { "usage": {
"prompt_tokens": token_count, "prompt_tokens": token_count,
@ -465,3 +500,19 @@ def completions(body: dict, is_legacy: bool = False) -> dict:
def stream_completions(body: dict, is_legacy: bool = False): def stream_completions(body: dict, is_legacy: bool = False):
for resp in completions_common(body, is_legacy, stream=True): for resp in completions_common(body, is_legacy, stream=True):
yield resp yield resp
def validateTools(tools: list[dict]):
# Validate each tool definition in the JSON array
valid_tools = None
for idx in range(len(tools)):
tool = tools[idx]
try:
tool_definition = ToolDefinition(**tool)
if valid_tools is None:
valid_tools = []
valid_tools.append(tool)
except ValidationError:
raise InvalidRequestError(message=f"Invalid tool specification at index {idx}.", param='tools')
return valid_tools

View file

@ -1,8 +1,8 @@
import json import json
import time import time
from typing import Dict, List from typing import Dict, List, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, validator
class GenerationOptions(BaseModel): class GenerationOptions(BaseModel):
@ -54,6 +54,48 @@ class GenerationOptions(BaseModel):
grammar_string: str = "" grammar_string: str = ""
class ToolDefinition(BaseModel):
function: 'ToolFunction'
type: str
class ToolFunction(BaseModel):
description: str
name: str
parameters: 'ToolParameters'
class ToolParameters(BaseModel):
properties: Optional[Dict[str, 'ToolProperty']] = None
required: Optional[list[str]] = None
type: str
description: Optional[str] = None
class ToolProperty(BaseModel):
description: Optional[str] = None
type: Optional[str] = None # we are faced with definitions like anyOf, e.g. {'type': 'function', 'function': {'name': 'git_create_branch', 'description': 'Creates a new branch from an optional base branch', 'parameters': {'type': 'object', 'properties': {'repo_path': {'title': 'Repo Path', 'type': 'string'}, 'branch_name': {'title': 'Branch Name', 'type': 'string'}, 'base_branch': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None, 'title': 'Base Branch'}}, 'required': ['repo_path', 'branch_name'], 'title': 'GitCreateBranch'}}}
class FunctionCall(BaseModel):
name: str
arguments: Optional[str] = None
parameters: Optional[str] = None
@validator('arguments', allow_reuse=True)
def checkPropertyArgsOrParams(cls, v, values, **kwargs):
if not v and not values.get('parameters'):
raise ValueError("At least one of 'arguments' or 'parameters' must be provided as property in FunctionCall type")
return v
class ToolCall(BaseModel):
id: str
index: int
type: str
function: FunctionCall
class CompletionRequestParams(BaseModel): class CompletionRequestParams(BaseModel):
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.") model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
prompt: str | List[str] prompt: str | List[str]
@ -92,6 +134,7 @@ class ChatCompletionRequestParams(BaseModel):
frequency_penalty: float | None = 0 frequency_penalty: float | None = 0
function_call: str | dict | None = Field(default=None, description="Unused parameter.") function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.") functions: List[dict] | None = Field(default=None, description="Unused parameter.")
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
logit_bias: dict | None = None logit_bias: dict | None = None
max_tokens: int | None = None max_tokens: int | None = None
n: int | None = Field(default=1, description="Unused parameter.") n: int | None = Field(default=1, description="Unused parameter.")

View file

@ -1,6 +1,9 @@
import base64 import base64
import os import os
import time import time
import json
import random
import re
import traceback import traceback
from typing import Callable, Optional from typing import Callable, Optional
@ -52,3 +55,94 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
time.sleep(3) time.sleep(3)
raise Exception('Could not start cloudflared.') raise Exception('Could not start cloudflared.')
def getToolCallId() -> str:
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b = [random.choice(letter_bytes) for _ in range(8)]
return "call_" + "".join(b).lower()
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
# check if property 'function' exists and is a dictionary, otherwise adapt dict
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
candidate_dict['name'] = candidate_dict['function']
del candidate_dict['function']
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
# check if 'name' exists within 'function' and is part of known tools
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
# map property 'parameters' used by some older models to 'arguments'
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
del candidate_dict["function"]["parameters"]
return candidate_dict
return None
def parseToolCall(answer: str, tool_names: list[str]):
matches = []
# abort on very short answers to save computation cycles
if len(answer) < 10:
return matches
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [ r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>" ]
for pattern in patterns:
for match in re.finditer(pattern, answer, re.DOTALL):
# print(match.group(2))
if match.group(2) is None:
continue
# remove backtick wraps if present
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
candidate = re.sub(r"```$", "", candidate.strip())
# unwrap inner tags
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
candidates = []
try:
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
except json.JSONDecodeError:
# Ignore invalid JSON silently
continue
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
if len(matches) == 0:
try:
candidate = answer
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
except json.JSONDecodeError:
# Ignore invalid JSON silently
pass
return matches

View file

@ -145,7 +145,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
instruct_renderer = partial( instruct_renderer = partial(
instruction_template.render, instruction_template.render,
builtin_tools=None, builtin_tools=None,
tools=None, tools=state['tools'] if 'tools' in state else None,
tools_in_user_message=False, tools_in_user_message=False,
add_generation_prompt=False add_generation_prompt=False
) )
@ -171,9 +171,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "system", "content": context}) messages.append({"role": "system", "content": context})
insert_pos = len(messages) insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history): for entry in reversed(history):
user_msg = user_msg.strip() user_msg = entry[0].strip()
assistant_msg = assistant_msg.strip() assistant_msg = entry[1].strip()
tool_msg = entry[2].strip() if len(entry) > 2 else ''
if tool_msg:
messages.insert(insert_pos, {"role": "tool", "content": tool_msg})
if assistant_msg: if assistant_msg:
messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg}) messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})