Refactor the transformers loader (#6859)

This commit is contained in:
oobabooga 2025-04-20 13:33:47 -03:00 committed by GitHub
parent 6ba0164c70
commit ae02ffc605
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 464 additions and 528 deletions

View file

@ -7,10 +7,7 @@ from io import BytesIO
import requests
import tiktoken
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import LogitsProcessor, LogitsProcessorList
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg
@ -22,54 +19,7 @@ from modules.chat import (
load_instruction_template_memoized
)
from modules.presets import load_preset_memoized
from modules.text_generation import (
decode,
encode,
generate_reply,
get_reply_from_output_ids
)
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}):
self.logit_bias = logit_bias
if self.logit_bias:
self.keys = list([int(key) for key in self.logit_bias.keys()])
values = [self.logit_bias[str(key)] for key in self.keys]
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
debug_msg(f"{self})")
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logit_bias:
debug_msg(logits[0, self.keys], " + ", self.values)
logits[0, self.keys] += self.values
debug_msg(" --> ", logits[0, self.keys])
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
from modules.text_generation import decode, encode, generate_reply
def convert_logprobs_to_tiktoken(model, logprobs):
@ -107,6 +57,14 @@ def process_parameters(body, is_legacy=False):
elif isinstance(body['stop'], list):
generate_params['custom_stopping_strings'] = body['stop']
if shared.args.loader != 'llama.cpp':
from transformers import LogitsProcessorList
from modules.transformers_loader import (
LogitsBiasProcessor,
LogprobProcessor
)
logits_processor = []
logit_bias = body.get('logit_bias', None)
if logit_bias: # {str: float, ...}

View file

@ -16,11 +16,9 @@ from pydub import AudioSegment
from sse_starlette import EventSourceResponse
import extensions.openai.completions as OAIcompletions
import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages
import extensions.openai.logits as OAIlogits
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
@ -211,6 +209,8 @@ async def handle_image_generation(request: Request):
@app.post("/v1/embeddings", response_model=EmbeddingsResponse, dependencies=check_key)
async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
import extensions.openai.embeddings as OAIembeddings
input = request_data.input
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
@ -224,6 +224,8 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
@app.post("/v1/moderations", dependencies=check_key)
async def handle_moderations(request: Request):
import extensions.openai.moderations as OAImoderations
body = await request.json()
input = body["input"]
if not input:

View file

@ -2,7 +2,6 @@ from pathlib import Path
import modules.shared as shared
from modules.logging_colors import logger
from modules.models import get_device
def add_lora_to_model(lora_names):
@ -47,9 +46,10 @@ def add_lora_exllamav2(lora_names):
def add_lora_transformers(lora_names):
from peft import PeftModel
from modules.torch_utils import get_device
prior_set = set(shared.lora_names)
added_set = set(lora_names) - prior_set
removed_set = prior_set - set(lora_names)

View file

