AutoGPTQ/auto_gptq/utils/peft_utils.py

347 lines
15 KiB
Python

import warnings
import re
from contextlib import contextmanager
from dataclasses import asdict
from enum import Enum
from typing import List, Optional
import torch
from peft import get_peft_model, PeftConfig, PeftModel, PeftType
from peft.peft_model import PEFT_TYPE_TO_MODEL_MAPPING
from peft.tuners.lora import LoraConfig, LoraLayer, LoraModel, Embedding, mark_only_lora_as_trainable, _freeze_adapter
from peft.mapping import PEFT_TYPE_TO_CONFIG_MAPPING
from peft.utils.other import (
_get_submodules,
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING
)
from ..modeling._base import BaseGPTQForCausalLM
class GPTQLoraConfig(LoraConfig):
injected_fused_attention: bool = False
class GPTQLoraLinear(torch.nn.Linear, LoraLayer):
def __init__(
self,
adapter_name: str,
linear_module: torch.nn.Linear,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
**kwargs,
):
init_lora_weights = kwargs.pop("init_lora_weights", True)
torch.nn.Linear.__init__(self, linear_module.in_features, linear_module.out_features)
LoraLayer.__init__(self, linear_module.in_features, linear_module.out_features)
self.linear_module = linear_module
self.weight.requires_grad = False
self.weight = self.linear_module.weight
self.bias = self.linear_module.bias
self.fan_in_fan_out = fan_in_fan_out
if fan_in_fan_out:
self.weight.data = self.weight.data.T
self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
self.active_adapter = adapter_name
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
def merge(self):
raise NotImplementedError("gptq model not support merge lora adapter")
def unmerge(self):
raise NotImplementedError("gptq model not support unmerge lora adapter")
def forward(self, x: torch.Tensor):
previous_dtype = x.dtype
if self.active_adapter not in self.lora_A.keys():
return self.linear_module(x)
if self.disable_adapters:
if self.r[self.active_adapter] > 0 and self.merged:
self.unmerge()
result = self.linear_module(x)
elif self.r[self.active_adapter] > 0 and not self.merged:
result = self.linear_module(x)
lora_B = self.lora_B[self.active_adapter]
lora_A = self.lora_A[self.active_adapter]
lora_dropout = self.lora_dropout[self.active_adapter]
scale = self.scaling[self.active_adapter]
x = x.type_as(lora_A.weight.data)
adapter_result = (lora_B(lora_A(lora_dropout(x))) * scale).type_as(result)
result += adapter_result
else:
result = self.linear_module(x)
result = result.to(previous_dtype)
return result
class GPTQLoraModel(torch.nn.Module):
def __init__(self, model, config, adapter_name):
super().__init__()
self.model = model
self.forward = self.model.forward
self.peft_config = config
self.add_adapter(adapter_name, self.peft_config[adapter_name])
def add_adapter(self, adapter_name, config=None):
if config is not None:
model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config
config = self._prepare_lora_config(config, model_config)
self.peft_config[adapter_name] = config
self._find_and_replace(adapter_name)
if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
raise ValueError(
"LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
)
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
if self.peft_config[adapter_name].inference_mode:
_freeze_adapter(self.model, adapter_name)
def _find_and_replace(self, adapter_name):
lora_config = self.peft_config[adapter_name]
is_target_modules_in_base_model = False
kwargs = {
"r": lora_config.r,
"lora_alpha": lora_config.lora_alpha,
"lora_dropout": lora_config.lora_dropout,
"fan_in_fan_out": lora_config.fan_in_fan_out,
"init_lora_weights": lora_config.init_lora_weights,
}
key_list = [key for key, _ in self.model.named_modules()]
for key in key_list:
if isinstance(lora_config.target_modules, str):
target_module_found = re.fullmatch(lora_config.target_modules, key)
else:
target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
if target_module_found:
if not is_target_modules_in_base_model:
is_target_modules_in_base_model = True
parent, target, target_name = _get_submodules(self.model, key)
bias = False
if hasattr(target, "bias"):
bias = target.bias is not None
if isinstance(target, LoraLayer):
target.update_layer(
adapter_name,
lora_config.r,
lora_config.lora_alpha,
lora_config.lora_dropout,
lora_config.init_lora_weights,
)
else:
if isinstance(target, torch.nn.Embedding):
embedding_kwargs = kwargs.copy()
embedding_kwargs.pop("fan_in_fan_out", None)
in_features, out_features = target.num_embeddings, target.embedding_dim
new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
else:
if isinstance(target, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
else:
raise ValueError(
f"Target module {target} is not supported. "
f"Currently, only `torch.nn.Linear` and its subclasses are supported."
)
new_module = GPTQLoraLinear(adapter_name, target, **kwargs)
self._replace_module(parent, target_name, new_module, target)
if not is_target_modules_in_base_model:
raise ValueError(
f"Target modules {lora_config.target_modules} not found in the base model. "
f"Please check the target modules and try again."
)
def _replace_module(self, parent_module, child_name, new_module, old_module):
setattr(parent_module, child_name, new_module)
if not isinstance(new_module, GPTQLoraLinear):
new_module.weight = old_module.weight
if hasattr(old_module, "bias"):
if old_module.bias is not None:
new_module.bias = old_module.bias
if getattr(old_module, "state", None) is not None:
new_module.state = old_module.state
new_module.to(old_module.weight.device)
# dispatch to correct device
for name, module in new_module.named_modules():
if "lora_" in name:
module.to(old_module.weight.device)
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Module's logic
except AttributeError:
return getattr(self.model, name)
def get_peft_config_as_dict(self, inference: bool = False):
config_dict = {}
for key, value in self.peft_config.items():
config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
if inference:
config["inference_mode"] = True
config_dict[key] = config
return config
def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, LoraLayer):
module.disable_adapters = False if enabled else True
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
def disable_adapter_layers(self):
self._set_adapter_layers(enabled=False)
def set_adapter(self, adapter_name):
for module in self.model.modules():
if isinstance(module, LoraLayer):
if module.merged:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.active_adapter = adapter_name
def merge_adapter(self):
raise NotImplementedError("gptq model not support merge lora adapter")
def unmerge_adapter(self):
raise NotImplementedError("gptq model not support unmerge lora adapter")
@staticmethod
def _prepare_lora_config(peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
if peft_config.inference_mode:
peft_config.merge_weights = True
return peft_config
def merge_and_unload(self):
raise NotImplementedError("gptq model not support merge and unload")
def add_weighted_adapter(self, adapters, weights, adapter_name):
if len({self.peft_config[adapter].r for adapter in adapters}) != 1:
raise ValueError("All adapters must have the same r value")
self.peft_config[adapter_name] = self.peft_config[adapters[0]]
self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r
self._find_and_replace(adapter_name)
mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
_freeze_adapter(self.model, adapter_name)
key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
if adapter_name in target.lora_A:
target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0
target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0
for adapter, weight in zip(adapters, weights):
if adapter not in target.lora_A:
continue
target.lora_A[adapter_name].weight.data += (
target.lora_A[adapter].weight.data * weight * target.scaling[adapter]
)
target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight
elif adapter_name in target.lora_embedding_A:
target.lora_embedding_A[adapter_name].data = target.lora_embedding_A[adapter_name].data * 0.0
target.lora_embedding_B[adapter_name].data = target.lora_embedding_B[adapter_name].data * 0.0
for adapter, weight in zip(adapters, weights):
if adapter not in target.lora_embedding_A:
continue
target.lora_embedding_A[adapter_name].data += (
target.lora_embedding_A[adapter].data * weight * target.scaling[adapter]
)
target.lora_embedding_B[adapter_name].data += target.lora_embedding_B[adapter].data * weight
def find_all_linear_names(model: BaseGPTQForCausalLM, ignore: Optional[List[str]] = None, ignore_lm_head: bool = True):
if not ignore:
ignore = []
lm_head_name = model.lm_head_name
if ignore_lm_head and lm_head_name not in ignore:
ignore.append(lm_head_name)
results = set()
for n, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
res = n.split('.')[-1]
if res not in ignore:
results.add(res)
return list(results)
@contextmanager
def hijack_peft_mappings():
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = GPTQLoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = GPTQLoraModel
try:
yield
except:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = LoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = LoraModel
raise
finally:
PEFT_TYPE_TO_CONFIG_MAPPING[PeftType.LORA] = LoraConfig
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = LoraModel
def get_gptq_peft_model(
model: BaseGPTQForCausalLM,
peft_config: PeftConfig = None,
model_id: str = None,
adapter_name: str = "default"
):
if (
model.fused_attn_module_type is not None and not model.injected_fused_attention
):
warnings.warn(
"You can just ignore this warning if the peft type you use isn't lora.\n"
f"{model.__class__.__name__} supports injecting fused attention but not enables this time. "
"If you are training lora adapters, you must also disable fused attention injection when loading quantized "
"base model at inference time, otherwise adapters may not be added to base model properly. "
"If you are loading lora adapters to do inference, you can reference to adapter's config file to check "
"whether the adapters are trained using base model that not enable fused attention injection"
)
if isinstance(peft_config, LoraConfig) and not isinstance(peft_config, GPTQLoraConfig):
peft_config = GPTQLoraConfig(**peft_config.to_dict())
peft_config.injected_fused_attention = model.injected_fused_attention
with hijack_peft_mappings():
try:
if model_id is None:
if not peft_config:
raise ValueError("peft_config can't be None when model_id is None.")
peft_model = get_peft_model(model.model, peft_config)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except:
raise NotImplementedError(f"auto_gptq not support {peft_config.peft_type.value} peft type yet.")
return peft_model
__all__ = ["get_gptq_peft_model"]