mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 14:17:09 -04:00
ExLlamaV3: Add kv cache quantization (#6903)
This commit is contained in:
parent
d4b1e31c49
commit
d4017fbb6d
4 changed files with 32 additions and 3 deletions
|
@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Union
|
|||
|
||||
import torch
|
||||
from exllamav3 import Cache, Config, Model
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import (
|
||||
GenerationConfig,
|
||||
|
@ -39,7 +40,33 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
|||
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
||||
max_tokens = adjusted_tokens
|
||||
|
||||
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens)
|
||||
# Parse cache type
|
||||
cache_type = shared.args.cache_type.lower()
|
||||
cache_kwargs = {}
|
||||
if cache_type == 'fp16':
|
||||
layer_type = CacheLayer_fp16
|
||||
elif cache_type.startswith('q'):
|
||||
layer_type = CacheLayer_quant
|
||||
if '_' in cache_type:
|
||||
# Different bits for k and v (e.g., q4_q8)
|
||||
k_part, v_part = cache_type.split('_')
|
||||
k_bits = int(k_part[1:])
|
||||
v_bits = int(v_part[1:])
|
||||
else:
|
||||
# Same bits for k and v (e.g., q4)
|
||||
k_bits = v_bits = int(cache_type[1:])
|
||||
|
||||
# Validate bit ranges
|
||||
if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8):
|
||||
logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.")
|
||||
layer_type = CacheLayer_fp16
|
||||
else:
|
||||
cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits}
|
||||
else:
|
||||
logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.")
|
||||
layer_type = CacheLayer_fp16
|
||||
|
||||
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||
|
||||
# Create load parameters dictionary
|
||||
load_params = {'progressbar': True}
|
||||
|
|
|
@ -13,6 +13,7 @@ loaders_and_params = OrderedDict({
|
|||
'cache_type',
|
||||
'tensor_split',
|
||||
'extra_flags',
|
||||
'streaming_llm',
|
||||
'rope_freq_base',
|
||||
'compress_pos_emb',
|
||||
'flash_attn',
|
||||
|
@ -49,6 +50,7 @@ loaders_and_params = OrderedDict({
|
|||
],
|
||||
'ExLlamav3_HF': [
|
||||
'ctx_size',
|
||||
'cache_type',
|
||||
'gpu_split',
|
||||
'cfg_cache',
|
||||
'trust_remote_code',
|
||||
|
|
|
@ -132,7 +132,7 @@ group.add_argument('--streaming-llm', action='store_true', help='Activate Stream
|
|||
# Cache
|
||||
group = parser.add_argument_group('Context and cache management')
|
||||
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, help='Context size in tokens.')
|
||||
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
||||
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
|
||||
|
||||
# Speculative decoding
|
||||
group = parser.add_argument_group('Speculative decoding')
|
||||
|
|
|
@ -52,7 +52,7 @@ def create_ui():
|
|||
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
|
||||
shared.gradio['ctx_size'] = gr.Number(label='ctx_size', precision=0, step=256, value=shared.args.ctx_size, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768, 65536.')
|
||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
|
||||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1;flag2;flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||
|
|
Loading…
Add table
Reference in a new issue