38 lines
1.4 KiB
Python
38 lines
1.4 KiB
Python
from typing import List, Optional, Union
|
|
|
|
from torch import LongTensor
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
|
|
def postprocess_generation_ids(
|
|
input_ids: LongTensor,
|
|
output_ids: LongTensor,
|
|
num_return_sequences: int,
|
|
tokenizer: Optional[PreTrainedTokenizer] = None,
|
|
pad_token_ids: Optional[int] = None,
|
|
) -> List[List[Union[str, List[int]]]]:
|
|
outputs = []
|
|
for idx, start in enumerate(range(0, len(output_ids), num_return_sequences)):
|
|
sub_output_ids = output_ids[start: start + num_return_sequences]
|
|
sub_generated_ids = sub_output_ids[..., input_ids[idx].size(0):]
|
|
if tokenizer:
|
|
outputs.append(
|
|
[
|
|
generated_text for generated_text in tokenizer.batch_decode(
|
|
sub_generated_ids,
|
|
clean_up_tokenization_spaces=True
|
|
)
|
|
]
|
|
)
|
|
else:
|
|
sub_generated_ids = sub_output_ids.cpu().numpy().tolist()
|
|
for i, one_sub_generated_ids in enumerate(sub_generated_ids):
|
|
if pad_token_ids is not None and pad_token_ids in one_sub_generated_ids:
|
|
one_sub_generated_ids = one_sub_generated_ids[: one_sub_generated_ids.index(pad_token_ids)]
|
|
sub_generated_ids[i] = one_sub_generated_ids
|
|
outputs.append(sub_generated_ids)
|
|
|
|
return outputs
|
|
|
|
|
|
__all__ = ["postprocess_generation_ids"]
|