51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
from functools import wraps
|
|
|
|
import torch
|
|
from auto_gptq.modeling._base import BaseGPTQForCausalLM
|
|
from peft import get_peft_model, PeftConfig, PeftModel, TaskType, PeftType
|
|
from peft.import_utils import is_bnb_available
|
|
from peft.tuners.lora import Linear, Embedding
|
|
|
|
lora_embedding = (Embedding,)
|
|
lora_linear = (Linear,)
|
|
if is_bnb_available():
|
|
from peft.tuners.lora import Linear8bitLt
|
|
|
|
lora_linear += (Linear8bitLt,)
|
|
|
|
|
|
def get_gptq_peft_model(
|
|
model: BaseGPTQForCausalLM,
|
|
peft_config: PeftConfig = None,
|
|
model_id: str = None,
|
|
adapter_name: str = "default"
|
|
):
|
|
if not peft_config.task_type == TaskType.CAUSAL_LM:
|
|
raise TypeError("only support CAUSAL_LM task type.")
|
|
|
|
try:
|
|
if model_id is None:
|
|
if not peft_config:
|
|
raise ValueError("peft_config can't be None when model_id is None.")
|
|
peft_model = get_peft_model(model.model, peft_config)
|
|
else:
|
|
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
|
|
except:
|
|
raise NotImplementedError(f"gptq model not support {peft_config.peft_type.value} peft type yet.")
|
|
|
|
if peft_config.peft_type == PeftType.LORA:
|
|
for n, m in model.named_modules():
|
|
if isinstance(m, lora_embedding + lora_linear):
|
|
old_forward = m.forward
|
|
|
|
@wraps(old_forward)
|
|
def new_forward(*args, **kwargs):
|
|
args = [arg.type_as(m.weight.data) if isinstance(arg, torch.Tensor) else arg for arg in args]
|
|
return old_forward(*args, **kwargs)
|
|
|
|
m.forward = new_forward
|
|
|
|
if peft_config.peft_type == PeftType.ADALORA:
|
|
pass
|
|
|
|
return peft_model
|