set xavier_uniform_ as lora_A's init function

This commit is contained in:
PanQiWei 2023-05-26 14:06:53 +08:00
parent 2b532f9453
commit 8bf21a7e4c
2 changed files with 2 additions and 4 deletions

View file

@ -1,2 +1,3 @@
from .modeling import BaseQuantizeConfig
from .modeling import AutoGPTQForCausalLM
from .utils.peft_utils import get_gptq_peft_model

View file

@ -47,7 +47,7 @@ class QuantLoraLinear(torch.nn.Linear, LoraLayer):
def reset_lora_parameters(self, adapter_name):
if adapter_name in self.lora_A.keys():
torch.nn.init.ones_(self.lora_A[adapter_name].weight)
torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
def merge(self):
@ -277,9 +277,6 @@ def get_gptq_peft_model(
model_id: str = None,
adapter_name: str = "default"
):
if not peft_config.task_type == TaskType.CAUSAL_LM:
raise TypeError("only support CAUSAL_LM task type.")
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = QuantLoraModel
try: