AutoGPTQ/auto_gptq/utils/peft_utils.py
2023-05-25 15:11:11 +08:00

51 lines
1.7 KiB
Python

from functools import wraps
import torch
from auto_gptq.modeling._base import BaseGPTQForCausalLM
from peft import get_peft_model, PeftConfig, PeftModel, TaskType, PeftType
from peft.import_utils import is_bnb_available
from peft.tuners.lora import Linear, Embedding
lora_embedding = (Embedding,)
lora_linear = (Linear,)
if is_bnb_available():
from peft.tuners.lora import Linear8bitLt
lora_linear += (Linear8bitLt,)
def get_gptq_peft_model(
model: BaseGPTQForCausalLM,
peft_config: PeftConfig = None,
model_id: str = None,
adapter_name: str = "default"
):
if not peft_config.task_type == TaskType.CAUSAL_LM:
raise TypeError("only support CAUSAL_LM task type.")
try:
if model_id is None:
if not peft_config:
raise ValueError("peft_config can't be None when model_id is None.")
peft_model = get_peft_model(model.model, peft_config)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except:
raise NotImplementedError(f"gptq model not support {peft_config.peft_type.value} peft type yet.")
if peft_config.peft_type == PeftType.LORA:
for n, m in model.named_modules():
if isinstance(m, lora_embedding + lora_linear):
old_forward = m.forward
@wraps(old_forward)
def new_forward(*args, **kwargs):
args = [arg.type_as(m.weight.data) if isinstance(arg, torch.Tensor) else arg for arg in args]
return old_forward(*args, **kwargs)
m.forward = new_forward
if peft_config.peft_type == PeftType.ADALORA:
pass
return peft_model