Improve the web search query generation

This commit is contained in:
oobabooga 2025-05-28 18:14:51 -07:00
parent 27641ac182
commit 3eb0b77427
2 changed files with 28 additions and 26 deletions

View file

@ -538,6 +538,27 @@ def extract_pdf_text(pdf_path):
return f"[Error extracting PDF text: {str(e)}]"
def generate_search_query(user_message, state):
"""Generate a search query from user message using the LLM"""
# Augment the user message with search instruction
augmented_message = f"{user_message}\n\n=====\n\nPlease turn the message above into a short web search query in the same language as the message. Respond with only the search query, nothing else."
# Use a minimal state for search query generation but keep the full history
search_state = state.copy()
search_state['max_new_tokens'] = 64
search_state['auto_max_new_tokens'] = False
search_state['enable_thinking'] = False
# Generate the full prompt using existing history + augmented message
formatted_prompt = generate_chat_prompt(augmented_message, search_state)
query = ""
for reply in generate_reply(formatted_prompt, search_state, stopping_strings=[], is_chat=True):
query = reply.strip()
return query
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
# Handle dict format with text and files
files = []
@ -570,7 +591,9 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
add_message_attachment(output, row_idx, file_path, is_user=True)
# Add web search results as attachments if enabled
add_web_search_attachments(output, row_idx, text, state)
if state.get('enable_web_search', False):
search_query = generate_search_query(text, state)
add_web_search_attachments(output, row_idx, text, search_query, state)
# Apply extensions
text, visible_text = apply_extensions('chat_input', text, visible_text, state)

View file

@ -13,22 +13,6 @@ def get_current_timestamp():
return datetime.now().strftime('%b %d, %Y %H:%M')
def generate_search_query(user_message, state):
"""Generate a search query from user message using the LLM"""
search_prompt = f"{user_message}\n\n=====\n\nPlease turn the message above into a short web search query in the same language as the message. Respond with only the search query, nothing else."
# Use a minimal state for search query generation
search_state = state.copy()
search_state['max_new_tokens'] = 64
search_state['temperature'] = 0.1
query = ""
for reply in generate_reply(search_prompt, search_state, stopping_strings=[], is_chat=False):
query = reply.strip()
return query
def download_web_page(url, timeout=10):
"""Download and extract text from a web page"""
try:
@ -82,19 +66,14 @@ def perform_web_search(query, num_pages=3):
return []
def add_web_search_attachments(history, row_idx, user_message, state):
def add_web_search_attachments(history, row_idx, user_message, search_query, state):
"""Perform web search and add results as attachments"""
if not state.get('enable_web_search', False):
if not search_query:
logger.warning("No search query provided")
return
try:
# Generate search query
search_query = generate_search_query(user_message, state)
if not search_query:
logger.warning("Failed to generate search query")
return
logger.info(f"Generated search query: {search_query}")
logger.info(f"Using search query: {search_query}")
# Perform web search
num_pages = int(state.get('web_search_pages', 3))