mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 14:17:09 -04:00
Tools support for OpenAI compatible API (#6827)
This commit is contained in:
parent
ed6e16191d
commit
fa960496d5
4 changed files with 209 additions and 17 deletions
|
@ -1,11 +1,14 @@
|
|||
import copy
|
||||
import time
|
||||
import json
|
||||
from collections import deque
|
||||
|
||||
import tiktoken
|
||||
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.utils import debug_msg
|
||||
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
|
||||
from extensions.openai.typing import ToolDefinition
|
||||
from pydantic import ValidationError
|
||||
from modules import shared
|
||||
from modules.chat import (
|
||||
generate_chat_prompt,
|
||||
|
@ -99,19 +102,24 @@ def convert_history(history):
|
|||
user_input = content
|
||||
user_input_last = True
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, ''])
|
||||
chat_dialogue.append([current_message, '', ''])
|
||||
current_message = ""
|
||||
|
||||
current_message = content
|
||||
elif role == "assistant":
|
||||
if "tool_calls" in entry and isinstance(entry["tool_calls"], list) and len(entry["tool_calls"]) > 0 and content.strip() == "":
|
||||
continue # skip tool calls
|
||||
current_reply = content
|
||||
user_input_last = False
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, current_reply])
|
||||
chat_dialogue.append([current_message, current_reply, ''])
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
else:
|
||||
chat_dialogue.append(['', current_reply])
|
||||
chat_dialogue.append(['', current_reply, ''])
|
||||
elif role == "tool":
|
||||
user_input_last = False
|
||||
chat_dialogue.append(['', '', content])
|
||||
elif role == "system":
|
||||
system_message += f"\n{content}" if system_message else content
|
||||
|
||||
|
@ -131,6 +139,10 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if 'messages' not in body:
|
||||
raise InvalidRequestError(message="messages is required", param='messages')
|
||||
|
||||
tools = None
|
||||
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
|
||||
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
|
||||
|
||||
messages = body['messages']
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
|
@ -188,6 +200,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
'custom_system_message': custom_system_message,
|
||||
'chat_template_str': chat_template_str,
|
||||
'chat-instruct_command': chat_instruct_command,
|
||||
'tools': tools,
|
||||
'history': history,
|
||||
'stream': stream
|
||||
})
|
||||
|
@ -200,7 +213,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
|
||||
def chat_streaming_chunk(content):
|
||||
def chat_streaming_chunk(content, chunk_tool_calls=None):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
|
@ -210,7 +223,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"delta": {'role': 'assistant', 'content': content},
|
||||
"delta": {'role': 'assistant', 'content': content, 'tool_calls': chunk_tool_calls},
|
||||
}],
|
||||
}
|
||||
|
||||
|
@ -219,6 +232,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# chunk[resp_list][0]["logprobs"] = None
|
||||
|
||||
return chunk
|
||||
|
||||
# generate reply #######################################
|
||||
|
@ -227,8 +241,6 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
yield {'prompt': prompt}
|
||||
return
|
||||
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
|
||||
if stream:
|
||||
yield chat_streaming_chunk('')
|
||||
|
||||
|
@ -238,8 +250,23 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
answer = ''
|
||||
seen_content = ''
|
||||
|
||||
tool_calls = []
|
||||
end_last_tool_call = 0
|
||||
supported_tools = [x["function"]["name"] for x in tools] if tools is not None else None
|
||||
|
||||
for a in generator:
|
||||
answer = a['internal'][-1][1]
|
||||
|
||||
if supported_tools is not None:
|
||||
tool_call = parseToolCall(answer[end_last_tool_call:], supported_tools) if len(answer) > 0 else []
|
||||
if len(tool_call) > 0:
|
||||
for tc in tool_call:
|
||||
tc["id"] = getToolCallId()
|
||||
tc["index"] = str(len(tool_calls))
|
||||
tc["function"]["arguments"] = json.dumps(tc["function"]["arguments"])
|
||||
tool_calls.append(tc)
|
||||
end_last_tool_call = len(answer)
|
||||
|
||||
if stream:
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
@ -247,18 +274,25 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
seen_content = answer
|
||||
yield chunk
|
||||
|
||||
# stop generation if tool_calls were generated previously
|
||||
if len(tool_calls) > 0:
|
||||
break
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if len(tool_calls) > 0:
|
||||
stop_reason = "tool_calls"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk('')
|
||||
chunk = chat_streaming_chunk('', tool_calls)
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
"prompt_tokens": token_count,
|
||||
|
@ -276,7 +310,8 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
|
|||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer}
|
||||
"message": {"role": "assistant", "content": answer},
|
||||
"tool_calls": tool_calls
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
|
@ -465,3 +500,19 @@ def completions(body: dict, is_legacy: bool = False) -> dict:
|
|||
def stream_completions(body: dict, is_legacy: bool = False):
|
||||
for resp in completions_common(body, is_legacy, stream=True):
|
||||
yield resp
|
||||
|
||||
|
||||
def validateTools(tools: list[dict]):
|
||||
# Validate each tool definition in the JSON array
|
||||
valid_tools = None
|
||||
for idx in range(len(tools)):
|
||||
tool = tools[idx]
|
||||
try:
|
||||
tool_definition = ToolDefinition(**tool)
|
||||
if valid_tools is None:
|
||||
valid_tools = []
|
||||
valid_tools.append(tool)
|
||||
except ValidationError:
|
||||
raise InvalidRequestError(message=f"Invalid tool specification at index {idx}.", param='tools')
|
||||
|
||||
return valid_tools
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import json
|
||||
import time
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
|
||||
class GenerationOptions(BaseModel):
|
||||
|
@ -54,6 +54,48 @@ class GenerationOptions(BaseModel):
|
|||
grammar_string: str = ""
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
function: 'ToolFunction'
|
||||
type: str
|
||||
|
||||
|
||||
class ToolFunction(BaseModel):
|
||||
description: str
|
||||
name: str
|
||||
parameters: 'ToolParameters'
|
||||
|
||||
|
||||
class ToolParameters(BaseModel):
|
||||
properties: Optional[Dict[str, 'ToolProperty']] = None
|
||||
required: Optional[list[str]] = None
|
||||
type: str
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ToolProperty(BaseModel):
|
||||
description: Optional[str] = None
|
||||
type: Optional[str] = None # we are faced with definitions like anyOf, e.g. {'type': 'function', 'function': {'name': 'git_create_branch', 'description': 'Creates a new branch from an optional base branch', 'parameters': {'type': 'object', 'properties': {'repo_path': {'title': 'Repo Path', 'type': 'string'}, 'branch_name': {'title': 'Branch Name', 'type': 'string'}, 'base_branch': {'anyOf': [{'type': 'string'}, {'type': 'null'}], 'default': None, 'title': 'Base Branch'}}, 'required': ['repo_path', 'branch_name'], 'title': 'GitCreateBranch'}}}
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: Optional[str] = None
|
||||
parameters: Optional[str] = None
|
||||
|
||||
@validator('arguments', allow_reuse=True)
|
||||
def checkPropertyArgsOrParams(cls, v, values, **kwargs):
|
||||
if not v and not values.get('parameters'):
|
||||
raise ValueError("At least one of 'arguments' or 'parameters' must be provided as property in FunctionCall type")
|
||||
return v
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
index: int
|
||||
type: str
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class CompletionRequestParams(BaseModel):
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, use the /v1/internal/model/load endpoint.")
|
||||
prompt: str | List[str]
|
||||
|
@ -92,6 +134,7 @@ class ChatCompletionRequestParams(BaseModel):
|
|||
frequency_penalty: float | None = 0
|
||||
function_call: str | dict | None = Field(default=None, description="Unused parameter.")
|
||||
functions: List[dict] | None = Field(default=None, description="Unused parameter.")
|
||||
tools: List[dict] | None = Field(default=None, description="Tools signatures passed via MCP.")
|
||||
logit_bias: dict | None = None
|
||||
max_tokens: int | None = None
|
||||
n: int | None = Field(default=1, description="Unused parameter.")
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import base64
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
|
||||
|
@ -52,3 +55,94 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
|
|||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
||||
|
||||
|
||||
def getToolCallId() -> str:
|
||||
letter_bytes = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
b = [random.choice(letter_bytes) for _ in range(8)]
|
||||
return "call_" + "".join(b).lower()
|
||||
|
||||
|
||||
def checkAndSanitizeToolCallCandidate(candidate_dict: dict, tool_names: list[str]):
|
||||
# check if property 'function' exists and is a dictionary, otherwise adapt dict
|
||||
if 'function' not in candidate_dict and 'name' in candidate_dict and isinstance(candidate_dict['name'], str):
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], str):
|
||||
candidate_dict['name'] = candidate_dict['function']
|
||||
del candidate_dict['function']
|
||||
candidate_dict = {"type": "function", "function": candidate_dict}
|
||||
if 'function' in candidate_dict and isinstance(candidate_dict['function'], dict):
|
||||
# check if 'name' exists within 'function' and is part of known tools
|
||||
if 'name' in candidate_dict['function'] and candidate_dict['function']['name'] in tool_names:
|
||||
candidate_dict["type"] = "function" # ensure required property 'type' exists and has the right value
|
||||
# map property 'parameters' used by some older models to 'arguments'
|
||||
if "arguments" not in candidate_dict["function"] and "parameters" in candidate_dict["function"]:
|
||||
candidate_dict["function"]["arguments"] = candidate_dict["function"]["parameters"]
|
||||
del candidate_dict["function"]["parameters"]
|
||||
return candidate_dict
|
||||
return None
|
||||
|
||||
|
||||
def parseToolCall(answer: str, tool_names: list[str]):
|
||||
matches = []
|
||||
|
||||
# abort on very short answers to save computation cycles
|
||||
if len(answer) < 10:
|
||||
return matches
|
||||
|
||||
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
|
||||
patterns = [ r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>" ]
|
||||
|
||||
for pattern in patterns:
|
||||
for match in re.finditer(pattern, answer, re.DOTALL):
|
||||
# print(match.group(2))
|
||||
if match.group(2) is None:
|
||||
continue
|
||||
# remove backtick wraps if present
|
||||
candidate = re.sub(r"^```(json|xml|python[^\n]*)\n", "", match.group(2).strip())
|
||||
candidate = re.sub(r"```$", "", candidate.strip())
|
||||
# unwrap inner tags
|
||||
candidate = re.sub(pattern, r"\2", candidate.strip(), flags=re.DOTALL)
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
|
||||
candidates = []
|
||||
try:
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
continue
|
||||
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
matches.append(checked_candidate)
|
||||
|
||||
# last resort if nothing has been mapped: LLM might have produced plain json tool call without xml-like tags
|
||||
if len(matches) == 0:
|
||||
try:
|
||||
candidate = answer
|
||||
# llm might have generated multiple json objects separated by linebreaks, check for this pattern and try parsing each object individually
|
||||
if re.search(r"\}\s*\n\s*\{", candidate) is not None:
|
||||
candidate = re.sub(r"\}\s*\n\s*\{", "},\n{", candidate)
|
||||
if not candidate.strip().startswith("["):
|
||||
candidate = "[" + candidate + "]"
|
||||
# parse the candidate JSON into a dictionary
|
||||
candidates = json.loads(candidate)
|
||||
if not isinstance(candidates, list):
|
||||
candidates = [candidates]
|
||||
for candidate_dict in candidates:
|
||||
checked_candidate = checkAndSanitizeToolCallCandidate(candidate_dict, tool_names)
|
||||
if checked_candidate is not None:
|
||||
matches.append(checked_candidate)
|
||||
except json.JSONDecodeError:
|
||||
# Ignore invalid JSON silently
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
|
|
@ -145,7 +145,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
instruct_renderer = partial(
|
||||
instruction_template.render,
|
||||
builtin_tools=None,
|
||||
tools=None,
|
||||
tools=state['tools'] if 'tools' in state else None,
|
||||
tools_in_user_message=False,
|
||||
add_generation_prompt=False
|
||||
)
|
||||
|
@ -171,9 +171,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||
messages.append({"role": "system", "content": context})
|
||||
|
||||
insert_pos = len(messages)
|
||||
for user_msg, assistant_msg in reversed(history):
|
||||
user_msg = user_msg.strip()
|
||||
assistant_msg = assistant_msg.strip()
|
||||
for entry in reversed(history):
|
||||
user_msg = entry[0].strip()
|
||||
assistant_msg = entry[1].strip()
|
||||
tool_msg = entry[2].strip() if len(entry) > 2 else ''
|
||||
|
||||
if tool_msg:
|
||||
messages.insert(insert_pos, {"role": "tool", "content": tool_msg})
|
||||
|
||||
if assistant_msg:
|
||||
messages.insert(insert_pos, {"role": "assistant", "content": assistant_msg})
|
||||
|
|
Loading…
Add table
Reference in a new issue