Pre-merge dev branch

This commit is contained in:
oobabooga 2025-06-05 10:47:07 -07:00
parent 1f3b1a1b94
commit 27affa9db7

View file

@ -467,19 +467,21 @@ def get_stopping_strings(state):
return result return result
def add_message_version(history, row_idx, is_current=True): def add_message_version(history, role, row_idx, is_current=True):
key = f"assistant_{row_idx}" key = f"{role}_{row_idx}"
if 'metadata' not in history:
history['metadata'] = {}
if key not in history['metadata']: if key not in history['metadata']:
history['metadata'][key] = {} history['metadata'][key] = {}
if "versions" not in history['metadata'][key]: if "versions" not in history['metadata'][key]:
history['metadata'][key]["versions"] = [] history['metadata'][key]["versions"] = []
current_content = history['internal'][row_idx][1] # Determine which index to use for content based on role
current_visible = history['visible'][row_idx][1] content_idx = 0 if role == 'user' else 1
current_content = history['internal'][row_idx][content_idx]
current_visible = history['visible'][row_idx][content_idx]
# Always add the current message as a new version entry.
# The timestamp will differentiate it even if content is identical to a previous version.
history['metadata'][key]["versions"].append({ history['metadata'][key]["versions"].append({
"content": current_content, "content": current_content,
"visible_content": current_visible, "visible_content": current_visible,
@ -534,6 +536,13 @@ def add_message_attachment(history, row_idx, file_path, is_user=True):
"type": "application/pdf", "type": "application/pdf",
"content": content, "content": content,
} }
elif file_extension == '.docx':
content = extract_docx_text(path)
attachment = {
"name": filename,
"type": "application/docx",
"content": content,
}
else: else:
# Default handling for text files # Default handling for text files
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
@ -569,6 +578,79 @@ def extract_pdf_text(pdf_path):
return f"[Error extracting PDF text: {str(e)}]" return f"[Error extracting PDF text: {str(e)}]"
def extract_docx_text(docx_path):
"""
Extract text from a .docx file, including headers,
body (paragraphs and tables), and footers.
"""
try:
import docx
doc = docx.Document(docx_path)
parts = []
# 1) Extract non-empty header paragraphs from each section
for section in doc.sections:
for para in section.header.paragraphs:
text = para.text.strip()
if text:
parts.append(text)
# 2) Extract body blocks (paragraphs and tables) in document order
parent_elm = doc.element.body
for child in parent_elm.iterchildren():
if isinstance(child, docx.oxml.text.paragraph.CT_P):
para = docx.text.paragraph.Paragraph(child, doc)
text = para.text.strip()
if text:
parts.append(text)
elif isinstance(child, docx.oxml.table.CT_Tbl):
table = docx.table.Table(child, doc)
for row in table.rows:
cells = [cell.text.strip() for cell in row.cells]
parts.append("\t".join(cells))
# 3) Extract non-empty footer paragraphs from each section
for section in doc.sections:
for para in section.footer.paragraphs:
text = para.text.strip()
if text:
parts.append(text)
return "\n".join(parts)
except Exception as e:
logger.error(f"Error extracting text from DOCX: {e}")
return f"[Error extracting DOCX text: {str(e)}]"
def generate_search_query(user_message, state):
"""Generate a search query from user message using the LLM"""
# Augment the user message with search instruction
augmented_message = f"{user_message}\n\n=====\n\nPlease turn the message above into a short web search query in the same language as the message. Respond with only the search query, nothing else."
# Use a minimal state for search query generation but keep the full history
search_state = state.copy()
search_state['max_new_tokens'] = 64
search_state['auto_max_new_tokens'] = False
search_state['enable_thinking'] = False
# Generate the full prompt using existing history + augmented message
formatted_prompt = generate_chat_prompt(augmented_message, search_state)
query = ""
for reply in generate_reply(formatted_prompt, search_state, stopping_strings=[], is_chat=True):
query = reply
# Strip and remove surrounding quotes if present
query = query.strip()
if len(query) >= 2 and query.startswith('"') and query.endswith('"'):
query = query[1:-1]
return query
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False): def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
# Handle dict format with text and files # Handle dict format with text and files
files = [] files = []
@ -614,7 +696,9 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
state['image_attachments'] = image_attachments state['image_attachments'] = image_attachments
# Add web search results as attachments if enabled # Add web search results as attachments if enabled
add_web_search_attachments(output, row_idx, text, state) if state.get('enable_web_search', False):
search_query = generate_search_query(text, state)
add_web_search_attachments(output, row_idx, text, search_query, state)
# Apply extensions # Apply extensions
text, visible_text = apply_extensions('chat_input', text, visible_text, state) text, visible_text = apply_extensions('chat_input', text, visible_text, state)
@ -638,9 +722,18 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
if regenerate: if regenerate:
row_idx = len(output['internal']) - 1 row_idx = len(output['internal']) - 1
# Store the first response as a version before regenerating # Store the old response as a version before regenerating
if not output['metadata'].get(f"assistant_{row_idx}", {}).get('versions'): if not output['metadata'].get(f"assistant_{row_idx}", {}).get('versions'):
add_message_version(output, row_idx, is_current=False) add_message_version(output, "assistant", row_idx, is_current=False)
# Add new empty version (will be filled during streaming)
key = f"assistant_{row_idx}"
output['metadata'][key]["versions"].append({
"content": "",
"visible_content": "",
"timestamp": get_current_timestamp()
})
output['metadata'][key]["current_version_index"] = len(output['metadata'][key]["versions"]) - 1
if loading_message: if loading_message:
yield { yield {
@ -672,7 +765,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
# Add timestamp for assistant's response at the start of generation # Add timestamp for assistant's response at the start of generation
row_idx = len(output['internal']) - 1 row_idx = len(output['internal']) - 1
update_message_metadata(output['metadata'], "assistant", row_idx, timestamp=get_current_timestamp()) update_message_metadata(output['metadata'], "assistant", row_idx, timestamp=get_current_timestamp(), model_name=shared.model_name)
# Generate # Generate
reply = None reply = None
@ -694,33 +787,51 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
if _continue: if _continue:
output['internal'][-1] = [text, last_reply[0] + reply] output['internal'][-1] = [text, last_reply[0] + reply]
output['visible'][-1] = [visible_text, last_reply[1] + visible_reply] output['visible'][-1] = [visible_text, last_reply[1] + visible_reply]
if is_stream:
yield output
elif not (j == 0 and visible_reply.strip() == ''): elif not (j == 0 and visible_reply.strip() == ''):
output['internal'][-1] = [text, reply.lstrip(' ')] output['internal'][-1] = [text, reply.lstrip(' ')]
output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')] output['visible'][-1] = [visible_text, visible_reply.lstrip(' ')]
# Keep version metadata in sync during streaming (for regeneration)
if regenerate:
row_idx = len(output['internal']) - 1
key = f"assistant_{row_idx}"
current_idx = output['metadata'][key]['current_version_index']
output['metadata'][key]['versions'][current_idx].update({
'content': output['internal'][row_idx][1],
'visible_content': output['visible'][row_idx][1]
})
if is_stream: if is_stream:
yield output yield output
# Add the newly generated response as a version (only for regeneration) output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
# Final sync for version metadata (in case streaming was disabled)
if regenerate: if regenerate:
row_idx = len(output['internal']) - 1 row_idx = len(output['internal']) - 1
add_message_version(output, row_idx, is_current=True) key = f"assistant_{row_idx}"
current_idx = output['metadata'][key]['current_version_index']
output['metadata'][key]['versions'][current_idx].update({
'content': output['internal'][row_idx][1],
'visible_content': output['visible'][row_idx][1]
})
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
yield output yield output
def impersonate_wrapper(text, state): def impersonate_wrapper(textbox, state):
text = textbox['text']
static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) static_output = chat_html_wrapper(state['history'], state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
prompt = generate_chat_prompt('', state, impersonate=True) prompt = generate_chat_prompt('', state, impersonate=True)
stopping_strings = get_stopping_strings(state) stopping_strings = get_stopping_strings(state)
yield text + '...', static_output textbox['text'] = text + '...'
yield textbox, static_output
reply = None reply = None
for reply in generate_reply(prompt + text, state, stopping_strings=stopping_strings, is_chat=True): for reply in generate_reply(prompt + text, state, stopping_strings=stopping_strings, is_chat=True):
yield (text + reply).lstrip(' '), static_output textbox['text'] = (text + reply).lstrip(' ')
yield textbox, static_output
if shared.stop_everything: if shared.stop_everything:
return return
@ -769,7 +880,9 @@ def generate_chat_reply_wrapper(text, state, regenerate=False, _continue=False):
last_save_time = time.monotonic() last_save_time = time.monotonic()
save_interval = 8 save_interval = 8
for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)): for i, history in enumerate(generate_chat_reply(text, state, regenerate, _continue, loading_message=True, for_ui=True)):
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']), history yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'], last_message_only=(i > 0)), history
if i == 0:
time.sleep(0.125) # We need this to make sure the first update goes through
current_time = time.monotonic() current_time = time.monotonic()
# Save on first iteration or if save_interval seconds have passed # Save on first iteration or if save_interval seconds have passed
@ -800,9 +913,12 @@ def remove_last_message(history):
return html.unescape(last[0]), history return html.unescape(last[0]), history
def send_dummy_message(textbox, state): def send_dummy_message(text, state):
history = state['history'] history = state['history']
text = textbox['text']
# Handle both dict and string inputs
if isinstance(text, dict):
text = text['text']
# Initialize metadata if not present # Initialize metadata if not present
if 'metadata' not in history: if 'metadata' not in history:
@ -816,9 +932,12 @@ def send_dummy_message(textbox, state):
return history return history
def send_dummy_reply(textbox, state): def send_dummy_reply(text, state):
history = state['history'] history = state['history']
text = textbox['text']
# Handle both dict and string inputs
if isinstance(text, dict):
text = text['text']
# Initialize metadata if not present # Initialize metadata if not present
if 'metadata' not in history: if 'metadata' not in history:
@ -1487,76 +1606,76 @@ def handle_edit_message_click(state):
if message_index >= len(history['internal']): if message_index >= len(history['internal']):
html_output = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) html_output = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html_output, gr.update()] return [history, html_output]
# Use the role passed from frontend role_idx = 0 if role == "user" else 1
is_user_msg = (role == "user")
role_idx = 0 if is_user_msg else 1
# For assistant messages, save the original version BEFORE updating content if 'metadata' not in history:
if not is_user_msg: history['metadata'] = {}
if not history['metadata'].get(f"assistant_{message_index}", {}).get('versions'):
add_message_version(history, message_index, is_current=False) key = f"{role}_{message_index}"
if key not in history['metadata']:
history['metadata'][key] = {}
# If no versions exist yet for this message, store the current (pre-edit) content as the first version.
if "versions" not in history['metadata'][key] or not history['metadata'][key]["versions"]:
original_content = history['internal'][message_index][role_idx]
original_visible = history['visible'][message_index][role_idx]
original_timestamp = history['metadata'][key].get('timestamp', get_current_timestamp())
history['metadata'][key]["versions"] = [{
"content": original_content,
"visible_content": original_visible,
"timestamp": original_timestamp
}]
# NOW update the message content
history['internal'][message_index][role_idx] = apply_extensions('input', new_text, state, is_chat=True) history['internal'][message_index][role_idx] = apply_extensions('input', new_text, state, is_chat=True)
history['visible'][message_index][role_idx] = html.escape(new_text) history['visible'][message_index][role_idx] = html.escape(new_text)
# Branch if editing user message, add version if editing assistant message add_message_version(history, role, message_index, is_current=True)
if is_user_msg:
# Branch like branch-here
history['visible'] = history['visible'][:message_index + 1]
history['internal'] = history['internal'][:message_index + 1]
new_unique_id = datetime.now().strftime('%Y%m%d-%H-%M-%S')
save_history(history, new_unique_id, state['character_menu'], state['mode'])
histories = find_all_histories_with_first_prompts(state)
past_chats_update = gr.update(choices=histories, value=new_unique_id)
state['unique_id'] = new_unique_id
elif not is_user_msg:
# Add the new version as current
add_message_version(history, message_index, is_current=True)
past_chats_update = gr.update()
else:
past_chats_update = gr.update()
save_history(history, state['unique_id'], state['character_menu'], state['mode']) save_history(history, state['unique_id'], state['character_menu'], state['mode'])
html_output = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) html_output = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html_output, past_chats_update] return [history, html_output]
def handle_navigate_version_click(state): def handle_navigate_version_click(state):
history = state['history'] history = state['history']
message_index = int(state['navigate_message_index']) message_index = int(state['navigate_message_index'])
direction = state['navigate_direction'] direction = state['navigate_direction']
role = state['navigate_message_role']
# Get assistant message metadata if not role:
key = f"assistant_{message_index}" logger.error("Role not provided for version navigation.")
if key not in history['metadata'] or 'versions' not in history['metadata'][key]: html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
# No versions to navigate return [history, html]
key = f"{role}_{message_index}"
if 'metadata' not in history or key not in history['metadata'] or 'versions' not in history['metadata'][key]:
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html] return [history, html]
metadata = history['metadata'][key] metadata = history['metadata'][key]
current_idx = metadata.get('current_version_index', 0)
versions = metadata['versions'] versions = metadata['versions']
# Default to the last version if current_version_index is not set
current_idx = metadata.get('current_version_index', len(versions) - 1 if versions else 0)
# Calculate new index
if direction == 'left': if direction == 'left':
new_idx = max(0, current_idx - 1) new_idx = max(0, current_idx - 1)
else: # right else: # right
new_idx = min(len(versions) - 1, current_idx + 1) new_idx = min(len(versions) - 1, current_idx + 1)
if new_idx == current_idx: if new_idx == current_idx:
# No change needed
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])
return [history, html] return [history, html]
# Update history with new version msg_content_idx = 0 if role == 'user' else 1 # 0 for user content, 1 for assistant content in the pair
version = versions[new_idx] version_to_load = versions[new_idx]
history['internal'][message_index][1] = version['content'] history['internal'][message_index][msg_content_idx] = version_to_load['content']
history['visible'][message_index][1] = version['visible_content'] history['visible'][message_index][msg_content_idx] = version_to_load['visible_content']
metadata['current_version_index'] = new_idx metadata['current_version_index'] = new_idx
update_message_metadata(history['metadata'], role, message_index, timestamp=version_to_load['timestamp'])
# Redraw and save # Redraw and save
html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu']) html = redraw_html(history, state['name1'], state['name2'], state['mode'], state['chat_style'], state['character_menu'])