This commit is contained in:
oobabooga 2025-05-15 21:19:19 -07:00
parent 8cb73b78e1
commit fd61297933
4 changed files with 12 additions and 11 deletions

View file

@ -1,14 +1,14 @@
import copy
import time
import json
import time
from collections import deque
import tiktoken
from pydantic import ValidationError
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
from extensions.openai.typing import ToolDefinition
from pydantic import ValidationError
from extensions.openai.utils import debug_msg, getToolCallId, parseToolCall
from modules import shared
from modules.chat import (
generate_chat_prompt,
@ -141,7 +141,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, p
tools = None
if 'tools' in body and body['tools'] is not None and isinstance(body['tools'], list) and len(body['tools']) > 0:
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
tools = validateTools(body['tools']) # raises InvalidRequestError if validation fails
messages = body['messages']
for m in messages:

View file

@ -1,9 +1,9 @@
import base64
import os
import time
import json
import os
import random
import re
import time
import traceback
from typing import Callable, Optional
@ -91,7 +91,7 @@ def parseToolCall(answer: str, tool_names: list[str]):
return matches
# Define the regex pattern to find the JSON content wrapped in <function>, <tools>, <tool_call>, and other tags observed from various models
patterns = [ r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>" ]
patterns = [r"(```[^\n]*)\n(.*?)```", r"<([^>]+)>(.*?)</\1>"]
for pattern in patterns:
for match in re.finditer(pattern, answer, re.DOTALL):

View file

@ -1,10 +1,11 @@
import math
import random
import threading
import torch
import chromadb
import numpy as np
import posthog
import torch
from chromadb.config import Settings
from chromadb.utils import embedding_functions

View file

@ -1,15 +1,15 @@
from pathlib import Path
import torch
import tensorrt_llm
import torch
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
from modules import shared
from modules.logging_colors import logger
from modules.text_generation import (
get_max_prompt_length,
get_reply_from_output_ids
)
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp
class TensorRTLLMModel: