support older version python

This commit is contained in:
PanQiWei 2023-05-31 22:11:16 +08:00
parent d0769c1a39
commit ec6603d0ab

View file

@ -227,7 +227,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
break
layer_inputs.append(move_to_device(inp, self.data_device))
attention_masks.append(kwargs["attention_mask"].to(self.data_device))
if (pos_ids := kwargs.get("position_ids", None)) is not None:
pos_ids = kwargs.get("position_ids", None)
if pos_ids is not None:
position_ids.append(move_to_device(pos_ids, self.data_device))
one_kwargs = dict()
for k, v in kwargs.items(): # make sure other arguments also be captured
@ -328,10 +329,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
additional_layer_inputs = {
"attention_mask": layer_attention_mask
}
if (
layer_position_ids := None if not position_ids
else move_to_device(position_ids[j], cur_layer_device)
) is not None:
layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):
@ -363,10 +362,8 @@ class BaseGPTQForCausalLM(nn.Module, PushToHubMixin):
additional_layer_inputs = {
"attention_mask": layer_attention_mask
}
if (
layer_position_ids := None if not position_ids
else move_to_device(position_ids[j], cur_layer_device)
) is not None:
layer_position_ids = None if not position_ids else move_to_device(position_ids[j], cur_layer_device)
if layer_position_ids is not None:
additional_layer_inputs["position_ids"] = layer_position_ids
for k, v in layer_input_kwargs[j].items():
if isinstance(v, torch.Tensor):