support older version python
This commit is contained in:
parent
d0769c1a39
commit
ec6603d0ab
1 changed files with 6 additions and 9 deletions
|
@ -227,7 +227,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
break
|
||||
layer_inputs.append(move_to_device(inp, self.data_device))
|
||||
attention_masks.append(kwargs["attention_mask"].to(self.data_device))
|
||||
if (pos_ids := kwargs.get("position_ids", None)) is not None:
|
||||
pos_ids = kwargs.get("position_ids", None)
|
||||
if pos_ids is not None:
|
||||
position_ids.append(move_to_device(pos_ids, self.data_device))
|
||||
one_kwargs = dict()
|
||||
for k, v in kwargs.items(): # make sure other arguments also be captured
|
||||
|
@ -328,10 +329,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
additional_layer_inputs = {
|
||||
"attention_mask": layer_attention_mask
|
||||
}
|
||||
if (
|
||||
layer_position_ids := None if not position_ids
|
||||
else move_to_device(position_ids[j], cur_layer_device)
|
||||
) is not None:
|
||||
layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
|
||||
if layer_position_ids is not None:
|
||||
additional_layer_inputs["position_ids"] = layer_position_ids
|
||||
for k, v in layer_input_kwargs[j].items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
@ -363,10 +362,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
|
|||
additional_layer_inputs = {
|
||||
"attention_mask": layer_attention_mask
|
||||
}
|
||||
if (
|
||||
layer_position_ids := None if not position_ids
|
||||
else move_to_device(position_ids[j], cur_layer_device)
|
||||
) is not None:
|
||||
layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
|
||||
if layer_position_ids is not None:
|
||||
additional_layer_inputs["position_ids"] = layer_position_ids
|
||||
for k, v in layer_input_kwargs[j].items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
|
|
Loading…
Add table
Reference in a new issue