From d4017fbb6d0b9be7a8964ad3fa03db0b373e453d Mon Sep 17 00:00:00 2001 From: oobabooga Date: Fri, 25 Apr 2025 21:32:00 -0300 Subject: [PATCH] ExLlamaV3: Add kv cache quantization (#6903) --- modules/exllamav3_hf.py | 29 ++++++++++++++++++++++++++++- modules/loaders.py | 2 ++ modules/shared.py | 2 +- modules/ui_model_menu.py | 2 +- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/modules/exllamav3_hf.py b/modules/exllamav3_hf.py index 24ba9e13..f15fc0b2 100644 --- a/modules/exllamav3_hf.py +++ b/modules/exllamav3_hf.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Union import torch from exllamav3 import Cache, Config, Model +from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant from torch.nn import CrossEntropyLoss from transformers import ( GenerationConfig, @@ -39,7 +40,33 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin): logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}") max_tokens = adjusted_tokens - self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens) + # Parse cache type + cache_type = shared.args.cache_type.lower() + cache_kwargs = {} + if cache_type == 'fp16': + layer_type = CacheLayer_fp16 + elif cache_type.startswith('q'): + layer_type = CacheLayer_quant + if '_' in cache_type: + # Different bits for k and v (e.g., q4_q8) + k_part, v_part = cache_type.split('_') + k_bits = int(k_part[1:]) + v_bits = int(v_part[1:]) + else: + # Same bits for k and v (e.g., q4) + k_bits = v_bits = int(cache_type[1:]) + + # Validate bit ranges + if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8): + logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.") + layer_type = CacheLayer_fp16 + else: + cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits} + else: + logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.") + layer_type = CacheLayer_fp16 + + self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs) # Create load parameters dictionary load_params = {'progressbar': True} diff --git a/modules/loaders.py b/modules/loaders.py index d8d62bf9..062e4837 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -13,6 +13,7 @@ loaders_and_params = OrderedDict({ 'cache_type', 'tensor_split', 'extra_flags', + 'streaming_llm', 'rope_freq_base', 'compress_pos_emb', 'flash_attn', @@ -49,6 +50,7 @@ loaders_and_params = OrderedDict({ ], 'ExLlamav3_HF': [ 'ctx_size', + 'cache_type', 'gpu_split', 'cfg_cache', 'trust_remote_code', diff --git a/modules/shared.py b/modules/shared.py index 572bfc09..96f65929 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -132,7 +132,7 @@ group.add_argument('--streaming-llm', action='store_true', help='Activate Stream # Cache group = parser.add_argument_group('Context and cache management') group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, help='Context size in tokens.') -group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') +group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).') # Speculative decoding group = parser.add_argument_group('Speculative decoding') diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 9aeb02d1..6bd647c6 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -52,7 +52,7 @@ def create_ui(): shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size) shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend) shared.gradio['ctx_size'] = gr.Number(label='ctx_size', precision=0, step=256, value=shared.args.ctx_size, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768, 65536.') - shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') + shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).') shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1;flag2;flag3=value3". Example: "override-tensor=exps=CPU"')