Use adapter_name for get_gptq_peft_model with train_mode=True

This commit is contained in:
Alexander Pozharskii 2023-09-24 17:11:19 +04:00
parent 06e071e68e
commit 0185095402

View file

@ -402,7 +402,7 @@ def get_gptq_peft_model(
with hijack_peft_mappings():
try:
if train_mode:
peft_model = get_peft_model(model.model, peft_config)
peft_model = get_peft_model(model.model, peft_config, adapter_name=adapter_name)
else:
peft_model = PeftModel.from_pretrained(model.model, model_id, adapter_name)
except: