from collections import Counter from typing import Any, Dict, List, Optional import numpy as np from torch import LongTensor from transformers import PreTrainedTokenizer, GenerationConfig from ._base import BaseTask from ._utils.generation_utils import postprocess_generation_ids from ._utils.classification_utils import get_closest_label def get_predictions( input_ids: LongTensor, output_ids: LongTensor, num_return_sequences: int, tokenizer: PreTrainedTokenizer, classes: List[str] ) -> List[int]: predictions = [] generated_texts = postprocess_generation_ids( input_ids=input_ids, output_ids=output_ids, num_return_sequences=num_return_sequences, tokenizer=tokenizer ) for sub_generated_texts in generated_texts: sub_predictions = [] for gen_text in sub_generated_texts: sub_predictions.append(get_closest_label(gen_text, classes)) predictions.append(Counter(sub_predictions).most_common(1)[0][0]) return predictions class SequenceClassificationTask(BaseTask): def __init__( self, model, tokenizer: PreTrainedTokenizer, classes: List[str], data_name_or_path: str, prompt_col_name: str, label_col_name: str, device: Optional[str] = None, **kwargs ): kwargs["merge_prompt_label"] = False super().__init__( model=model, tokenizer=tokenizer, data_name_or_path=data_name_or_path, prompt_col_name=prompt_col_name, label_col_name=label_col_name, device=device, **kwargs ) self.classes = [each.lower().strip() for each in classes] classes_ids = self.tokenizer(classes) self.max_new_tokens = max([len(each) for each in classes_ids]) def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[int]: generation_config = kwargs["generation_config"] output_ids = self.model.generate( input_ids=batch_data["input_ids"], attention_mask=batch_data["attention_mask"], generation_config=generation_config ) return get_predictions( batch_data["input_ids"], output_ids, generation_config.num_return_sequences, self.tokenizer, self.classes ) def _parse_labels(self, label_ids: LongTensor) -> List[int]: labels = [] for one_label_ids in label_ids: one_label_ids = one_label_ids[(one_label_ids == -100).sum():] label = self.classes.index(self.tokenizer.decode(one_label_ids).lower().strip()) labels.append(label) return labels def _metric(self, pred: List[int], label: List[int]) -> Dict[str, float]: pred = np.array(pred) label = np.array(label) acc = (pred == label).mean() return {"acc": acc} def run(self, generation_config: Optional[GenerationConfig] = None) -> Dict[str, float]: if not generation_config: generation_config = GenerationConfig( num_beams=1, do_sample=False, num_return_sequences=1 ) generation_config.max_new_tokens = self.max_new_tokens return super().run(generation_config=generation_config)