fix gptj forward and add torch.no_grad context manager
This commit is contained in:
parent
941e281e9d
commit
12ae4d024c
1 changed files with 7 additions and 1 deletions
|
@ -79,6 +79,7 @@ class BaseGPTQForCausalLM:
|
||||||
def _resize_attention_mask(attention_mask: List[torch.LongTensor]):
|
def _resize_attention_mask(attention_mask: List[torch.LongTensor]):
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def quantize(self, examples: List[Dict[str, torch.LongTensor]]):
|
def quantize(self, examples: List[Dict[str, torch.LongTensor]]):
|
||||||
if self.quantized:
|
if self.quantized:
|
||||||
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
raise EnvironmentError("can't execute quantize because the model is quantized.")
|
||||||
|
@ -94,7 +95,12 @@ class BaseGPTQForCausalLM:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = m
|
self.module = m
|
||||||
|
|
||||||
def forward(self, inp, **kwargs):
|
def forward(self, inp=None, **kwargs):
|
||||||
|
if inp is None: # some models use all key-value arguments in forward pass call
|
||||||
|
for kwarg_name in ["hidden_states"]:
|
||||||
|
if kwarg_name in kwargs:
|
||||||
|
inp = kwargs[kwarg_name]
|
||||||
|
break
|
||||||
bsz = inp.size(0)
|
bsz = inp.size(0)
|
||||||
for i in range(bsz):
|
for i in range(bsz):
|
||||||
layer_inputs.append(inp[i].to(CPU))
|
layer_inputs.append(inp[i].to(CPU))
|
||||||
|
|
Loading…
Add table
Reference in a new issue