set xavier_uniform_ as lora_A's init function
This commit is contained in:
parent
2b532f9453
commit
8bf21a7e4c
2 changed files with 2 additions and 4 deletions
|
@ -1,2 +1,3 @@
|
|||
from .modeling import BaseQuantizeConfig
|
||||
from .modeling import AutoGPTQForCausalLM
|
||||
from .utils.peft_utils import get_gptq_peft_model
|
||||
|
|
|
@ -47,7 +47,7 @@ class QuantLoraLinear(torch.nn.Linear, LoraLayer):
|
|||
|
||||
def reset_lora_parameters(self, adapter_name):
|
||||
if adapter_name in self.lora_A.keys():
|
||||
torch.nn.init.ones_(self.lora_A[adapter_name].weight)
|
||||
torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
|
||||
torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
|
||||
|
||||
def merge(self):
|
||||
|
@ -277,9 +277,6 @@ def get_gptq_peft_model(
|
|||
model_id: str = None,
|
||||
adapter_name: str = "default"
|
||||
):
|
||||
if not peft_config.task_type == TaskType.CAUSAL_LM:
|
||||
raise TypeError("only support CAUSAL_LM task type.")
|
||||
|
||||
PEFT_TYPE_TO_MODEL_MAPPING[PeftType.LORA] = QuantLoraModel
|
||||
|
||||
try:
|
||||
|
|
Loading…
Add table
Reference in a new issue