diff --git a/css/main.css b/css/main.css index a3fa9753..1545a74b 100644 --- a/css/main.css +++ b/css/main.css @@ -1291,3 +1291,11 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* { .dark .footer-button:hover svg { stroke: rgb(209 213 219); } + +.tgw-accordion { + padding: 10px 12px !important; +} + +.dark .tgw-accordion { + border: 1px solid var(--border-color-dark); +} diff --git a/modules/llama_cpp_server.py b/modules/llama_cpp_server.py index c88f945d..ecc543f3 100644 --- a/modules/llama_cpp_server.py +++ b/modules/llama_cpp_server.py @@ -6,6 +6,7 @@ import subprocess import sys import threading import time +from pathlib import Path import llama_cpp_binaries import requests @@ -281,6 +282,25 @@ class LlamaServer: cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)] if shared.args.rope_freq_base > 0: cmd += ["--rope-freq-base", str(shared.args.rope_freq_base)] + if shared.args.model_draft not in [None, 'None']: + path = Path(shared.args.model_draft) + if not path.exists(): + path = Path(f'{shared.args.model_dir}/{shared.args.model_draft}') + + if path.is_file(): + model_file = path + else: + model_file = sorted(Path(f'{shared.args.model_dir}/{shared.args.model_draft}').glob('*.gguf'))[0] + + cmd += ["--model-draft", model_file] + if shared.args.draft_max > 0: + cmd += ["--draft-max", str(shared.args.draft_max)] + if shared.args.gpu_layers_draft > 0: + cmd += ["--gpu-layers-draft", str(shared.args.gpu_layers_draft)] + if shared.args.device_draft: + cmd += ["--device-draft", shared.args.device_draft] + if shared.args.ctx_size_draft > 0: + cmd += ["--ctx-size-draft", str(shared.args.ctx_size_draft)] env = os.environ.copy() if os.name == 'posix': diff --git a/modules/loaders.py b/modules/loaders.py index 7d6afe80..167b2c98 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -20,6 +20,12 @@ loaders_and_params = OrderedDict({ 'no_mmap', 'mlock', 'numa', + 'model_draft', + 'draft_max', + 'gpu_layers_draft', + 'device_draft', + 'ctx_size_draft', + 'speculative_decoding_accordion' ], 'Transformers': [ 'gpu_split', diff --git a/modules/shared.py b/modules/shared.py index 08268ae0..e531cd3c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -13,6 +13,7 @@ from modules.logging_colors import logger model = None tokenizer = None model_name = 'None' +draft_model_name = 'None' is_seq2seq = False model_dirty_from_training = False lora_names = [] @@ -127,6 +128,14 @@ group.add_argument('--numa', action='store_true', help='Activate NUMA task alloc group.add_argument('--no-kv-offload', action='store_true', help='Do not offload the K, Q, V to the GPU. This saves VRAM but reduces the performance.') group.add_argument('--row-split', action='store_true', help='Split the model by rows across GPUs. This may improve multi-gpu performance.') +# Speculative decoding +group = parser.add_argument_group('Speculative decoding') +group.add_argument('--model-draft', type=str, default=None, help='Path to the draft model for speculative decoding.') +group.add_argument('--draft-max', type=int, default=4, help='Number of tokens to draft for speculative decoding.') +group.add_argument('--gpu-layers-draft', type=int, default=0, help='Number of layers to offload to the GPU for the draft model.') +group.add_argument('--device-draft', type=str, default=None, help='Comma-separated list of devices to use for offloading the draft model.') +group.add_argument('--ctx-size-draft', type=int, default=0, help='Size of the prompt context for the draft model. If 0, uses the same as the main model.') + # ExLlamaV2 group = parser.add_argument_group('ExLlamaV2') group.add_argument('--gpu-split', type=str, help='Comma-separated list of VRAM (in GB) to use per GPU device for model layers. Example: 20,7,7.') diff --git a/modules/training.py b/modules/training.py index c6c380a3..69142463 100644 --- a/modules/training.py +++ b/modules/training.py @@ -52,7 +52,7 @@ def create_ui(): with gr.Column(): always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name is the same, checking will replace the existing file, and unchecking will load and continue from it (the rank must be the same).', elem_classes=['no-background']) - with gr.Accordion(label='Target Modules', open=False): + with gr.Accordion(label='Target Modules', open=False, elem_classes='tgw-accordion'): gr.Markdown("Selects which modules to target in training. Targeting more modules is closer to a full fine-tune at the cost of increased VRAM requirements and adapter size.\nNOTE: Only works for model_id='llama', other types will retain default training behavior and not use these settings.") with gr.Row(): with gr.Column(): @@ -86,7 +86,7 @@ def create_ui(): with gr.Row(): lr_scheduler_type = gr.Dropdown(label='LR Scheduler', value='linear', choices=['linear', 'constant', 'constant_with_warmup', 'cosine', 'cosine_with_restarts', 'polynomial', 'inverse_sqrt'], info='Learning rate scheduler - defines how the learning rate changes over time. "Constant" means never change, "linear" means to go in a straight line from the learning rate down to 0, cosine follows a curve, etc.', elem_classes=['slim-dropdown']) - with gr.Accordion(label='Advanced Options', open=False): + with gr.Accordion(label='Advanced Options', open=False, elem_classes='tgw-accordion'): with gr.Row(): with gr.Column(): lora_dropout = gr.Slider(label='LoRA Dropout', minimum=0.0, maximum=1.0, step=0.025, value=0.05, info='Percentage probability for dropout of LoRA layers. This can help reduce overfitting. Most users should leave at default.') diff --git a/modules/ui.py b/modules/ui.py index d5caaeaa..6fc5e955 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -145,6 +145,11 @@ def list_model_elements(): 'cpp_runner', 'trust_remote_code', 'no_use_fast', + 'model_draft', + 'draft_max', + 'gpu_layers_draft', + 'device_draft', + 'ctx_size_draft', ] return elements diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index b4af771c..1b0c25fa 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -92,6 +92,17 @@ def create_ui(): shared.gradio['exllamav2_info'] = gr.Markdown("ExLlamav2_HF is recommended over ExLlamav2 for better integration with extensions and more consistent sampling behavior across loaders.") shared.gradio['tensorrt_llm_info'] = gr.Markdown('* TensorRT-LLM has to be installed manually in a separate Python 3.10 environment at the moment. For a guide, consult the description of [this PR](https://github.com/oobabooga/text-generation-webui/pull/5715). \n\n* `max_seq_len` is only used when `cpp-runner` is checked.\n\n* `cpp_runner` does not support streaming at the moment.') + # Speculative decoding + with gr.Accordion("Speculative decoding", open=False, elem_classes='tgw-accordion') as shared.gradio['speculative_decoding_accordion']: + with gr.Row(): + shared.gradio['model_draft'] = gr.Dropdown(label="model-draft", choices=utils.get_available_models(), value=lambda: shared.draft_model_name, elem_classes='slim-dropdown', interactive=not mu) + ui.create_refresh_button(shared.gradio['model_draft'], lambda: None, lambda: {'choices': utils.get_available_models()}, 'refresh-button', interactive=not mu) + + shared.gradio['draft_max'] = gr.Number(label="draft-max", precision=0, step=1, value=shared.args.draft_max, info='Number of tokens to draft for speculative decoding.') + shared.gradio['gpu_layers_draft'] = gr.Slider(label="gpu-layers-draft", minimum=0, maximum=256, value=shared.args.gpu_layers_draft, info='Number of layers to offload to the GPU for the draft model.') + shared.gradio['device_draft'] = gr.Textbox(label="device-draft", value=shared.args.device_draft, info='Comma-separated list of devices to use for offloading the draft model.') + shared.gradio['ctx_size_draft'] = gr.Number(label="ctx-size-draft", precision=0, step=256, value=shared.args.ctx_size_draft, info='Size of the prompt context for the draft model. If 0, uses the same as the main model.') + with gr.Column(): with gr.Row(): shared.gradio['autoload_model'] = gr.Checkbox(value=shared.settings['autoload_model'], label='Autoload the model', info='Whether to load the model as soon as it is selected in the Model dropdown.', interactive=not mu)