mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-08 06:35:57 -04:00
ExLlamaV3: Add kv cache quantization (#6903)
This commit is contained in:
parent
d4b1e31c49
commit
d4017fbb6d
4 changed files with 32 additions and 3 deletions
|
@ -5,6 +5,7 @@ from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from exllamav3 import Cache, Config, Model
|
from exllamav3 import Cache, Config, Model
|
||||||
|
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers import (
|
from transformers import (
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
|
@ -39,7 +40,33 @@ class Exllamav3HF(PreTrainedModel, GenerationMixin):
|
||||||
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
logger.warning(f"max_num_tokens must be a multiple of 256. Adjusting from {max_tokens} to {adjusted_tokens}")
|
||||||
max_tokens = adjusted_tokens
|
max_tokens = adjusted_tokens
|
||||||
|
|
||||||
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens)
|
# Parse cache type
|
||||||
|
cache_type = shared.args.cache_type.lower()
|
||||||
|
cache_kwargs = {}
|
||||||
|
if cache_type == 'fp16':
|
||||||
|
layer_type = CacheLayer_fp16
|
||||||
|
elif cache_type.startswith('q'):
|
||||||
|
layer_type = CacheLayer_quant
|
||||||
|
if '_' in cache_type:
|
||||||
|
# Different bits for k and v (e.g., q4_q8)
|
||||||
|
k_part, v_part = cache_type.split('_')
|
||||||
|
k_bits = int(k_part[1:])
|
||||||
|
v_bits = int(v_part[1:])
|
||||||
|
else:
|
||||||
|
# Same bits for k and v (e.g., q4)
|
||||||
|
k_bits = v_bits = int(cache_type[1:])
|
||||||
|
|
||||||
|
# Validate bit ranges
|
||||||
|
if not (2 <= k_bits <= 8 and 2 <= v_bits <= 8):
|
||||||
|
logger.warning(f"Invalid quantization bits: k_bits={k_bits}, v_bits={v_bits}. Must be between 2 and 8. Falling back to fp16.")
|
||||||
|
layer_type = CacheLayer_fp16
|
||||||
|
else:
|
||||||
|
cache_kwargs = {'k_bits': k_bits, 'v_bits': v_bits}
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unrecognized cache type: {cache_type}. Falling back to fp16.")
|
||||||
|
layer_type = CacheLayer_fp16
|
||||||
|
|
||||||
|
self.ex_cache = Cache(self.ex_model, max_num_tokens=max_tokens, layer_type=layer_type, **cache_kwargs)
|
||||||
|
|
||||||
# Create load parameters dictionary
|
# Create load parameters dictionary
|
||||||
load_params = {'progressbar': True}
|
load_params = {'progressbar': True}
|
||||||
|
|
|
@ -13,6 +13,7 @@ loaders_and_params = OrderedDict({
|
||||||
'cache_type',
|
'cache_type',
|
||||||
'tensor_split',
|
'tensor_split',
|
||||||
'extra_flags',
|
'extra_flags',
|
||||||
|
'streaming_llm',
|
||||||
'rope_freq_base',
|
'rope_freq_base',
|
||||||
'compress_pos_emb',
|
'compress_pos_emb',
|
||||||
'flash_attn',
|
'flash_attn',
|
||||||
|
@ -49,6 +50,7 @@ loaders_and_params = OrderedDict({
|
||||||
],
|
],
|
||||||
'ExLlamav3_HF': [
|
'ExLlamav3_HF': [
|
||||||
'ctx_size',
|
'ctx_size',
|
||||||
|
'cache_type',
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
'cfg_cache',
|
'cfg_cache',
|
||||||
'trust_remote_code',
|
'trust_remote_code',
|
||||||
|
|
|
@ -132,7 +132,7 @@ group.add_argument('--streaming-llm', action='store_true', help='Activate Stream
|
||||||
# Cache
|
# Cache
|
||||||
group = parser.add_argument_group('Context and cache management')
|
group = parser.add_argument_group('Context and cache management')
|
||||||
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, help='Context size in tokens.')
|
group.add_argument('--ctx-size', '--n_ctx', '--max_seq_len', type=int, default=8192, help='Context size in tokens.')
|
||||||
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8 (can specify k_bits and v_bits separately, e.g. q4_q8).')
|
||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
group = parser.add_argument_group('Speculative decoding')
|
group = parser.add_argument_group('Speculative decoding')
|
||||||
|
|
|
@ -52,7 +52,7 @@ def create_ui():
|
||||||
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
shared.gradio['batch_size'] = gr.Slider(label="batch_size", minimum=1, maximum=4096, step=1, value=shared.args.batch_size)
|
||||||
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
|
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
|
||||||
shared.gradio['ctx_size'] = gr.Number(label='ctx_size', precision=0, step=256, value=shared.args.ctx_size, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768, 65536.')
|
shared.gradio['ctx_size'] = gr.Number(label='ctx_size', precision=0, step=256, value=shared.args.ctx_size, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768, 65536.')
|
||||||
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
|
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q7', 'q6', 'q5', 'q4', 'q3', 'q2'], value=shared.args.cache_type, allow_custom_value=True, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; ExLlamaV3 - fp16, q2 to q8. For ExLlamaV3, you can type custom combinations for separate k/v bits (e.g. q4_q8).')
|
||||||
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
|
||||||
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
|
||||||
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1;flag2;flag3=value3". Example: "override-tensor=exps=CPU"')
|
shared.gradio['extra_flags'] = gr.Textbox(label='extra-flags', info='Additional flags to pass to llama-server. Format: "flag1=value1;flag2;flag3=value3". Example: "override-tensor=exps=CPU"')
|
||||||
|
|
Loading…
Add table
Reference in a new issue