import warnings import re from contextlib import contextmanager from dataclasses import asdict from enum import Enum from typing import List, Optional import torch from peft import get_peft_model, PeftConfig, PeftModel, PeftType from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel, Embedding, mark_only_lora_as_trainable, _freeze_adapter from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING from peft.utils.other import ( _get_submodules, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING ) from ..modeling._base import BaseGPTQForCausalLM class GPTQLoraConfig(LoraConfig): injected_fused_attention: bool = False class GPTQLoraLinear(torch.nn.Linear, LoraLayer): def __init__( self, adapter_name: str, linear_module: torch.nn.Linear, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.0, fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) **kwargs, ): init_lora_weights = kwargs.pop("init_lora_weights", True) torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features) LoraLayer.__init__(self, linear_module.in_features, linear_module.out_features) self.linear_module = linear_module self.weight.requires_grad = False self.weight = self.linear_module.weight self.bias = self.linear_module.bias self.fan_in_fan_out = fan_in_fan_out if fan_in_fan_out: self.weight.data = self.weight.data.T self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights) self.active_adapter = adapter_name def reset_lora_parameters(self, adapter_name): if adapter_name in self.lora_A.keys(): torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight) torch.nn.init.zeros_(self.lora_B[adapter_name].weight) def merge(self): raise NotImplementedError("gptq model not support merge lora adapter") def unmerge(self): raise NotImplementedError("gptq model not support unmerge lora adapter") def forward(self, x: torch.Tensor): previous_dtype = x.dtype if self.active_adapter not in self.lora_A.keys(): return self.linear_module(x) if self.disable_adapters: if self.r[self.active_adapter] > 0 and self.merged: self.unmerge() result = self.linear_module(x) elif self.r[self.active_adapter] > 0 and not self.merged: result = self.linear_module(x) lora_B = self.lora_B[self.active_adapter] lora_A = self.lora_A[self.active_adapter] lora_dropout = self.lora_dropout[self.active_adapter] scale = self.scaling[self.active_adapter] x = x.type_as(lora_A.weight.data) adapter_result = (lora_B(lora_A(lora_dropout(x))) * scale).type_as(result) result += adapter_result else: result = self.linear_module(x) result = result.to(previous_dtype) return result class GPTQLoraModel(torch.nn.Module): def __init__(self, model, config, adapter_name): super().__init__() self.model = model self.forward = self.model.forward self.peft_config = config self.add_adapter(adapter_name, self.peft_config[adapter_name]) def add_adapter(self, adapter_name, config=None): if config is not None: model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config config = self._prepare_lora_config(config, model_config) self.peft_config[adapter_name] = config self._find_and_replace(adapter_name) if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none": raise ValueError( "LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters." ) mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) if self.peft_config[adapter_name].inference_mode: _freeze_adapter(self.model, adapter_name) def _find_and_replace(self, adapter_name): lora_config = self.peft_config[adapter_name] is_target_modules_in_base_model = False kwargs = { "r": lora_config.r, "lora_alpha": lora_config.lora_alpha, "lora_dropout": lora_config.lora_dropout, "fan_in_fan_out": lora_config.fan_in_fan_out, "init_lora_weights": lora_config.init_lora_weights, } key_list = [key for key, _ in self.model.named_modules()] for key in key_list: if isinstance(lora_config.target_modules, str): target_module_found = re.fullmatch(lora_config.target_modules, key) else: target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules) if target_module_found: if not is_target_modules_in_base_model: is_target_modules_in_base_model = True parent, target, target_name = _get_submodules(self.model, key) bias = False if hasattr(target, "bias"): bias = target.bias is not None if isinstance(target, LoraLayer): target.update_layer( adapter_name, lora_config.r, lora_config.lora_alpha, lora_config.lora_dropout, lora_config.init_lora_weights, ) else: if isinstance(target, torch.nn.Embedding): embedding_kwargs = kwargs.copy() embedding_kwargs.pop("fan_in_fan_out", None) in_features, out_features = target.num_embeddings, target.embedding_dim new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs) else: if isinstance(target, torch.nn.Linear): if kwargs["fan_in_fan_out"]: warnings.warn( "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " "Setting fan_in_fan_out to False." ) kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False else: raise ValueError( f"Target module {target} is not supported. " f"Currently, only `torch.nn.Linear` and its subclasses are supported." ) new_module = GPTQLoraLinear(adapter_name, target, **kwargs) self._replace_module(parent, target_name, new_module, target) if not is_target_modules_in_base_model: raise ValueError( f"Target modules {lora_config.target_modules} not found in the base model. " f"Please check the target modules and try again." ) def _replace_module(self, parent_module, child_name, new_module, old_module): setattr(parent_module, child_name, new_module) if not isinstance(new_module, GPTQLoraLinear): new_module.weight = old_module.weight if hasattr(old_module, "bias"): if old_module.bias is not None: new_module.bias = old_module.bias if getattr(old_module, "state", None) is not None: new_module.state = old_module.state new_module.to(old_module.weight.device) # dispatch to correct device for name, module in new_module.named_modules(): if "lora_" in name: module.to(old_module.weight.device) def __getattr__(self, name: str): """Forward missing attributes to the wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self.model, name) def get_peft_config_as_dict(self, inference: bool = False): config_dict = {} for key, value in self.peft_config.items(): config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} if inference: config["inference_mode"] = True config_dict[key] = config return config def _set_adapter_layers(self, enabled=True): for module in self.model.modules(): if isinstance(module, LoraLayer): module.disable_adapters = False if enabled else True def enable_adapter_layers(self): self._set_adapter_layers(enabled=True) def disable_adapter_layers(self): self._set_adapter_layers(enabled=False) def set_adapter(self, adapter_name): for module in self.model.modules(): if isinstance(module, LoraLayer): if module.merged: warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") module.unmerge() module.active_adapter = adapter_name def merge_adapter(self): raise NotImplementedError("gptq model not support merge lora adapter") def unmerge_adapter(self): raise NotImplementedError("gptq model not support unmerge lora adapter") @staticmethod def _prepare_lora_config(peft_config, model_config): if peft_config.target_modules is None: if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: raise ValueError("Please specify `target_modules` in `peft_config`") peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] if peft_config.inference_mode: peft_config.merge_weights = True return peft_config def merge_and_unload(self): raise NotImplementedError("gptq model not support merge and unload") def add_weighted_adapter(self, adapters, weights, adapter_name): if len({self.peft_config[adapter].r for adapter in adapters}) != 1: raise ValueError("All adapters must have the same r value") self.peft_config[adapter_name] = self.peft_config[adapters[0]] self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r self._find_and_replace(adapter_name) mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias) _freeze_adapter(self.model, adapter_name) key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, LoraLayer): if adapter_name in target.lora_A: target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0 target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0 for adapter, weight in zip(adapters, weights): if adapter not in target.lora_A: continue target.lora_A[adapter_name].weight.data += ( target.lora_A[adapter].weight.data * weight * target.scaling[adapter] ) target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight elif adapter_name in target.lora_embedding_A: target.lora_embedding_A[adapter_name].data = target.lora_embedding_A[adapter_name].data * 0.0 target.lora_embedding_B[adapter_name].data = target.lora_embedding_B[adapter_name].data * 0.0 for adapter, weight in zip(adapters, weights): if adapter not in target.lora_embedding_A: continue target.lora_embedding_A[adapter_name].data += ( target.lora_embedding_A[adapter].data * weight * target.scaling[adapter] ) target.lora_embedding_B[adapter_name].data += target.lora_embedding_B[adapter].data * weight def find_all_linear_names(model: BaseGPTQForCausalLM, ignore: Optional[List[str]] = None, ignore_lm_head: bool = True): if not ignore: ignore = [] lm_head_name = model.lm_head_name if ignore_lm_head and lm_head_name not in ignore: ignore.append(lm_head_name) results = set() for n, m in model.named_modules(): if isinstance(m, torch.nn.Linear): res = n.split('.')[-1] if res not in ignore: results.add(res) return list(results) @contextmanager def hijack_peft_mappings(): PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel try: yield except: PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = LoraConfig PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = LoraModel raise finally: PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = LoraConfig PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = LoraModel def get_gptq_peft_model( model: BaseGPTQForCausalLM, peft_config: PeftConfig = None, model_id: str = None, adapter_name: str = "default" ): if ( model.fused_attn_module_type is not None and not model.injected_fused_attention ): warnings.warn( "You can just ignore this warning if the peft type you use isn't lora.\n" f"{model.__class__.__name__} supports injecting fused attention but not enables this time. " "If you are training lora adapters, you must also disable fused attention injection when loading quantized " "base model at inference time, otherwise adapters may not be added to base model properly. " "If you are loading lora adapters to do inference, you can reference to adapter's config file to check " "whether the adapters are trained using base model that not enable fused attention injection" ) if isinstance(peft_config, LoraConfig) and not isinstance(peft_config, GPTQLoraConfig): peft_config = GPTQLoraConfig(**peft_config.to_dict()) peft_config.injected_fused_attention = model.injected_fused_attention with hijack_peft_mappings(): try: if model_id is None: if not peft_config: raise ValueError("peft_config can't be None when model_id is None.") peft_model = get_peft_model(model.model, peft_config) else: peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name) except: raise NotImplementedError(f"auto_gptq not support {peft_config.peft_type.value} peft type yet.") return peft_model __all__ = ["get_gptq_peft_model"]