API: Fix llama.cpp generating after disconnect, improve disconnect detection, fix deadlock on simultaneous requests

This commit is contained in:
oobabooga 2025-05-13 11:23:33 -07:00
parent 62c774bf24
commit c375b69413

View file

@ -14,6 +14,7 @@ from fastapi.requests import Request
from fastapi.responses import JSONResponse
from pydub import AudioSegment
from sse_starlette import EventSourceResponse
from starlette.concurrency import iterate_in_threadpool
import extensions.openai.completions as OAIcompletions
import extensions.openai.images as OAIimages
@ -115,7 +116,7 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
async def generator():
async with streaming_semaphore:
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
@ -125,7 +126,12 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
return EventSourceResponse(generator()) # SSE streaming
else:
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
response = await asyncio.to_thread(
OAIcompletions.completions,
to_dict(request_data),
is_legacy=is_legacy
)
return JSONResponse(response)
@ -138,7 +144,7 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
async def generator():
async with streaming_semaphore:
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
async for resp in iterate_in_threadpool(response):
disconnected = await request.is_disconnected()
if disconnected:
break
@ -148,7 +154,12 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
return EventSourceResponse(generator()) # SSE streaming
else:
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
response = await asyncio.to_thread(
OAIcompletions.chat_completions,
to_dict(request_data),
is_legacy=is_legacy
)
return JSONResponse(response)