diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index c14829e..3d6e75b 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -79,6 +79,7 @@ class BaseGPTQForCausalLM: def _resize_attention_mask(attention_mask: List[torch.LongTensor]): return attention_mask + @torch.no_grad() def quantize(self, examples: List[Dict[str, torch.LongTensor]]): if self.quantized: raise EnvironmentError("can't execute quantize because the model is quantized.") @@ -94,7 +95,12 @@ class BaseGPTQForCausalLM: super().__init__() self.module = m - def forward(self, inp, **kwargs): + def forward(self, inp=None, **kwargs): + if inp is None: # some models use all key-value arguments in forward pass call + for kwarg_name in ["hidden_states"]: + if kwarg_name in kwargs: + inp = kwargs[kwarg_name] + break bsz = inp.size(0) for i in range(bsz): layer_inputs.append(inp[i].to(CPU))