text-generation-webui-mirror/modules/transformers_loader.py
2025-04-26 08:56:54 -03:00

279 lines
9.8 KiB
Python

import os
import pprint
from pathlib import Path
import torch
import torch.nn.functional as F
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import (
is_ccl_available,
is_npu_available,
is_xpu_available
)
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoTokenizer,
BitsAndBytesConfig,
LogitsProcessor
)
import modules.shared as shared
from modules.logging_colors import logger
from modules.text_generation import get_reply_from_output_ids
from modules.torch_utils import get_device
transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.deepspeed:
import deepspeed
from transformers.integrations.deepspeed import (
HfDeepSpeedConfig,
is_deepspeed_zero3_enabled
)
from modules.deepspeed_parameters import generate_ds_config
# Distributed setup
local_rank = shared.args.local_rank if shared.args.local_rank is not None else int(os.getenv("LOCAL_RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
if is_xpu_available() and is_ccl_available():
torch.xpu.set_device(local_rank)
deepspeed.init_distributed(backend="ccl")
elif is_npu_available():
torch.npu.set_device(local_rank)
deepspeed.init_distributed(dist_backend="hccl")
else:
torch.cuda.set_device(local_rank)
deepspeed.init_distributed()
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
class _StopEverythingStoppingCriteria(transformers.StoppingCriteria):
def __init__(self):
transformers.StoppingCriteria.__init__(self)
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
return shared.stop_everything
class Stream(transformers.StoppingCriteria):
def __init__(self, callback_func=None):
self.callback_func = callback_func
def __call__(self, input_ids, scores) -> bool:
if self.callback_func is not None:
self.callback_func(input_ids[0])
return False
class LogitsBiasProcessor(LogitsProcessor):
def __init__(self, logit_bias={}):
self.logit_bias = logit_bias
if self.logit_bias:
self.keys = list([int(key) for key in self.logit_bias.keys()])
values = [self.logit_bias[str(key)] for key in self.keys]
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logit_bias:
logits[0, self.keys] += self.values
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
self.token_alternatives = {}
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
return logits
def __repr__(self):
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
def load_tokenizer(model_name, tokenizer_dir=None):
if tokenizer_dir:
path_to_model = Path(tokenizer_dir)
else:
path_to_model = Path(f"{shared.args.model_dir}/{model_name}/")
tokenizer = None
if path_to_model.exists():
if shared.args.no_use_fast:
logger.info('Loading the tokenizer with use_fast=False.')
tokenizer = AutoTokenizer.from_pretrained(
path_to_model,
trust_remote_code=shared.args.trust_remote_code,
use_fast=not shared.args.no_use_fast
)
return tokenizer
def load_model_HF(model_name):
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
params = {
'low_cpu_mem_usage': True,
'torch_dtype': torch.bfloat16 if shared.args.bf16 else torch.float16,
}
if shared.args.trust_remote_code:
params['trust_remote_code'] = True
if shared.args.use_flash_attention_2:
params['use_flash_attention_2'] = True
if shared.args.force_safetensors:
params['force_safetensors'] = True
if shared.args.use_eager_attention:
params['attn_implementation'] = 'eager'
config = AutoConfig.from_pretrained(path_to_model, trust_remote_code=shared.args.trust_remote_code)
if 'chatglm' in model_name.lower():
LoaderClass = AutoModel
else:
if config.to_dict().get('is_encoder_decoder', False):
LoaderClass = AutoModelForSeq2SeqLM
shared.is_seq2seq = True
else:
LoaderClass = AutoModelForCausalLM
# Determine if we should use default loading
should_use_default_loading = not any([
shared.args.cpu,
shared.args.load_in_8bit,
shared.args.load_in_4bit,
shared.args.disk,
shared.args.deepspeed,
shared.args.cpu_memory is not None,
shared.args.compress_pos_emb > 1,
shared.args.alpha_value > 1,
])
# Load the model without any special settings
if should_use_default_loading:
params['device_map'] = 'auto'
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()
model = LoaderClass.from_pretrained(path_to_model, **params)
if not (hasattr(model, 'is_loaded_in_4bit') and model.is_loaded_in_4bit):
device = get_device()
if device:
model = model.to(device)
# DeepSpeed ZeRO-3
elif shared.args.deepspeed:
model = LoaderClass.from_pretrained(
path_to_model,
torch_dtype=params['torch_dtype'],
trust_remote_code=params.get('trust_remote_code')
)
model = deepspeed.initialize(
model=model,
config_params=ds_config,
model_parameters=None,
optimizer=None,
lr_scheduler=None
)[0]
model.module.eval() # Inference
logger.info(f'DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}')
# Load with quantization and/or offloading
else:
if not any((shared.args.cpu, torch.cuda.is_available(), is_xpu_available(), torch.backends.mps.is_available())):
logger.warning('torch.cuda.is_available() and is_xpu_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.')
shared.args.cpu = True
if shared.args.cpu:
params['torch_dtype'] = torch.float32
else:
params['device_map'] = 'auto'
if x := get_max_memory_dict():
params['max_memory'] = x
if shared.args.load_in_4bit:
# See https://github.com/huggingface/transformers/pull/23479/files
# and https://huggingface.co/blog/4bit-transformers-bitsandbytes
quantization_config_params = {
'load_in_4bit': True,
'bnb_4bit_compute_dtype': eval(f"torch.{shared.args.compute_dtype}") if shared.args.compute_dtype in ["bfloat16", "float16", "float32"] else None,
'bnb_4bit_quant_type': shared.args.quant_type,
'bnb_4bit_use_double_quant': shared.args.use_double_quant,
'llm_int8_enable_fp32_cpu_offload': True
}
params['quantization_config'] = BitsAndBytesConfig(**quantization_config_params)
elif shared.args.load_in_8bit:
if shared.args.gpu_split:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
else:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
if params.get('max_memory') is not None:
with init_empty_weights():
model = LoaderClass.from_config(config, trust_remote_code=params.get('trust_remote_code'))
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params.get('max_memory'),
no_split_module_classes=model._no_split_modules
)
if shared.args.disk:
params['offload_folder'] = str(Path(shared.args.disk_cache_dir))
if shared.args.compress_pos_emb > 1:
params['rope_scaling'] = {'type': 'linear', 'factor': shared.args.compress_pos_emb}
elif shared.args.alpha_value > 1:
params['rope_scaling'] = {'type': 'dynamic', 'factor': shared.args.alpha_value}
logger.info("TRANSFORMERS_PARAMS=")
pprint.PrettyPrinter(indent=4, sort_dicts=False).pprint(params)
print()
model = LoaderClass.from_pretrained(path_to_model, **params)
if shared.args.torch_compile:
model = torch.compile(model)
return model
def get_max_memory_dict():
max_memory = {}
if shared.args.cpu_memory > 0:
max_memory['cpu'] = f'{shared.args.cpu_memory}GiB'
if shared.args.gpu_split:
for i, memory in enumerate(shared.args.gpu_split.split(',')):
max_memory[i] = f'{memory}GiB'
return max_memory if len(max_memory) > 0 else None