diff --git a/server.py b/server.py index c079a2a..fd2bf7b 100644 --- a/server.py +++ b/server.py @@ -247,8 +247,15 @@ def fix_galactica(s): s = s.replace(r'$$', r'$') return s +def get_max_prompt_length(tokens): + global soft_prompt, soft_prompt_tensor + max_length = 2048-tokens + if soft_prompt: + max_length -= soft_prompt_tensor.shape[1] + return max_length + def encode(prompt, tokens_to_generate=0, add_special_tokens=True): - input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=2048-tokens_to_generate, add_special_tokens=add_special_tokens) + input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens) if args.cpu: return input_ids elif args.deepspeed: @@ -497,7 +504,8 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe rows = [f"{context.strip()}\n"] i = len(history['internal'])-1 count = 0 - while i >= 0 and len(encode(''.join(rows), tokens)[0]) < 2048-tokens: + max_length = get_max_prompt_length(tokens) + while i >= 0 and len(encode(''.join(rows), tokens)[0]) < max_length: rows.insert(1, f"{name2}: {history['internal'][i][1].strip()}\n") count += 1 if not (history['internal'][i][0] == '<|BEGIN-VISIBLE-CHAT|>'): @@ -515,7 +523,7 @@ def generate_chat_prompt(text, tokens, name1, name2, context, history_size, impe rows.append(f"{name1}:") limit = 2 - while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= 2048-tokens: + while len(rows) > limit and len(encode(''.join(rows), tokens)[0]) >= max_length: rows.pop(1) rows.pop(1)