ExLlamaV3: Add kv cache quantization (#6903)

This commit is contained in:
oobabooga 2025-04-25 21:32:00 -03:00 committed by GitHub
parent d4b1e31c49
commit d4017fbb6d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 32 additions and 3 deletions

View file

@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Union
import torch
from exllamav3 import Cache, Config, Model
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
from torch.nn import CrossEntropyLoss
from transformers import (
GenerationConfig,
@ -39,7 +40,33 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
max_tokens = adjusted_tokens
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens)
# Parse cache type
cache_type = shared.args.cache_type.lower()
cache_kwargs = {}
if cache_type == 'fp16':
layer_type = CacheLayer_fp16
elif cache_type.startswith('q'):
layer_type = CacheLayer_quant
if '_' in cache_type:
# Different bits for k and v (e.g., q4_q8)
k_part, v_part = cache_type.split('_')
k_bits = int(k_part[1:])
v_bits = int(v_part[1:])
else:
# Same bits for k and v (e.g., q4)
k_bits = v_bits = int(cache_type[1:])
# Validate bit ranges
if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8):
logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.")
layer_type = CacheLayer_fp16
else:
cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits}
else:
logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.")
layer_type = CacheLayer_fp16
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
# Create load parameters dictionary
load_params = {'progressbar': True}

View file

@ -13,6 +13,7 @@ loaders_and_params = OrderedDict({
'cache_type',
'tensor_split',
'extra_flags',
'streaming_llm',
'rope_freq_base',
'compress_pos_emb',
'flash_attn',
@ -49,6 +50,7 @@ loaders_and_params = OrderedDict({
],
'ExLlamav3_HF': [
'ctx_size',
'cache_type',
'gpu_split',
'cfg_cache',
'trust_remote_code',

View file

@ -132,7 +132,7 @@ group.add_argument('--streaming-llm', action='store_true', help='Activate Stream
# Cache
group = parser.add_argument_group('Context and cache management')
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, help='Context size in tokens.')
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
# Speculative decoding
group = parser.add_argument_group('Speculative decoding')

View file

@ -52,7 +52,7 @@ def create_ui():
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
shared.gradio['ctx_size'] = gr.Number(label='ctx_size', precision=0, step=256, value=shared.args.ctx_size, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768, 65536.')
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1;flag2;flag3=value3". Example: "override-tensor=exps=CPU"')