diff --git a/modules/text_generation.py b/modules/text_generation.py index f302a918..a29b987f 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -101,7 +101,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) else: - yield formatted_outputs(question, shared.model_name) + if not (shared.args.chat or shared.args.cai_chat): + yield formatted_outputs(question, shared.model_name) # RWKV has proper streaming, which is very nice. # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): @@ -197,7 +198,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) - yield formatted_outputs(original_question, shared.model_name) + if not (shared.args.chat or shared.args.cai_chat): + yield formatted_outputs(original_question, shared.model_name) with generate_with_streaming(**generate_params) as generator: for output in generator: if shared.soft_prompt: