Tools support for OpenAI compatible API (#6827)

This commit is contained in:
Jonas 2025-05-08 17:30:27 +02:00 committed by GitHub
parent ed6e16191d
commit fa960496d5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 209 additions and 17 deletions

View file

@ -1,11 +1,14 @@
import copy
import time
import json
from collections import deque
import tiktoken
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
from extensions.openai.typing import ToolDefinition
from pydantic import ValidationError
from modules import shared
from modules.chat import (
generate_chat_prompt,
@ -99,19 +102,24 @@ def convert_history(history):
user_input = content
user_input_last = True
if current_message:
chat_dialogue.append([current_message, ''])
chat_dialogue.append([current_message, '', ''])
current_message = ""
current_message = content
elif role == "assistant":
if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "":
continue # skip tool calls
current_reply = content
user_input_last = False
if current_message:
chat_dialogue.append([current_message, current_reply])
chat_dialogue.append([current_message, current_reply, ''])
current_message = ""
current_reply = ""
else:
chat_dialogue.append(['', current_reply])
chat_dialogue.append(['', current_reply, ''])
elif role == "tool":
user_input_last = False
chat_dialogue.append(['', '', content])
elif role == "system":
system_message += f"\n{content}" if system_message else content
@ -131,6 +139,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages')
tools = None
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
messages = body['messages']
for m in messages:
if 'role' not in m:
@ -188,6 +200,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
'custom_system_message': custom_system_message,
'chat_template_str': chat_template_str,
'chat-instruct_command': chat_instruct_command,
'tools': tools,
'history': history,
'stream': stream
})
@ -200,7 +213,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
requested_model = generate_params.pop('model')
logprob_proc = generate_params.pop('logprob_proc', None)
def chat_streaming_chunk(content):
def chat_streaming_chunk(content, chunk_tool_calls=None):
# begin streaming
chunk = {
"id": cmpl_id,
@ -210,7 +223,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
resp_list: [{
"index": 0,
"finish_reason": None,
"delta": {'role': 'assistant', 'content': content},
"delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
}],
}
@ -219,6 +232,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# chunk[resp_list][0]["logprobs"] = None
return chunk
# generate reply #######################################
@ -227,8 +241,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
yield {'prompt': prompt}
return
debug_msg({'prompt': prompt, 'generate_params': generate_params})
if stream:
yield chat_streaming_chunk('')
@ -238,8 +250,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
answer = ''
seen_content = ''
tool_calls = []
end_last_tool_call = 0
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
for a in generator:
answer = a['internal'][-1][1]
if supported_tools is not None:
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
if len(tool_call) > 0:
for tc in tool_call:
tc["id"] = getToolCallId()
tc["index"] = str(len(tool_calls))
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
tool_calls.append(tc)
end_last_tool_call = len(answer)
if stream:
len_seen = len(seen_content)
new_content = answer[len_seen:]
@ -247,18 +274,25 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
continue
seen_content = answer
chunk = chat_streaming_chunk(new_content)
seen_content = answer
yield chunk
# stop generation if tool_calls were generated previously
if len(tool_calls) > 0:
break
token_count = len(encode(prompt)[0])
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if len(tool_calls) > 0:
stop_reason = "tool_calls"
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
stop_reason = "length"
if stream:
chunk = chat_streaming_chunk('')
chunk = chat_streaming_chunk('', tool_calls)
chunk[resp_list][0]['finish_reason'] = stop_reason
chunk['usage'] = {
"prompt_tokens": token_count,
@ -276,7 +310,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
resp_list: [{
"index": 0,
"finish_reason": stop_reason,
"message": {"role": "assistant", "content": answer}
"message": {"role": "assistant", "content": answer},
"tool_calls": tool_calls
}],
"usage": {
"prompt_tokens": token_count,
@ -465,3 +500,19 @@ def completions(body: dict, is_legacy: bool = False) -> dict:
def stream_completions(body: dict, is_legacy: bool = False):
for resp in completions_common(body, is_legacy, stream=True):
yield resp
def validateTools(tools: list[dict]):
# Validate each tool definition in the JSON array
valid_tools = None
for idx in range(len(tools)):
tool = tools[idx]
try:
tool_definition = ToolDefinition(**tool)
if valid_tools is None:
valid_tools = []
valid_tools.append(tool)
except ValidationError:
raise InvalidRequestError(message=f"Invalid tool specification at index {idx}.", param='tools')
return valid_tools

View file

@ -1,8 +1,8 @@
import json
import time
from typing import Dict, List
from typing import Dict, List, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
class GenerationOptions(BaseModel):
@ -54,6 +54,48 @@ class GenerationOptions(BaseModel):
grammar_string: str = ""
class ToolDefinition(BaseModel):
function: 'ToolFunction'
type: str
class ToolFunction(BaseModel):
description: str
name: str
parameters: 'ToolParameters'
class ToolParameters(BaseModel):
properties: Optional[Dict[str, 'ToolProperty']] = None
required: Optional[list[str]] = None
type: str
description: Optional[str] = None
class ToolProperty(BaseModel):
description: Optional[str] = None
type: Optional[str] = None # we are faced with definitions like anyOf, e.g. {'type': 'function', 'function': {'name': 'git_create_branch', 'description': 'Creates a new branch from an optional base branch', 'parameters': {'type': 'object', 'properties': {'repo_path': {'title': 'Repo Path', 'type': 'string'}, 'branch_name': {'title': 'Branch Name', 'type': 'string'}, 'base_branch': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None, 'title': 'Base Branch'}}, 'required': ['repo_path', 'branch_name'], 'title': 'GitCreateBranch'}}}
class FunctionCall(BaseModel):
name: str
arguments: Optional[str] = None
parameters: Optional[str] = None
@validator('arguments', allow_reuse=True)
def checkPropertyArgsOrParams(cls, v, values, **kwargs):
if not v and not values.get('parameters'):
raise ValueError("At least one of 'arguments' or 'parameters' must be provided as property in FunctionCall type")
return v
class ToolCall(BaseModel):
id: str
index: int
type: str
function: FunctionCall
class CompletionRequestParams(BaseModel):
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
prompt: str | List[str]
@ -92,6 +134,7 @@ class ChatCompletionRequestParams(BaseModel):
frequency_penalty: float | None = 0
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
logit_bias: dict | None = None
max_tokens: int | None = None
n: int | None = Field(default=1, description="Unused parameter.")

View file

@ -1,6 +1,9 @@
import base64
import os
import time
import json
import random
import re
import traceback
from typing import Callable, Optional
@ -52,3 +55,94 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
time.sleep(3)
raise Exception('Could not start cloudflared.')
def getToolCallId() -> str:
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b = [random.choice(letter_bytes) for _ in range(8)]
return "call_" + "".join(b).lower()
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
# check if property 'function' exists and is a dictionary, otherwise adapt dict
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
candidate_dict['name'] = candidate_dict['function']
del candidate_dict['function']
candidate_dict = {"type": "function", "function": candidate_dict}
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
# check if 'name' exists within 'function' and is part of known tools
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
# map property 'parameters' used by some older models to 'arguments'
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
del candidate_dict["function"]["parameters"]
return candidate_dict
return None
def parseToolCall(answer: str, tool_names: list[str]):
matches = []
# abort on very short answers to save computation cycles
if len(answer) < 10:
return matches
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [ r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>" ]
for pattern in patterns:
for match in re.finditer(pattern, answer, re.DOTALL):
# print(match.group(2))
if match.group(2) is None:
continue
# remove backtick wraps if present
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
candidate = re.sub(r"```$", "", candidate.strip())
# unwrap inner tags
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
candidates = []
try:
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
except json.JSONDecodeError:
# Ignore invalid JSON silently
continue
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
if len(matches) == 0:
try:
candidate = answer
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
if not candidate.strip().startswith("["):
candidate = "[" + candidate + "]"
# parse the candidate JSON into a dictionary
candidates = json.loads(candidate)
if not isinstance(candidates, list):
candidates = [candidates]
for candidate_dict in candidates:
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
if checked_candidate is not None:
matches.append(checked_candidate)
except json.JSONDecodeError:
# Ignore invalid JSON silently
pass
return matches

View file

@ -145,7 +145,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
instruct_renderer = partial(
instruction_template.render,
builtin_tools=None,
tools=None,
tools=state['tools'] if 'tools' in state else None,
tools_in_user_message=False,
add_generation_prompt=False
)
@ -171,9 +171,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
messages.append({"role": "system", "content": context})
insert_pos = len(messages)
for user_msg, assistant_msg in reversed(history):
user_msg = user_msg.strip()
assistant_msg = assistant_msg.strip()
for entry in reversed(history):
user_msg = entry[0].strip()
assistant_msg = entry[1].strip()
tool_msg = entry[2].strip() if len(entry) > 2 else ''
if tool_msg:
messages.insert(insert_pos, {"role": "tool", "content": tool_msg})
if assistant_msg:
messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})