mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-06-07 06:06:20 -04:00
Estimate the VRAM for GGUF models + autoset gpu-layers
(#6980)
This commit is contained in:
parent
c4a715fd1e
commit
5534d01da0
6 changed files with 193 additions and 4 deletions
14
css/main.css
14
css/main.css
|
@ -569,7 +569,7 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
|||
}
|
||||
|
||||
.dark .message-body :not(pre) > code {
|
||||
background-color: rgb(255 255 255 / 12.5%);
|
||||
background-color: rgb(255 255 255 / 10%);
|
||||
}
|
||||
|
||||
#chat-input {
|
||||
|
@ -1386,3 +1386,15 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
|||
strong {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.min.svelte-1ybaih5 {
|
||||
min-height: 0;
|
||||
}
|
||||
|
||||
#vram-info .value {
|
||||
color: #008d00;
|
||||
}
|
||||
|
||||
.dark #vram-info .value {
|
||||
color: #07ff07;
|
||||
}
|
||||
|
|
|
@ -282,8 +282,10 @@ class LlamaServer:
|
|||
cmd.append("--no-kv-offload")
|
||||
if shared.args.row_split:
|
||||
cmd += ["--split-mode", "row"]
|
||||
cache_type = "fp16"
|
||||
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
|
||||
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
|
||||
cache_type = shared.args.cache_type
|
||||
if shared.args.compress_pos_emb != 1:
|
||||
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
|
||||
if shared.args.rope_freq_base > 0:
|
||||
|
@ -343,6 +345,7 @@ class LlamaServer:
|
|||
print(' '.join(str(item) for item in cmd[1:]))
|
||||
print()
|
||||
|
||||
logger.info(f"Using gpu_layers={shared.args.gpu_layers} | ctx_size={shared.args.ctx_size} | cache_type={cache_type}")
|
||||
# Start the server with pipes for output
|
||||
self.process = subprocess.Popen(
|
||||
cmd,
|
||||
|
|
|
@ -71,7 +71,6 @@ def llama_cpp_server_loader(model_name):
|
|||
else:
|
||||
model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
|
||||
|
||||
logger.info(f"llama.cpp weights detected: \"{model_file}\"")
|
||||
try:
|
||||
model = LlamaServer(model_file)
|
||||
return model, model
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import functools
|
||||
import json
|
||||
import re
|
||||
import subprocess
|
||||
from math import exp
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import yaml
|
||||
|
||||
from modules import chat, loaders, metadata_gguf, shared, ui
|
||||
|
@ -216,7 +220,17 @@ def apply_model_settings_to_state(model, state):
|
|||
|
||||
for k in model_settings:
|
||||
if k in state:
|
||||
state[k] = model_settings[k]
|
||||
if k == 'gpu_layers':
|
||||
available_vram = get_nvidia_free_vram()
|
||||
n_layers = model_settings[k]
|
||||
if available_vram > 0:
|
||||
tolerance = 906
|
||||
while n_layers > 0 and estimate_vram(model, n_layers, state['ctx_size'], state['cache_type']) > available_vram - tolerance:
|
||||
n_layers -= 1
|
||||
|
||||
state[k] = gr.update(value=n_layers, maximum=model_settings[k])
|
||||
else:
|
||||
state[k] = model_settings[k]
|
||||
|
||||
return state
|
||||
|
||||
|
@ -277,3 +291,138 @@ def save_instruction_template(model, template):
|
|||
yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.")
|
||||
else:
|
||||
yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.")
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def get_gguf_metadata_cached(model_file):
|
||||
return metadata_gguf.load_metadata(model_file)
|
||||
|
||||
|
||||
def get_model_size_mb(model_file: Path) -> float:
|
||||
filename = model_file.name
|
||||
|
||||
# Check for multipart pattern
|
||||
match = re.match(r'(.+)-\d+-of-\d+\.gguf$', filename)
|
||||
|
||||
if match:
|
||||
# It's a multipart file, find all matching parts
|
||||
base_pattern = match.group(1)
|
||||
part_files = sorted(model_file.parent.glob(f'{base_pattern}-*-of-*.gguf'))
|
||||
total_size = sum(p.stat().st_size for p in part_files)
|
||||
else:
|
||||
# Single part
|
||||
total_size = model_file.stat().st_size
|
||||
|
||||
return total_size / (1024 ** 2) # Return size in MB
|
||||
|
||||
|
||||
def estimate_vram(gguf_file, gpu_layers, ctx_size, cache_type):
|
||||
model_file = Path(f'{shared.args.model_dir}/{gguf_file}')
|
||||
metadata = get_gguf_metadata_cached(model_file)
|
||||
size_in_mb = get_model_size_mb(model_file)
|
||||
|
||||
# Extract values from metadata
|
||||
n_layers = None
|
||||
n_kv_heads = None
|
||||
embedding_dim = None
|
||||
context_length = None
|
||||
feed_forward_dim = None
|
||||
|
||||
for key, value in metadata.items():
|
||||
if key.endswith('.block_count'):
|
||||
n_layers = value
|
||||
elif key.endswith('.attention.head_count_kv'):
|
||||
n_kv_heads = value
|
||||
elif key.endswith('.embedding_length'):
|
||||
embedding_dim = value
|
||||
elif key.endswith('.context_length'):
|
||||
context_length = value
|
||||
elif key.endswith('.feed_forward_length'):
|
||||
feed_forward_dim = value
|
||||
|
||||
if gpu_layers > n_layers:
|
||||
gpu_layers = n_layers
|
||||
|
||||
# Convert cache_type to numeric
|
||||
if cache_type == 'q4_0':
|
||||
cache_type = 4
|
||||
elif cache_type == 'q8_0':
|
||||
cache_type = 8
|
||||
else:
|
||||
cache_type = 16
|
||||
|
||||
# Derived features
|
||||
size_per_layer = size_in_mb / max(n_layers, 1e-6)
|
||||
context_per_layer = context_length / max(n_layers, 1e-6)
|
||||
ffn_per_embedding = feed_forward_dim / max(embedding_dim, 1e-6)
|
||||
kv_cache_factor = n_kv_heads * cache_type * ctx_size
|
||||
|
||||
# Helper function for smaller
|
||||
def smaller(x, y):
|
||||
return 1 if x < y else 0
|
||||
|
||||
# Calculate VRAM using the model
|
||||
# Details: https://oobabooga.github.io/blog/posts/gguf-vram-formula/
|
||||
vram = (
|
||||
(size_per_layer - 21.19195204848197)
|
||||
* exp(0.0001047328491557063 * size_in_mb * smaller(ffn_per_embedding, 2.671096993407845))
|
||||
+ 0.0006621544775632052 * context_per_layer
|
||||
+ 3.34664386576376e-05 * kv_cache_factor
|
||||
) * (1.363306170123392 + gpu_layers) + 1255.163594536052
|
||||
|
||||
return vram
|
||||
|
||||
|
||||
def get_nvidia_free_vram():
|
||||
"""
|
||||
Calculates the total free VRAM across all NVIDIA GPUs by parsing nvidia-smi output.
|
||||
|
||||
Returns:
|
||||
int: The total free VRAM in MiB summed across all detected NVIDIA GPUs.
|
||||
Returns -1 if nvidia-smi command fails (not found, error, etc.).
|
||||
Returns 0 if nvidia-smi succeeds but no GPU memory info found.
|
||||
"""
|
||||
try:
|
||||
# Execute nvidia-smi command
|
||||
result = subprocess.run(
|
||||
['nvidia-smi'],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False
|
||||
)
|
||||
|
||||
# Check if nvidia-smi returned an error
|
||||
if result.returncode != 0:
|
||||
return -1
|
||||
|
||||
# Parse the output for memory usage patterns
|
||||
output = result.stdout
|
||||
|
||||
# Find memory usage like "XXXXMiB / YYYYMiB"
|
||||
# Captures used and total memory for each GPU
|
||||
matches = re.findall(r"(\d+)\s*MiB\s*/\s*(\d+)\s*MiB", output)
|
||||
|
||||
if not matches:
|
||||
# No GPUs found in expected format
|
||||
return 0
|
||||
|
||||
total_free_vram_mib = 0
|
||||
for used_mem_str, total_mem_str in matches:
|
||||
try:
|
||||
used_mib = int(used_mem_str)
|
||||
total_mib = int(total_mem_str)
|
||||
total_free_vram_mib += (total_mib - used_mib)
|
||||
except ValueError:
|
||||
# Skip malformed entries
|
||||
pass
|
||||
|
||||
return total_free_vram_mib
|
||||
|
||||
except FileNotFoundError:
|
||||
raise
|
||||
# nvidia-smi not found (likely no NVIDIA drivers installed)
|
||||
return -1
|
||||
except Exception:
|
||||
raise
|
||||
# Handle any other unexpected exceptions
|
||||
return -1
|
||||
|
|
|
@ -11,6 +11,7 @@ from modules.LoRA import add_lora_to_model
|
|||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (
|
||||
apply_model_settings_to_state,
|
||||
estimate_vram,
|
||||
get_model_metadata,
|
||||
save_instruction_template,
|
||||
save_model_settings,
|
||||
|
@ -44,6 +45,7 @@ def create_ui():
|
|||
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
|
||||
|
||||
with gr.Column():
|
||||
shared.gradio['vram_info'] = gr.HTML(value=lambda: estimate_vram_wrapper(shared.args.model, shared.args.gpu_layers, shared.args.ctx_size, shared.args.cache_type))
|
||||
shared.gradio['flash_attn'] = gr.Checkbox(label="flash-attn", value=shared.args.flash_attn, info='Use flash-attention.')
|
||||
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
|
||||
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
|
||||
|
@ -105,7 +107,6 @@ def create_ui():
|
|||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button', interactive=not mu)
|
||||
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button', interactive=not mu)
|
||||
|
||||
|
||||
with gr.Column():
|
||||
with gr.Tab("Download"):
|
||||
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.", interactive=not mu)
|
||||
|
@ -148,6 +149,11 @@ def create_event_handlers():
|
|||
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||
save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
|
||||
|
||||
shared.gradio['model_menu'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
|
||||
shared.gradio['gpu_layers'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
|
||||
shared.gradio['ctx_size'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
|
||||
shared.gradio['cache_type'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
|
||||
|
||||
if not shared.args.portable:
|
||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
|
||||
|
||||
|
@ -275,6 +281,14 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
|||
yield traceback.format_exc().replace('\n', '\n\n')
|
||||
|
||||
|
||||
def estimate_vram_wrapper(model, gpu_layers, ctx_size, cache_type):
|
||||
if model in ["None", None]:
|
||||
return "<div id=\"vram-info\"'>Estimated VRAM to load the model:</span>"
|
||||
|
||||
result = estimate_vram(model, gpu_layers, ctx_size, cache_type)
|
||||
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{result:.0f} MiB</span>"
|
||||
|
||||
|
||||
def update_truncation_length(current_length, state):
|
||||
if 'loader' in state:
|
||||
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':
|
||||
|
|
12
server.py
12
server.py
|
@ -49,8 +49,10 @@ from modules.extensions import apply_extensions
|
|||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, unload_model_if_idle
|
||||
from modules.models_settings import (
|
||||
estimate_vram,
|
||||
get_fallback_settings,
|
||||
get_model_metadata,
|
||||
get_nvidia_free_vram,
|
||||
update_model_parameters
|
||||
)
|
||||
from modules.shared import do_cmd_flags_warnings
|
||||
|
@ -248,6 +250,16 @@ if __name__ == "__main__":
|
|||
model_settings = get_model_metadata(model_name)
|
||||
update_model_parameters(model_settings, initial=True) # hijack the command-line arguments
|
||||
|
||||
if 'gpu_layers' not in shared.provided_arguments:
|
||||
available_vram = get_nvidia_free_vram()
|
||||
if available_vram > 0:
|
||||
n_layers = model_settings['gpu_layers']
|
||||
tolerance = 906
|
||||
while n_layers > 0 and estimate_vram(model_name, n_layers, shared.args.ctx_size, shared.args.cache_type) > available_vram - tolerance:
|
||||
n_layers -= 1
|
||||
|
||||
shared.args.gpu_layers = n_layers
|
||||
|
||||
# Load the model
|
||||
shared.model, shared.tokenizer = load_model(model_name)
|
||||
if shared.args.lora:
|
||||
|
|
Loading…
Add table
Reference in a new issue