Estimate the VRAM for GGUF models + autoset gpu-layers (#6980)

This commit is contained in:
oobabooga 2025-05-16 00:07:37 -03:00 committed by GitHub
parent c4a715fd1e
commit 5534d01da0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 193 additions and 4 deletions

View file

@ -569,7 +569,7 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
}
.dark .message-body :not(pre) > code {
background-color: rgb(255 255 255 / 12.5%);
background-color: rgb(255 255 255 / 10%);
}
#chat-input {
@ -1386,3 +1386,15 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
strong {
font-weight: bold;
}
.min.svelte-1ybaih5 {
min-height: 0;
}
#vram-info .value {
color: #008d00;
}
.dark #vram-info .value {
color: #07ff07;
}

View file

@ -282,8 +282,10 @@ class LlamaServer:
cmd.append("--no-kv-offload")
if shared.args.row_split:
cmd += ["--split-mode", "row"]
cache_type = "fp16"
if shared.args.cache_type != "fp16" and shared.args.cache_type in llamacpp_valid_cache_types:
cmd += ["--cache-type-k", shared.args.cache_type, "--cache-type-v", shared.args.cache_type]
cache_type = shared.args.cache_type
if shared.args.compress_pos_emb != 1:
cmd += ["--rope-freq-scale", str(1.0 / shared.args.compress_pos_emb)]
if shared.args.rope_freq_base > 0:
@ -343,6 +345,7 @@ class LlamaServer:
print(' '.join(str(item) for item in cmd[1:]))
print()
logger.info(f"Using gpu_layers={shared.args.gpu_layers} | ctx_size={shared.args.ctx_size} | cache_type={cache_type}")
# Start the server with pipes for output
self.process = subprocess.Popen(
cmd,

View file

@ -71,7 +71,6 @@ def llama_cpp_server_loader(model_name):
else:
model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
logger.info(f"llama.cpp weights detected: \"{model_file}\"")
try:
model = LlamaServer(model_file)
return model, model

View file

@ -1,7 +1,11 @@
import functools
import json
import re
import subprocess
from math import exp
from pathlib import Path
import gradio as gr
import yaml
from modules import chat, loaders, metadata_gguf, shared, ui
@ -216,7 +220,17 @@ def apply_model_settings_to_state(model, state):
for k in model_settings:
if k in state:
state[k] = model_settings[k]
if k == 'gpu_layers':
available_vram = get_nvidia_free_vram()
n_layers = model_settings[k]
if available_vram > 0:
tolerance = 906
while n_layers > 0 and estimate_vram(model, n_layers, state['ctx_size'], state['cache_type']) > available_vram - tolerance:
n_layers -= 1
state[k] = gr.update(value=n_layers, maximum=model_settings[k])
else:
state[k] = model_settings[k]
return state
@ -277,3 +291,138 @@ def save_instruction_template(model, template):
yield (f"Instruction template for `{model}` unset in `{p}`, as the value for template was `{template}`.")
else:
yield (f"Instruction template for `{model}` saved to `{p}` as `{template}`.")
@functools.lru_cache(maxsize=None)
def get_gguf_metadata_cached(model_file):
return metadata_gguf.load_metadata(model_file)
def get_model_size_mb(model_file: Path) -> float:
filename = model_file.name
# Check for multipart pattern
match = re.match(r'(.+)-\d+-of-\d+\.gguf$', filename)
if match:
# It's a multipart file, find all matching parts
base_pattern = match.group(1)
part_files = sorted(model_file.parent.glob(f'{base_pattern}-*-of-*.gguf'))
total_size = sum(p.stat().st_size for p in part_files)
else:
# Single part
total_size = model_file.stat().st_size
return total_size / (1024 ** 2) # Return size in MB
def estimate_vram(gguf_file, gpu_layers, ctx_size, cache_type):
model_file = Path(f'{shared.args.model_dir}/{gguf_file}')
metadata = get_gguf_metadata_cached(model_file)
size_in_mb = get_model_size_mb(model_file)
# Extract values from metadata
n_layers = None
n_kv_heads = None
embedding_dim = None
context_length = None
feed_forward_dim = None
for key, value in metadata.items():
if key.endswith('.block_count'):
n_layers = value
elif key.endswith('.attention.head_count_kv'):
n_kv_heads = value
elif key.endswith('.embedding_length'):
embedding_dim = value
elif key.endswith('.context_length'):
context_length = value
elif key.endswith('.feed_forward_length'):
feed_forward_dim = value
if gpu_layers > n_layers:
gpu_layers = n_layers
# Convert cache_type to numeric
if cache_type == 'q4_0':
cache_type = 4
elif cache_type == 'q8_0':
cache_type = 8
else:
cache_type = 16
# Derived features
size_per_layer = size_in_mb / max(n_layers, 1e-6)
context_per_layer = context_length / max(n_layers, 1e-6)
ffn_per_embedding = feed_forward_dim / max(embedding_dim, 1e-6)
kv_cache_factor = n_kv_heads * cache_type * ctx_size
# Helper function for smaller
def smaller(x, y):
return 1 if x < y else 0
# Calculate VRAM using the model
# Details: https://oobabooga.github.io/blog/posts/gguf-vram-formula/
vram = (
(size_per_layer - 21.19195204848197)
* exp(0.0001047328491557063 * size_in_mb * smaller(ffn_per_embedding, 2.671096993407845))
+ 0.0006621544775632052 * context_per_layer
+ 3.34664386576376e-05 * kv_cache_factor
) * (1.363306170123392 + gpu_layers) + 1255.163594536052
return vram
def get_nvidia_free_vram():
"""
Calculates the total free VRAM across all NVIDIA GPUs by parsing nvidia-smi output.
Returns:
int: The total free VRAM in MiB summed across all detected NVIDIA GPUs.
Returns -1 if nvidia-smi command fails (not found, error, etc.).
Returns 0 if nvidia-smi succeeds but no GPU memory info found.
"""
try:
# Execute nvidia-smi command
result = subprocess.run(
['nvidia-smi'],
capture_output=True,
text=True,
check=False
)
# Check if nvidia-smi returned an error
if result.returncode != 0:
return -1
# Parse the output for memory usage patterns
output = result.stdout
# Find memory usage like "XXXXMiB / YYYYMiB"
# Captures used and total memory for each GPU
matches = re.findall(r"(\d+)\s*MiB\s*/\s*(\d+)\s*MiB", output)
if not matches:
# No GPUs found in expected format
return 0
total_free_vram_mib = 0
for used_mem_str, total_mem_str in matches:
try:
used_mib = int(used_mem_str)
total_mib = int(total_mem_str)
total_free_vram_mib += (total_mib - used_mib)
except ValueError:
# Skip malformed entries
pass
return total_free_vram_mib
except FileNotFoundError:
raise
# nvidia-smi not found (likely no NVIDIA drivers installed)
return -1
except Exception:
raise
# Handle any other unexpected exceptions
return -1

View file

@ -11,6 +11,7 @@ from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model
from modules.models_settings import (
apply_model_settings_to_state,
estimate_vram,
get_model_metadata,
save_instruction_template,
save_model_settings,
@ -44,6 +45,7 @@ def create_ui():
shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend)
with gr.Column():
shared.gradio['vram_info'] = gr.HTML(value=lambda: estimate_vram_wrapper(shared.args.model, shared.args.gpu_layers, shared.args.ctx_size, shared.args.cache_type))
shared.gradio['flash_attn'] = gr.Checkbox(label="flash-attn", value=shared.args.flash_attn, info='Use flash-attention.')
shared.gradio['streaming_llm'] = gr.Checkbox(label="streaming-llm", value=shared.args.streaming_llm, info='Activate StreamingLLM to avoid re-evaluating the entire prompt when old messages are removed.')
shared.gradio['load_in_8bit'] = gr.Checkbox(label="load-in-8bit", value=shared.args.load_in_8bit)
@ -105,7 +107,6 @@ def create_ui():
ui.create_refresh_button(shared.gradio['lora_menu'], lambda: None, lambda: {'choices': utils.get_available_loras(), 'value': shared.lora_names}, 'refresh-button', interactive=not mu)
shared.gradio['lora_menu_apply'] = gr.Button(value='Apply LoRAs', elem_classes='refresh-button', interactive=not mu)
with gr.Column():
with gr.Tab("Download"):
shared.gradio['custom_model_menu'] = gr.Textbox(label="Download model or LoRA", info="Enter the Hugging Face username/model path, for instance: facebook/galactica-125m. To specify a branch, add it at the end after a \":\" character like this: facebook/galactica-125m:main. To download a single file, enter its name in the second box.", interactive=not mu)
@ -148,6 +149,11 @@ def create_event_handlers():
ui.gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
save_model_settings, gradio('model_menu', 'interface_state'), gradio('model_status'), show_progress=False)
shared.gradio['model_menu'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
shared.gradio['gpu_layers'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
shared.gradio['ctx_size'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
shared.gradio['cache_type'].change(estimate_vram_wrapper, gradio('model_menu', 'gpu_layers', 'ctx_size', 'cache_type'), gradio('vram_info'), show_progress=False)
if not shared.args.portable:
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, gradio('lora_menu'), gradio('model_status'), show_progress=False)
@ -275,6 +281,14 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
yield traceback.format_exc().replace('\n', '\n\n')
def estimate_vram_wrapper(model, gpu_layers, ctx_size, cache_type):
if model in ["None", None]:
return "<div id=\"vram-info\"'>Estimated VRAM to load the model:</span>"
result = estimate_vram(model, gpu_layers, ctx_size, cache_type)
return f"<div id=\"vram-info\"'>Estimated VRAM to load the model: <span class=\"value\">{result:.0f} MiB</span>"
def update_truncation_length(current_length, state):
if 'loader' in state:
if state['loader'].lower().startswith('exllama') or state['loader'] == 'llama.cpp':

View file

@ -49,8 +49,10 @@ from modules.extensions import apply_extensions
from modules.LoRA import add_lora_to_model
from modules.models import load_model, unload_model_if_idle
from modules.models_settings import (
estimate_vram,
get_fallback_settings,
get_model_metadata,
get_nvidia_free_vram,
update_model_parameters
)
from modules.shared import do_cmd_flags_warnings
@ -248,6 +250,16 @@ if __name__ == "__main__":
model_settings = get_model_metadata(model_name)
update_model_parameters(model_settings, initial=True) # hijack the command-line arguments
if 'gpu_layers' not in shared.provided_arguments:
available_vram = get_nvidia_free_vram()
if available_vram > 0:
n_layers = model_settings['gpu_layers']
tolerance = 906
while n_layers > 0 and estimate_vram(model_name, n_layers, shared.args.ctx_size, shared.args.cache_type) > available_vram - tolerance:
n_layers -= 1
shared.args.gpu_layers = n_layers
# Load the model
shared.model, shared.tokenizer = load_model(model_name)
if shared.args.lora: