mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-08 14:46:14 -04:00
Perplexity colors extension updates (#6764)
This commit is contained in:
parent
5bcd2d7ad0
commit
60d67994d9
1 changed files with 55 additions and 25 deletions
|
@ -96,23 +96,42 @@ def logits_processor_modifier(logits_processor_list, input_ids):
|
||||||
logits_processor_list.append(ppl_logits_processor)
|
logits_processor_list.append(ppl_logits_processor)
|
||||||
|
|
||||||
|
|
||||||
|
def get_last_token(text, tokens_list, token_ids_list, token_probs_list):
|
||||||
|
for token, token_id, prob in zip(tokens_list, token_ids_list, token_probs_list):
|
||||||
|
if text.strip().endswith(token.strip()): # Whitespace could be a problem
|
||||||
|
return token, token_id, prob
|
||||||
|
# Unknown?
|
||||||
|
print("Last token not found in list:", tokens_list)
|
||||||
|
return '', -1, 0.0
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(text):
|
def output_modifier(text):
|
||||||
global ppl_logits_processor
|
global ppl_logits_processor
|
||||||
#t0 = time.time()
|
#t0 = time.time()
|
||||||
|
original_text = text
|
||||||
|
|
||||||
if not params['active'] or ppl_logits_processor is None:
|
if not params['active'] or ppl_logits_processor is None:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
# Space at the beginning to account for tokenization spaces...
|
||||||
|
text = ' ' + html.unescape(text)
|
||||||
|
|
||||||
# TODO: It's probably more efficient to do this above rather than modifying all these lists
|
# TODO: It's probably more efficient to do this above rather than modifying all these lists
|
||||||
# Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation
|
# Remove last element of perplexities_list, top_token_ids_list, top_tokens_list, top_probs_list since everything is off by one because this extension runs before generation
|
||||||
perplexities = ppl_logits_processor.perplexities_list[:-1]
|
perplexities = ppl_logits_processor.perplexities_list
|
||||||
top_token_ids_list = ppl_logits_processor.top_token_ids_list[:-1]
|
top_token_ids_list = ppl_logits_processor.top_token_ids_list
|
||||||
top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids[0]] for top_token_ids in top_token_ids_list]
|
top_tokens_list = [[shared.tokenizer.decode(token_id) for token_id in top_token_ids[0]] for top_token_ids in top_token_ids_list]
|
||||||
top_probs_list = ppl_logits_processor.top_probs_list[:-1]
|
top_probs_list = ppl_logits_processor.top_probs_list
|
||||||
# Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt
|
# Remove first element of generated_token_ids, generated_tokens, selected_probs because they are for the last token of the prompt
|
||||||
gen_token_ids = ppl_logits_processor.generated_token_ids[1:]
|
gen_token_ids = ppl_logits_processor.generated_token_ids[1:]
|
||||||
|
# Add last sampled token, if possible (it could be past the end of the top 5 list)
|
||||||
|
last_token, last_token_id, last_prob = get_last_token(text, top_tokens_list[-1], top_token_ids_list[-1][0], top_probs_list[-1][0])
|
||||||
|
if last_token_id != -1:
|
||||||
|
gen_token_ids.append(last_token_id)
|
||||||
gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids]
|
gen_tokens = [shared.tokenizer.decode(token_id) for token_id in gen_token_ids]
|
||||||
sel_probs = ppl_logits_processor.selected_probs[1:]
|
sel_probs = ppl_logits_processor.selected_probs[1:]
|
||||||
|
if last_token_id != -1:
|
||||||
|
sel_probs.append(last_prob)
|
||||||
|
|
||||||
end_part = '</div></div>' if params['probability_dropdown'] else '</span>' # Helps with finding the index after replacing part of the text.
|
end_part = '</div></div>' if params['probability_dropdown'] else '</span>' # Helps with finding the index after replacing part of the text.
|
||||||
|
|
||||||
|
@ -120,8 +139,7 @@ def output_modifier(text):
|
||||||
# Used to find where the message started generating, for working with "continue" generations
|
# Used to find where the message started generating, for working with "continue" generations
|
||||||
# Doesn't work for longer messages... Not sure how I should handle this
|
# Doesn't work for longer messages... Not sure how I should handle this
|
||||||
full_msg = shared.tokenizer.decode([token_id for token_id in gen_token_ids[:-1]]).strip()
|
full_msg = shared.tokenizer.decode([token_id for token_id in gen_token_ids[:-1]]).strip()
|
||||||
# Space at the beginning to account for tokenization spaces...
|
|
||||||
text = ' ' + html.unescape(text)
|
|
||||||
# There was an issue with tab lengths being off by one...
|
# There was an issue with tab lengths being off by one...
|
||||||
# Seems like it might be model-dependent...
|
# Seems like it might be model-dependent...
|
||||||
#text = re.sub(r'( {3,})', r'\1 ', text)
|
#text = re.sub(r'( {3,})', r'\1 ', text)
|
||||||
|
@ -137,6 +155,7 @@ def output_modifier(text):
|
||||||
#i = 0
|
#i = 0
|
||||||
# Add token index for ability to regenerate from there
|
# Add token index for ability to regenerate from there
|
||||||
nonwhitespace_token_found = False
|
nonwhitespace_token_found = False
|
||||||
|
missing_token_count = 0
|
||||||
for index, token, prob, ppl, top_tokens, top_probs in zip(range(len(gen_tokens)), gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list):
|
for index, token, prob, ppl, top_tokens, top_probs in zip(range(len(gen_tokens)), gen_tokens, sel_probs, perplexities, top_tokens_list, top_probs_list):
|
||||||
# Somehow this works without issues, but not sure how...
|
# Somehow this works without issues, but not sure how...
|
||||||
if not nonwhitespace_token_found and token.strip() == '':
|
if not nonwhitespace_token_found and token.strip() == '':
|
||||||
|
@ -153,14 +172,20 @@ def output_modifier(text):
|
||||||
color = probability_color_scale(prob)
|
color = probability_color_scale(prob)
|
||||||
if token.strip() in text[i:]:
|
if token.strip() in text[i:]:
|
||||||
if params['probability_dropdown']:
|
if params['probability_dropdown']:
|
||||||
text = text[:i] + text[i:].replace(token.replace('\n', ''), add_dropdown_html(token, index, color, top_tokens, top_probs[0], ppl), 1)
|
text = text[:i] + text[i:].replace(token.replace('\n', ''), add_dropdown_html(token, index, i, color, top_tokens, top_probs[0], ppl), 1)
|
||||||
else:
|
else:
|
||||||
text = text[:i] + text[i:].replace(token.replace('\n', ''), add_color_html(token, color), 1)
|
text = text[:i] + text[i:].replace(token.replace('\n', ''), add_color_html(token, color), 1)
|
||||||
|
|
||||||
# This might be slightly inefficient
|
# This might be slightly inefficient
|
||||||
i += text[i:].find(end_part) + len(end_part)
|
i += text[i:].find(end_part) + len(end_part)
|
||||||
else:
|
else:
|
||||||
|
missing_token_count += 1
|
||||||
print('Missing token:', token, '...', text[i:i+20])
|
print('Missing token:', token, '...', text[i:i+20])
|
||||||
|
# If there are any missing tokens, then either the tokenization was off, or this is the start of a conversation, or something else went wrong
|
||||||
|
if missing_token_count > 5:
|
||||||
|
print("Canceling token coloring...")
|
||||||
|
return original_text
|
||||||
|
|
||||||
|
|
||||||
# Use full perplexity list for calculating the average here.
|
# Use full perplexity list for calculating the average here.
|
||||||
# Fix issue with mean of empty slice
|
# Fix issue with mean of empty slice
|
||||||
|
@ -236,11 +261,11 @@ def add_color_html(token, color):
|
||||||
# I think the issue is from HTML elements taking up space in the visible history, and things like history deepcopy add latency proportional to the size of the history.
|
# I think the issue is from HTML elements taking up space in the visible history, and things like history deepcopy add latency proportional to the size of the history.
|
||||||
# Potential solution is maybe to modify the main generation code to send just the internal text and not the visible history, to avoid moving too much around.
|
# Potential solution is maybe to modify the main generation code to send just the internal text and not the visible history, to avoid moving too much around.
|
||||||
# I wonder if we can also avoid using deepcopy here.
|
# I wonder if we can also avoid using deepcopy here.
|
||||||
def add_dropdown_html(token, index, color, top_tokens, top_probs, perplexity=0):
|
def add_dropdown_html(token, index, msg_position, color, top_tokens, top_probs, perplexity=0):
|
||||||
#print("Token:", token, token.isspace(), '\n' in token or '\r' in token)
|
#print("Token:", token, token.isspace(), '\n' in token or '\r' in token)
|
||||||
output = ''
|
output = ''
|
||||||
# Use the repr to get characters like \n visible. Exclude the quotes around it
|
# Use the repr to get characters like \n visible. Exclude the quotes around it
|
||||||
output += f'<div class="hoverable" id="tok_{index}"><span style="color: #{color}">{html.escape(repr(token)[1:-1])}</span><div class="dropdown"><table class="dropdown-content"><tbody>'
|
output += f'<div class="hoverable" name="tok_{index}_{msg_position}"><span style="color: #{color}">{html.escape(repr(token)[1:-1])}</span><div class="dropdown"><table class="dropdown-content"><tbody>'
|
||||||
for i, token_option, prob in zip(range(len(top_tokens)), top_tokens, top_probs):
|
for i, token_option, prob in zip(range(len(top_tokens)), top_tokens, top_probs):
|
||||||
# TODO: Bold for selected token?
|
# TODO: Bold for selected token?
|
||||||
# Using divs prevented the problem of divs inside spans causing issues.
|
# Using divs prevented the problem of divs inside spans causing issues.
|
||||||
|
@ -249,7 +274,7 @@ def add_dropdown_html(token, index, color, top_tokens, top_probs, perplexity=0):
|
||||||
row_color = probability_color_scale(prob)
|
row_color = probability_color_scale(prob)
|
||||||
row_class = ' class="selected"' if token_option == token else ''
|
row_class = ' class="selected"' if token_option == token else ''
|
||||||
# This time we want to include the quotes around it so that we can see where the spaces are.
|
# This time we want to include the quotes around it so that we can see where the spaces are.
|
||||||
output += f'<tr{row_class}><td id="opt_{index}_{i}" style="color: #{row_color}">{html.escape(repr(token_option))}</td><td style="color: #{row_color}">{prob:.4f}</td></tr>'
|
output += f'<tr{row_class}><td name="opt_{index}_{i}_{msg_position}" style="color: #{row_color}">{html.escape(repr(token_option))}</td><td style="color: #{row_color}">{prob:.4f}</td></tr>'
|
||||||
if perplexity != 0:
|
if perplexity != 0:
|
||||||
ppl_color = perplexity_color_scale(perplexity)
|
ppl_color = perplexity_color_scale(perplexity)
|
||||||
output += f'<tr><td>Perplexity:</td><td style="color: #{ppl_color}">{perplexity:.4f}</td></tr>'
|
output += f'<tr><td>Perplexity:</td><td style="color: #{ppl_color}">{perplexity:.4f}</td></tr>'
|
||||||
|
@ -324,11 +349,12 @@ function sleep(ms) {
|
||||||
// Note that this will only work as intended on the last agent message
|
// Note that this will only work as intended on the last agent message
|
||||||
document.addEventListener("click", async function(event) {
|
document.addEventListener("click", async function(event) {
|
||||||
//console.log(event.target);
|
//console.log(event.target);
|
||||||
const id = event.target.id;
|
const name = event.target.getAttribute("name");
|
||||||
if (id.includes("opt_")) {
|
if (name != null && name.includes("opt_")) {
|
||||||
const id_parts = id.split("_");
|
const name_parts = name.split("_");
|
||||||
const token_index = id_parts[1];
|
const token_index = name_parts[1];
|
||||||
const option_index = id_parts[2];
|
const option_index = name_parts[2];
|
||||||
|
const msg_pos = name_parts[3];
|
||||||
// Exclude the quotes and convert newlines... Not sure about the newlines though
|
// Exclude the quotes and convert newlines... Not sure about the newlines though
|
||||||
// TODO: Seems like continuing generation from a newline causes problems whether you add it or not!
|
// TODO: Seems like continuing generation from a newline causes problems whether you add it or not!
|
||||||
const token_string = event.target.innerHTML.substring(1, event.target.innerHTML.length-1).replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '');
|
const token_string = event.target.innerHTML.substring(1, event.target.innerHTML.length-1).replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '');
|
||||||
|
@ -341,8 +367,11 @@ document.addEventListener("click", async function(event) {
|
||||||
var msg_part = msg_parts[i];
|
var msg_part = msg_parts[i];
|
||||||
if (msg_part.nodeType === Node.ELEMENT_NODE) {
|
if (msg_part.nodeType === Node.ELEMENT_NODE) {
|
||||||
if (msg_part.nodeName == "DIV") {
|
if (msg_part.nodeName == "DIV") {
|
||||||
var current_token_index = msg_part.id.split("_")[1];
|
msg_part_name = msg_part.getAttribute("name")
|
||||||
if (current_token_index == token_index) {
|
if (msg_part_name != null) {
|
||||||
|
var current_token_index = msg_part_name.split("_")[1];
|
||||||
|
var current_message_pos = msg_part_name.split("_")[2];
|
||||||
|
if (current_token_index == token_index && current_message_pos == msg_pos) {
|
||||||
// Use the replacement token
|
// Use the replacement token
|
||||||
// TODO: Don't have access to the tokenizer here, and sometimes there needs to be a space added before this token
|
// TODO: Don't have access to the tokenizer here, and sometimes there needs to be a space added before this token
|
||||||
msg_text += token_string //.replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '');
|
msg_text += token_string //.replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"r", "g"), '').replace(new RegExp(String.fromCharCode(92)+String.fromCharCode(92)+"n", "g"), '');
|
||||||
|
@ -354,6 +383,7 @@ document.addEventListener("click", async function(event) {
|
||||||
msg_text += text;
|
msg_text += text;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
else {
|
else {
|
||||||
// Break tag (hacky workaround because the newline literal can't be parsed here)
|
// Break tag (hacky workaround because the newline literal can't be parsed here)
|
||||||
//msg_text += String.fromCharCode(10);
|
//msg_text += String.fromCharCode(10);
|
||||||
|
|
Loading…
Add table
Reference in a new issue