@ -2,9 +2,6 @@ import traceback
from queue import Queue
from threading import Thread
import torch
import transformers
import modules.shared as shared
@ -12,25 +9,6 @@ class StopNowException(Exception):
pass
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
return shared.stop_everything
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class Iteratorize:
"""

View file

@ -2,13 +2,12 @@ import datetime
from pathlib import Path
import pandas as pd
import torch
from datasets import load_dataset
from tqdm import tqdm
from modules import shared
from modules.logging_colors import logger
from modules.models import clear_torch_cache, load_model, unload_model
from modules.models import load_model, unload_model
from modules.models_settings import get_model_metadata, update_model_parameters
from modules.text_generation import encode
@ -39,6 +38,10 @@ def calculate_perplexity(models, input_dataset, stride, _max_length):
https://huggingface.co/docs/transformers/perplexity#calculating-ppl-with-fixedlength-models
'''
import torch
from modules.torch_utils import clear_torch_cache
if shared.args.loader == "llama.cpp":
logger.error("Perplexity evaluation is not implemented for the llama.cpp loader.")
raise ValueError

View file

@ -3,11 +3,9 @@ from collections import OrderedDict
import gradio as gr
from modules import shared
loaders_and_params = OrderedDict({
'Transformers': [
'gpu_memory',
'gpu_split',
'cpu_memory',
'alpha_value',
'compress_pos_emb',
@ -17,7 +15,6 @@ loaders_and_params = OrderedDict({
'load_in_4bit',
'torch_compile',
'use_flash_attention_2',
'auto_devices',
'cpu',
'disk',
'use_double_quant',
@ -346,10 +343,6 @@ def blacklist_samplers(loader, dynamic_temperature):
return output
def get_gpu_memory_keys():
return [k for k in shared.gradio if k.startswith('gpu_memory')]
@functools.cache
def get_all_params():
all_params = set()
@ -357,11 +350,6 @@ def get_all_params():
for el in loaders_and_params[k]:
all_params.add(el)
if 'gpu_memory' in all_params:
all_params.remove('gpu_memory')
for k in get_gpu_memory_keys():
all_params.add(k)
return sorted(all_params)
@ -371,8 +359,4 @@ def make_loader_params_visible(loader):
if loader in loaders_and_params:
params = loaders_and_params[loader]
if 'gpu_memory' in params:
params.remove('gpu_memory')
params += get_gpu_memory_keys()
return [gr.update(visible=True) if k in params else gr.update(visible=False) for k in all_params]

View file

@ -2,11 +2,10 @@ import time
import traceback
import numpy as np
import torch
from modules import models, sampler_hijack, shared
from modules import models, shared
from modules.logging_colors import logger
from modules.models import get_device, load_model
from modules.models import load_model
from modules.text_generation import generate_reply
global_scores = None
@ -38,18 +37,16 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
logger.error("No model is loaded! Select one in the Model tab.")
return 'Error: No model is loaded1 Select one in the Model tab.', previous
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
is_llamacpp = shared.model.__class__.__name__ == 'LlamaServer'
if is_llamacpp:
# llama.cpp case
if shared.model.__class__.__name__ == 'LlamaServer':
logprobs = shared.model.get_logits(prompt, state, n_probs=top_logits, use_samplers=use_samplers)
if return_dict:
output = {}
for entry in logprobs:
token = repr(entry['token'])
prob = entry['prob'] if use_samplers else np.exp(entry['logprob'])
output[token] = prob
return output
else:
output = ''
@ -57,9 +54,17 @@ def _get_next_logits(prompt, state, use_samplers, previous, top_logits=25, retur
token = repr(entry['token'])
prob = entry['prob'] if use_samplers else np.exp(entry['logprob'])
output += f"{prob:.5f} - {token}\n"
return output, previous
# All other model types
else:
import torch
from modules import sampler_hijack
from modules.torch_utils import get_device
is_non_hf_exllamav2 = shared.model.__class__.__name__ == 'Exllamav2Model'
if not use_samplers:
state = {'stream': True}

View file

@ -1,61 +1,10 @@
import gc
import os
import pprint
import re
import time
from pathlib import Path
import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import (
is_ccl_available,
is_npu_available,
is_xpu_available
)
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BitsAndBytesConfig,
is_torch_npu_available,
is_torch_xpu_available
)
import modules.shared as shared
from modules.logging_colors import logger
from modules.models_settings import get_model_metadata
transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.deepspeed:
import deepspeed
from transformers.integrations.deepspeed import (
HfDeepSpeedConfig,
is_deepspeed_zero3_enabled
)
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if is_xpu_available() and is_ccl_available():
torch.xpu.set_device(local_rank)
deepspeed.init_distributed(backend="ccl")
elif is_npu_available():
torch.npu.set_device(local_rank)
deepspeed.init_distributed(dist_backend="hccl")
else:
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
last_generation_time = time.time()
@ -66,8 +15,8 @@ def load_model(model_name, loader=None):
shared.is_seq2seq = False
shared.model_name = model_name
load_func_map = {
'Transformers': huggingface_loader,
'llama.cpp': llama_cpp_server_loader,
'Transformers': transformers_loader,
'ExLlamav3_HF': ExLlamav3_HF_loader,
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ExLlamav2': ExLlamav2_loader,
@ -86,7 +35,6 @@ def load_model(model_name, loader=None):
raise ValueError
shared.args.loader = loader
clear_torch_cache()
output = load_func_map[loader](model_name)
if type(output) is tuple:
model, tokenizer = output
@ -95,6 +43,7 @@ def load_model(model_name, loader=None):
if model is None:
return None, None
else:
from modules.transformers_loader import load_tokenizer
tokenizer = load_tokenizer(model_name)
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
@ -110,163 +59,6 @@ def load_model(model_name, loader=None):
return model, tokenizer
def load_tokenizer(model_name, tokenizer_dir=None):
if tokenizer_dir:
path_to_model = Path(tokenizer_dir)
else:
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
tokenizer = None
if path_to_model.exists():
if shared.args.no_use_fast:
logger.info('Loading the tokenizer with use_fast=False.')
tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
trust_remote_code=shared.args.trust_remote_code,
use_fast=not shared.args.no_use_fast
)
return tokenizer
def huggingface_loader(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
params = {
'low_cpu_mem_usage': True,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
}
if shared.args.trust_remote_code:
params['trust_remote_code'] = True
if shared.args.use_flash_attention_2:
params['use_flash_attention_2'] = True
if shared.args.force_safetensors:
params['force_safetensors'] = True
if shared.args.use_eager_attention:
params['attn_implementation'] = 'eager'
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
if 'chatglm' in model_name.lower():
LoaderClass = AutoModel
else:
if config.to_dict().get('is_encoder_decoder', False):
LoaderClass = AutoModelForSeq2SeqLM
shared.is_seq2seq = True
else:
LoaderClass = AutoModelForCausalLM
# Determine if we should use default loading
should_use_default_loading = not any([
shared.args.cpu,
shared.args.load_in_8bit,
shared.args.load_in_4bit,
shared.args.auto_devices,
shared.args.disk,
shared.args.deepspeed,
shared.args.gpu_memory is not None,
shared.args.cpu_memory is not None,
shared.args.compress_pos_emb > 1,
shared.args.alpha_value > 1,
])
# Load the model without any special settings
if should_use_default_loading:
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()
model = LoaderClass.from_pretrained(path_to_model, **params)
if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit):
device = get_device()
if device:
model = model.to(device)
# DeepSpeed ZeRO-3
elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(
path_to_model,
torch_dtype=params['torch_dtype'],
trust_remote_code=params.get('trust_remote_code')
)
model = deepspeed.initialize(
model=model,
config_params=ds_config,
model_parameters=None,
optimizer=None,
lr_scheduler=None
)[0]
model.module.eval() # Inference
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
# Load with quantization and/or offloading
else:
if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())):
logger.warning('torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.')
shared.args.cpu = True
if shared.args.cpu:
params['torch_dtype'] = torch.float32
else:
params['device_map'] = 'auto'
if x := get_max_memory_dict():
params['max_memory'] = x
if shared.args.load_in_4bit:
# See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
'load_in_4bit': True,
'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
'bnb_4bit_quant_type': shared.args.quant_type,
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
'llm_int8_enable_fp32_cpu_offload': True
}
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit:
if shared.args.auto_devices or shared.args.gpu_memory:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
if params.get('max_memory') is not None:
with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code'))
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params.get('max_memory'),
no_split_module_classes=model._no_split_modules
)
if shared.args.disk:
params['offload_folder'] = shared.args.disk_cache_dir
if shared.args.compress_pos_emb > 1:
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
elif shared.args.alpha_value > 1:
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()
model = LoaderClass.from_pretrained(path_to_model, **params)
if shared.args.torch_compile:
model = torch.compile(model)
return model
def llama_cpp_server_loader(model_name):
from modules.llama_cpp_server import LlamaServer
@ -284,6 +76,11 @@ def llama_cpp_server_loader(model_name):
logger.error(f"Error loading the model with llama.cpp: {str(e)}")
def transformers_loader(model_name):
from modules.transformers_loader import load_model_HF
return load_model_HF(model_name)
def ExLlamav3_HF_loader(model_name):
from modules.exllamav3_hf import Exllamav3HF
@ -328,70 +125,14 @@ def TensorRT_LLM_loader(model_name):
return model
def get_max_memory_dict():
max_memory = {}
max_cpu_memory = shared.args.cpu_memory.strip() if shared.args.cpu_memory is not None else '99GiB'
if shared.args.gpu_memory:
memory_map = list(map(lambda x: x.strip(), shared.args.gpu_memory))
for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB' if not re.match('.*ib$', memory_map[i].lower()) else memory_map[i]
max_memory['cpu'] = f'{max_cpu_memory}GiB' if not re.match('.*ib$', max_cpu_memory.lower()) else max_cpu_memory
# If --auto-devices is provided standalone, try to get a reasonable value
# for the maximum memory of device :0
elif shared.args.auto_devices:
if is_xpu_available():
total_mem = (torch.xpu.get_device_properties(0).total_memory / (1024 * 1024))
else:
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024 * 1024))
suggestion = round((total_mem - 1000) / 1000) * 1000
if total_mem - suggestion < 800:
suggestion -= 1000
suggestion = int(round(suggestion / 1000))
logger.warning(f"Auto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors. You can manually set other values.")
max_memory[0] = f'{suggestion}GiB'
max_memory['cpu'] = f'{max_cpu_memory}GiB' if not re.match('.*ib$', max_cpu_memory.lower()) else max_cpu_memory
return max_memory if len(max_memory) > 0 else None
def get_device():
if torch.cuda.is_available():
return torch.device('cuda')
elif shared.args.deepspeed:
import deepspeed
return deepspeed.get_accelerator().current_device_name()
elif torch.backends.mps.is_available():
return torch.device('mps')
elif is_torch_xpu_available():
return torch.device('xpu:0')
elif is_torch_npu_available():
return torch.device('npu:0')
else:
return None
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_xpu_available():
torch.xpu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
elif torch.backends.mps.is_available():
if hasattr(torch.backends.mps, 'empty_cache'):
torch.backends.mps.empty_cache()
def unload_model(keep_model_name=False):
is_llamacpp = (shared.model.__class__.__name__ == 'LlamaServer')
shared.model = shared.tokenizer = None
shared.lora_names = []
shared.model_dirty_from_training = False
if not is_llamacpp:
from modules.torch_utils import clear_torch_cache
clear_torch_cache()
if not keep_model_name:

View file

@ -188,17 +188,12 @@ def update_model_parameters(state, initial=False):
UI: update the command-line arguments based on the interface values
'''
elements = ui.list_model_elements() # the names of the parameters
gpu_memories = []
for i, element in enumerate(elements):
if element not in state:
continue
value = state[element]
if element.startswith('gpu_memory'):
gpu_memories.append(value)
continue
if initial and element in shared.provided_arguments:
continue
@ -211,18 +206,6 @@ def update_model_parameters(state, initial=False):
setattr(shared.args, element, value)
found_positive = False
for i in gpu_memories:
if i > 0:
found_positive = True
break
if not (initial and vars(shared.args)['gpu_memory'] != vars(shared.args_defaults)['gpu_memory']):
if found_positive:
shared.args.gpu_memory = [f"{i}MiB" for i in gpu_memories]
else:
shared.args.gpu_memory = None
def apply_model_settings_to_state(model, state):
'''

View file

@ -13,7 +13,7 @@ from transformers.generation.logits_process import (
from modules import shared
from modules.logging_colors import logger
from modules.models import get_device
from modules.torch_utils import get_device
global_scores = None

View file

@ -91,9 +91,7 @@ group.add_argument('--loader', type=str, help='Choose the model loader manually,
# Transformers/Accelerate
group = parser.add_argument_group('Transformers/Accelerate')
group.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
group.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
group.add_argument('--gpu-memory', type=str, nargs='+', help='Maximum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
group.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
group.add_argument('--cpu-memory', type=float, default=0, help='Maximum CPU memory in GiB. Use this for CPU offloading.')
group.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
group.add_argument('--disk-cache-dir', type=str, default='cache', help='Directory to save the disk cache to. Defaults to "cache".')
group.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision (using bitsandbytes).')

View file

@ -7,33 +7,18 @@ import time
import traceback
import numpy as np
import torch
import transformers
from transformers import (
LogitsProcessorList,
is_torch_npu_available,
is_torch_xpu_available
)
import modules.shared as shared
from modules import models, sampler_hijack
from modules.callbacks import (
Iteratorize,
Stream,
_StopEverythingStoppingCriteria
)
from modules import models
from modules.callbacks import Iteratorize
from modules.extensions import apply_extensions
from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import GrammarConstrainedLogitsProcessor
from modules.html_generator import generate_basic_html
from modules.logging_colors import logger
from modules.models import clear_torch_cache, get_device, load_model
sampler_hijack.hijack_samplers()
def generate_reply(*args, **kwargs):
if shared.args.idle_timeout > 0 and shared.model is None and shared.model_name not in [None, 'None']:
from modules.models import load_model
shared.model, shared.tokenizer = load_model(shared.model_name)
shared.generation_lock.acquire()
@ -46,7 +31,6 @@ def generate_reply(*args, **kwargs):
def _generate_reply(question, state, stopping_strings=None, is_chat=False, escape_html=False, for_ui=False):
# Find the appropriate generation function
generate_func = apply_extensions('custom_generate_reply')
if generate_func is None:
@ -80,7 +64,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
all_stop_strings += st
shared.stop_everything = False
seed = set_manual_seed(state['seed'])
last_update = -1
reply = ''
is_stream = state['stream']
@ -93,7 +76,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
min_update_interval = 1 / state['max_updates_second']
# Generate
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
for reply in generate_func(question, original_question, state, stopping_strings, is_chat=is_chat):
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
if escape_html:
reply = html.escape(reply)
@ -132,37 +115,48 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if shared.tokenizer is None:
raise ValueError('No tokenizer is loaded')
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel']:
# llama.cpp case
if shared.model.__class__.__name__ == 'LlamaServer':
input_ids = shared.tokenizer.encode(str(prompt), add_bos_token=add_bos_token)
else:
input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids))
if shared.model.__class__.__name__ not in ['Exllamav2Model']:
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
return input_ids
# All other model types
else:
import torch
from modules.torch_utils import get_device
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel']:
input_ids = shared.tokenizer.encode(str(prompt))
if shared.model.__class__.__name__ != 'Exllamav2Model':
input_ids = np.array(input_ids).reshape(1, len(input_ids))
else:
input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)
if hasattr(shared.tokenizer, 'bos_token_id') and shared.tokenizer.bos_token_id is not None:
if add_bos_token:
# Add BOS token if missing
if (len(input_ids[0]) > 0 and input_ids[0][0] != shared.tokenizer.bos_token_id) or len(input_ids[0]) == 0:
# Add a missing bos token (it may not have been added due to faulty model metadata)
bos_tensor = torch.tensor([[shared.tokenizer.bos_token_id]])
input_ids = torch.cat((bos_tensor, input_ids), 1)
# Prevent double bos token due to jinja templates with <s> somewhere
# Prevent double BOS tokens from jinja templates
while len(input_ids[0]) > 1 and input_ids[0][0] == shared.tokenizer.bos_token_id and input_ids[0][1] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
else:
# Remove any bos token that may have been added
# Remove BOS tokens when not wanted
while len(input_ids[0]) > 0 and input_ids[0][0] == shared.tokenizer.bos_token_id:
input_ids = input_ids[:, 1:]
# Handling truncation
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['LlamaServer', 'Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
if shared.model.__class__.__name__ in ['Exllamav2Model', 'TensorRTLLMModel'] or shared.args.cpu:
return input_ids
else:
device = get_device()
@ -221,6 +215,9 @@ def formatted_outputs(reply, model_name):
def set_manual_seed(seed):
import torch
from transformers import is_torch_npu_available, is_torch_xpu_available
seed = int(seed)
if seed == -1:
seed = random.randint(1, 2**31)
@ -285,10 +282,26 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
return reply
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
def generate_reply_HF(question, original_question, state, stopping_strings=None, is_chat=False):
import torch
import transformers
from transformers import LogitsProcessorList
from modules.grammar.grammar_utils import initialize_grammar
from modules.grammar.logits_process import (
GrammarConstrainedLogitsProcessor
)
from modules.torch_utils import clear_torch_cache, get_device
from modules.transformers_loader import (
Stream,
_StopEverythingStoppingCriteria
)
if shared.args.loader == 'Transformers':
clear_torch_cache()
seed = set_manual_seed(state['seed'])
generate_params = {}
for k in [
'temperature',
@ -458,11 +471,15 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
return
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
def generate_reply_custom(question, original_question, state, stopping_strings=None, is_chat=False):
"""
For models that do not use the transformers library for sampling
"""
seed = set_manual_seed(state['seed'])
seed = state['seed']
if shared.args.loader != 'llama.cpp':
print(shared.args.loader)
seed = set_manual_seed(seed)
t0 = time.time()
reply = ''

37
modules/torch_utils.py Normal file
View file

@ -0,0 +1,37 @@
import gc
import torch
from accelerate.utils import is_npu_available, is_xpu_available
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import shared
def get_device():
if torch.cuda.is_available():
return torch.device('cuda')
elif shared.args.deepspeed:
import deepspeed
return deepspeed.get_accelerator().current_device_name()
elif torch.backends.mps.is_available():
return torch.device('mps')
elif is_torch_xpu_available():
return torch.device('xpu:0')
elif is_torch_npu_available():
return torch.device('npu:0')
else:
return None
def clear_torch_cache():
gc.collect()
if not shared.args.cpu:
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif is_xpu_available():
torch.xpu.empty_cache()
elif is_npu_available():
torch.npu.empty_cache()
elif torch.backends.mps.is_available():
if hasattr(torch.backends.mps, 'empty_cache'):
torch.backends.mps.empty_cache()

View file

@ -15,13 +15,6 @@ from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
import transformers
from datasets import Dataset, load_dataset
from transformers import is_torch_xpu_available
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
from modules import shared, ui, utils
from modules.evaluate import (
@ -33,7 +26,6 @@ from modules.logging_colors import logger
from modules.models import reload_model
from modules.utils import natural_keys
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
PARAMETERS = ["lora_name", "always_override", "q_proj_en", "v_proj_en", "k_proj_en", "o_proj_en", "gate_proj_en", "down_proj_en", "up_proj_en", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
WANT_INTERRUPT = False
@ -284,6 +276,9 @@ def calc_trainable_parameters(model):
def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en: bool, k_proj_en: bool, o_proj_en: bool, gate_proj_en: bool, down_proj_en: bool, up_proj_en: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
import torch
import transformers
from datasets import Dataset, load_dataset
from peft import (
LoraConfig,
get_peft_model,
@ -293,6 +288,12 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules
from transformers import is_torch_xpu_available
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
global WANT_INTERRUPT
WANT_INTERRUPT = False

View file

@ -0,0 +1,281 @@
import os
import pprint
from pathlib import Path
import torch
import torch.nn.functional as F
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import (
is_ccl_available,
is_npu_available,
is_xpu_available
)
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BitsAndBytesConfig,
LogitsProcessor
)
import modules.shared as shared
from modules import sampler_hijack
from modules.logging_colors import logger
from modules.text_generation import get_reply_from_output_ids
from modules.torch_utils import get_device
transformers.logging.set_verbosity_error()
sampler_hijack.hijack_samplers()
local_rank = None
if shared.args.deepspeed:
import deepspeed
from transformers.integrations.deepspeed import (
HfDeepSpeedConfig,
is_deepspeed_zero3_enabled
)
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if is_xpu_available() and is_ccl_available():
torch.xpu.set_device(local_rank)
deepspeed.init_distributed(backend="ccl")
elif is_npu_available():
torch.npu.set_device(local_rank)
deepspeed.init_distributed(dist_backend="hccl")
else:
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
return shared.stop_everything
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}):
self.logit_bias = logit_bias
if self.logit_bias:
self.keys = list([int(key) for key in self.logit_bias.keys()])
values = [self.logit_bias[str(key)] for key in self.keys]
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logit_bias:
logits[0, self.keys] += self.values
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
def load_tokenizer(model_name, tokenizer_dir=None):
if tokenizer_dir:
path_to_model = Path(tokenizer_dir)
else:
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
tokenizer = None
if path_to_model.exists():
if shared.args.no_use_fast:
logger.info('Loading the tokenizer with use_fast=False.')
tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
trust_remote_code=shared.args.trust_remote_code,
use_fast=not shared.args.no_use_fast
)
return tokenizer
def load_model_HF(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
params = {
'low_cpu_mem_usage': True,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
}
if shared.args.trust_remote_code:
params['trust_remote_code'] = True
if shared.args.use_flash_attention_2:
params['use_flash_attention_2'] = True
if shared.args.force_safetensors:
params['force_safetensors'] = True
if shared.args.use_eager_attention:
params['attn_implementation'] = 'eager'
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
if 'chatglm' in model_name.lower():
LoaderClass = AutoModel
else:
if config.to_dict().get('is_encoder_decoder', False):
LoaderClass = AutoModelForSeq2SeqLM
shared.is_seq2seq = True
else:
LoaderClass = AutoModelForCausalLM
# Determine if we should use default loading
should_use_default_loading = not any([
shared.args.cpu,
shared.args.load_in_8bit,
shared.args.load_in_4bit,
shared.args.disk,
shared.args.deepspeed,
shared.args.cpu_memory is not None,
shared.args.compress_pos_emb > 1,
shared.args.alpha_value > 1,
])
# Load the model without any special settings
if should_use_default_loading:
params['device_map'] = 'auto'
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()
model = LoaderClass.from_pretrained(path_to_model, **params)
if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit):
device = get_device()
if device:
model = model.to(device)
# DeepSpeed ZeRO-3
elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(
path_to_model,
torch_dtype=params['torch_dtype'],
trust_remote_code=params.get('trust_remote_code')
)
model = deepspeed.initialize(
model=model,
config_params=ds_config,
model_parameters=None,
optimizer=None,
lr_scheduler=None
)[0]
model.module.eval() # Inference
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
# Load with quantization and/or offloading
else:
if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())):
logger.warning('torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.')
shared.args.cpu = True
if shared.args.cpu:
params['torch_dtype'] = torch.float32
else:
params['device_map'] = 'auto'
if x := get_max_memory_dict():
params['max_memory'] = x
if shared.args.load_in_4bit:
# See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
'load_in_4bit': True,
'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
'bnb_4bit_quant_type': shared.args.quant_type,
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
'llm_int8_enable_fp32_cpu_offload': True
}
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit:
if shared.args.gpu_split:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
if params.get('max_memory') is not None:
with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code'))
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params.get('max_memory'),
no_split_module_classes=model._no_split_modules
)
if shared.args.disk:
params['offload_folder'] = shared.args.disk_cache_dir
if shared.args.compress_pos_emb > 1:
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
elif shared.args.alpha_value > 1:
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()
model = LoaderClass.from_pretrained(path_to_model, **params)
if shared.args.torch_compile:
model = torch.compile(model)
return model
def get_max_memory_dict():
max_memory = {}
if shared.args.cpu_memory > 0:
max_memory['cpu'] = f'{shared.args.cpu_memory}GiB'
if shared.args.gpu_split:
for i, memory in enumerate(shared.args.gpu_split.split(',')):
max_memory[i] = f'{memory}GiB'
return max_memory if len(max_memory) > 0 else None

View file

@ -2,9 +2,7 @@ import copy
from pathlib import Path
import gradio as gr
import torch
import yaml
from transformers import is_torch_xpu_available
import extensions
from modules import shared
@ -128,7 +126,6 @@ def list_model_elements():
'torch_compile',
'flash_attn',
'use_flash_attention_2',
'auto_devices',
'cpu',
'disk',
'row_split',
@ -150,13 +147,6 @@ def list_model_elements():
'no_use_fast',
]
if is_torch_xpu_available():
for i in range(torch.xpu.device_count()):
elements.append(f'gpu_memory_{i}')
else:
for i in range(torch.cuda.device_count()):
elements.append(f'gpu_memory_{i}')
return elements

View file

@ -1,14 +1,9 @@
import importlib
import math
import re
import traceback
from functools import partial
from pathlib import Path
import gradio as gr
import psutil
import torch
from transformers import is_torch_npu_available, is_torch_xpu_available
from modules import loaders, shared, ui, utils
from modules.logging_colors import logger
@ -27,35 +22,6 @@ from modules.utils import gradio
def create_ui():
mu = shared.args.multi_user
# Finding the default values for the GPU and CPU memories
total_mem = []
if is_torch_xpu_available():
for i in range(torch.xpu.device_count()):
total_mem.append(math.floor(torch.xpu.get_device_properties(i).total_memory / (1024 * 1024)))
elif is_torch_npu_available():
for i in range(torch.npu.device_count()):
total_mem.append(math.floor(torch.npu.get_device_properties(i).total_memory / (1024 * 1024)))
else:
for i in range(torch.cuda.device_count()):
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
default_gpu_mem = []
if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
for i in shared.args.gpu_memory:
if 'mib' in i.lower():
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
else:
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
while len(default_gpu_mem) < len(total_mem):
default_gpu_mem.append(0)
total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
if shared.args.cpu_memory is not None:
default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
else:
default_cpu_mem = 0
with gr.Tab("Model", elem_id="model-tab"):
with gr.Row():
with gr.Column():
@ -80,10 +46,6 @@ def create_ui():
with gr.Blocks():
with gr.Row():
with gr.Column():
for i in range(len(total_mem)):
shared.gradio[f'gpu_memory_{i}'] = gr.Slider(label=f"gpu-memory in MiB for device :{i}", maximum=total_mem[i], value=default_gpu_mem[i])
shared.gradio['cpu_memory'] = gr.Slider(label="cpu-memory in MiB", maximum=total_cpu_mem, value=default_cpu_mem)
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=256, value=shared.args.n_gpu_layers, info='Must be greater than 0 for the GPU to be used. ⚠️ Lower this value if you can\'t load the model.')
shared.gradio['threads'] = gr.Slider(label="threads", minimum=0, step=1, maximum=256, value=shared.args.threads)
shared.gradio['threads_batch'] = gr.Slider(label="threads_batch", minimum=0, step=1, maximum=256, value=shared.args.threads_batch)
@ -94,6 +56,7 @@ def create_ui():
shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.')
shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40')
shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7')
shared.gradio['cpu_memory'] = gr.Number(label="Maximum CPU memory in GiB. Use this for CPU offloading.", value=shared.args.cpu_memory)
shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.')
shared.gradio['rope_freq_base'] = gr.Number(label='rope_freq_base', value=shared.args.rope_freq_base, precision=0, info='Positional embeddings frequency base for NTK RoPE scaling. Related to alpha_value by rope_freq_base = 10000 * alpha_value ^ (64 / 63). 0 = from model.')
shared.gradio['compress_pos_emb'] = gr.Number(label='compress_pos_emb', value=shared.args.compress_pos_emb, precision=2, info='Positional embeddings compression factor. Should be set to (context length) / (model\'s original context length). Equal to 1/rope_freq_scale.')
@ -107,7 +70,6 @@ def create_ui():
shared.gradio['torch_compile'] = gr.Checkbox(label="torch-compile", value=shared.args.torch_compile, info='Compile the model with torch.compile for improved performance.')
shared.gradio['flash_attn'] = gr.Checkbox(label="flash_attn", value=shared.args.flash_attn, info='Use flash-attention.')
shared.gradio['use_flash_attention_2'] = gr.Checkbox(label="use_flash_attention_2", value=shared.args.use_flash_attention_2, info='Set use_flash_attention_2=True while loading the model.')
shared.gradio['auto_devices'] = gr.Checkbox(label="auto-devices", value=shared.args.auto_devices)
shared.gradio['cpu'] = gr.Checkbox(label="cpu", value=shared.args.cpu, info='llama.cpp: Use llama-cpp-python compiled without GPU acceleration. Transformers: use PyTorch in CPU mode.')
shared.gradio['disk'] = gr.Checkbox(label="disk", value=shared.args.disk)
shared.gradio['row_split'] = gr.Checkbox(label="row_split", value=shared.args.row_split, info='Split the model by rows across GPUs. This may improve multi-gpu performance.')

View file

@ -1,11 +1,8 @@
import os
import warnings
from modules import shared
import accelerate # This early import makes Intel GPUs happy
import modules.one_click_installer_check
from modules import shared
from modules.block_requests import OpenMonkeyPatch, RequestBlocker
from modules.logging_colors import logger
@ -38,7 +35,6 @@ import yaml
import modules.extensions as extensions_module
from modules import (
chat,
training,
ui,
ui_chat,