Store previous reply versions on regenerate (#7004)

This commit is contained in:
oobabooga 2025-05-20 12:51:28 -03:00 committed by GitHub
parent c25a381540
commit 616ea6966d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -365,6 +365,34 @@ def get_stopping_strings(state):
return result
def add_message_version(history, row_idx, is_current=True):
"""Add the current message as a version in the history metadata"""
if 'metadata' not in history:
history['metadata'] = {}
if row_idx >= len(history['internal']) or not history['internal'][row_idx][1].strip():
return # Skip if row doesn't exist or message is empty
key = f"assistant_{row_idx}"
# Initialize metadata structures if needed
if key not in history['metadata']:
history['metadata'][key] = {"timestamp": get_current_timestamp()}
if "versions" not in history['metadata'][key]:
history['metadata'][key]["versions"] = []
# Add current message as a version
history['metadata'][key]["versions"].append({
"content": history['internal'][row_idx][1],
"visible_content": history['visible'][row_idx][1],
"timestamp": get_current_timestamp()
})
# Update index if this is the current version
if is_current:
history['metadata'][key]["current_version_index"] = len(history['metadata'][key]["versions"]) - 1
def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_message=True, for_ui=False):
history = state['history']
output = copy.deepcopy(history)
@ -405,6 +433,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
text, visible_text = output['internal'][-1][0], output['visible'][-1][0]
if regenerate:
row_idx = len(output['internal']) - 1
# Store the existing response as a version before regenerating
add_message_version(output, row_idx, is_current=False)
if loading_message:
yield {
'visible': output['visible'][:-1] + [[visible_text, shared.processing_message]],
@ -465,6 +497,11 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
if is_stream:
yield output
# Add the newly generated response as a version (only for regeneration)
if regenerate:
row_idx = len(output['internal']) - 1
add_message_version(output, row_idx, is_current=True)
output['visible'][-1][1] = apply_extensions('output', output['visible'][-1][1], state, is_chat=True)
yield output