diff --git a/modules/shared.py b/modules/shared.py index 4e0a20db..a6c0cbe9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -47,6 +47,7 @@ settings = { 'max_new_tokens_max': 4096, 'prompt_lookup_num_tokens': 0, 'max_tokens_second': 0, + 'max_updates_second': 12, 'auto_max_new_tokens': True, 'ban_eos_token': False, 'add_bos_token': True, diff --git a/modules/text_generation.py b/modules/text_generation.py index 00b9275a..962311df 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -65,39 +65,41 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap all_stop_strings += st shared.stop_everything = False + last_update = -1 reply = '' is_stream = state['stream'] if len(all_stop_strings) > 0 and not state['stream']: state = copy.deepcopy(state) state['stream'] = True + min_update_interval = 0 + if state.get('max_updates_second', 0) > 0: + min_update_interval = 1 / state['max_updates_second'] + # Generate - last_update = -1 - latency_threshold = 1 / 1000 for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat): - cur_time = time.monotonic() reply, stop_found = apply_stopping_strings(reply, all_stop_strings) if escape_html: reply = html.escape(reply) if is_stream: + cur_time = time.time() + # Limit number of tokens/second to make text readable in real time if state['max_tokens_second'] > 0: diff = 1 / state['max_tokens_second'] - (cur_time - last_update) if diff > 0: time.sleep(diff) - last_update = time.monotonic() + last_update = time.time() yield reply # Limit updates to avoid lag in the Gradio UI # API updates are not limited else: - # If 'generate_func' takes less than 0.001 seconds to yield the next token - # (equivalent to more than 1000 tok/s), assume that the UI is lagging behind and skip yielding - if (cur_time - last_update) > latency_threshold: + if cur_time - last_update > min_update_interval: + last_update = cur_time yield reply - last_update = time.monotonic() if stop_found or (state['max_tokens_second'] > 0 and shared.stop_everything): break diff --git a/modules/ui.py b/modules/ui.py index eeb6ce92..25f93612 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -192,6 +192,7 @@ def list_interface_input_elements(): 'max_new_tokens', 'prompt_lookup_num_tokens', 'max_tokens_second', + 'max_updates_second', 'do_sample', 'dynamic_temperature', 'temperature_last', diff --git a/modules/ui_parameters.py b/modules/ui_parameters.py index 84f9fbfc..733d0901 100644 --- a/modules/ui_parameters.py +++ b/modules/ui_parameters.py @@ -71,6 +71,8 @@ def create_ui(default_preset): shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], value=shared.settings['max_new_tokens'], step=1, label='max_new_tokens', info='⚠️ Setting this too high can cause prompt truncation.') shared.gradio['prompt_lookup_num_tokens'] = gr.Slider(value=shared.settings['prompt_lookup_num_tokens'], minimum=0, maximum=10, step=1, label='prompt_lookup_num_tokens', info='Activates Prompt Lookup Decoding.') shared.gradio['max_tokens_second'] = gr.Slider(value=shared.settings['max_tokens_second'], minimum=0, maximum=20, step=1, label='Maximum tokens/second', info='To make text readable in real time.') + shared.gradio['max_updates_second'] = gr.Slider(value=shared.settings['max_updates_second'], minimum=0, maximum=24, step=1, label='Maximum UI updates/second', info='Set this if you experience lag in the UI during streaming.') + with gr.Column(): with gr.Row(): with gr.Column(): diff --git a/user_data/settings-template.yaml b/user_data/settings-template.yaml index db481e84..ce0f77e1 100644 --- a/user_data/settings-template.yaml +++ b/user_data/settings-template.yaml @@ -18,6 +18,7 @@ max_new_tokens_min: 1 max_new_tokens_max: 4096 prompt_lookup_num_tokens: 0 max_tokens_second: 0 +max_updates_second: 12 auto_max_new_tokens: true ban_eos_token: false add_bos_token: true