fix gptj forward and add torch.no_grad context manager

This commit is contained in:
PanQiWei 2023-04-17 00:15:41 +08:00
parent 941e281e9d
commit 12ae4d024c

View file

@ -79,6 +79,7 @@ class BaseGPTQForCausalLM:
def _resize_attention_mask(attention_mask: List[torch.LongTensor]): def _resize_attention_mask(attention_mask: List[torch.LongTensor]):
return attention_mask return attention_mask
@torch.no_grad()
def quantize(self, examples: List[Dict[str, torch.LongTensor]]): def quantize(self, examples: List[Dict[str, torch.LongTensor]]):
if self.quantized: if self.quantized:
raise EnvironmentError("can't execute quantize because the model is quantized.") raise EnvironmentError("can't execute quantize because the model is quantized.")
@ -94,7 +95,12 @@ class BaseGPTQForCausalLM:
super().__init__() super().__init__()
self.module = m self.module = m
def forward(self, inp, **kwargs): def forward(self, inp=None, **kwargs):
if inp is None: # some models use all key-value arguments in forward pass call
for kwarg_name in ["hidden_states"]:
if kwarg_name in kwargs:
inp = kwargs[kwarg_name]
break
bsz = inp.size(0) bsz = inp.size(0)
for i in range(bsz): for i in range(bsz):
layer_inputs.append(inp[i].to(CPU)) layer_inputs.append(inp[i].to(CPU))