AutoGPTQ/auto_gptq/eval_tasks/sequence_classification_task.py
2023-04-23 19:27:16 +08:00

107 lines
3.6 KiB
Python

from collections import Counter
from typing import Any, Dict, List, Optional
import numpy as np
from torch import LongTensor
from transformers import PreTrainedTokenizer, GenerationConfig
from ._base import BaseTask
from ._utils.generation_utils import postprocess_generation_ids
from ._utils.classification_utils import get_closest_label
def get_predictions(
input_ids: LongTensor,
output_ids: LongTensor,
num_return_sequences: int,
tokenizer: PreTrainedTokenizer,
classes: List[str]
) -> List[int]:
predictions = []
generated_texts = postprocess_generation_ids(
input_ids=input_ids,
output_ids=output_ids,
num_return_sequences=num_return_sequences,
tokenizer=tokenizer
)
for sub_generated_texts in generated_texts:
sub_predictions = []
for gen_text in sub_generated_texts:
sub_predictions.append(get_closest_label(gen_text.lower().strip(), classes))
predictions.append(Counter(sub_predictions).most_common(1)[0][0])
return predictions
class SequenceClassificationTask(BaseTask):
def __init__(
self,
model,
tokenizer: PreTrainedTokenizer,
classes: List[str],
data_name_or_path: str,
prompt_col_name: str,
label_col_name: str,
device: Optional[str] = None,
**kwargs
):
kwargs["merge_prompt_label"] = False
super().__init__(
model=model,
tokenizer=tokenizer,
data_name_or_path=data_name_or_path,
prompt_col_name=prompt_col_name,
label_col_name=label_col_name,
device=device,
**kwargs
)
self.classes = [each.lower().strip() for each in classes]
classes_ids = self.tokenizer(classes)
self.max_new_tokens = max([len(each) for each in classes_ids])
def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[int]:
generation_config = kwargs["generation_config"]
output_ids = self.model.generate(
input_ids=batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
generation_config=generation_config
)
return get_predictions(
batch_data["input_ids"],
output_ids,
generation_config.num_return_sequences,
self.tokenizer,
self.classes
)
def _parse_labels(self, label_ids: LongTensor) -> List[int]:
labels = []
for one_label_ids in label_ids:
one_label_ids = one_label_ids[(one_label_ids == -100).sum():]
label = self.tokenizer.decode(one_label_ids, clean_up_tokenization_spaces=True).lower().strip()
label = get_closest_label(label, self.classes)
labels.append(label)
return labels
def _metric(self, pred: List[int], label: List[int]) -> Dict[str, float]:
pred = np.array(pred)
label = np.array(label)
acc = (pred == label).mean()
return {"acc": acc}
def run(self, generation_config: Optional[GenerationConfig] = None) -> Dict[str, float]:
if not generation_config:
generation_config = GenerationConfig(
num_beams=1,
do_sample=False,
num_return_sequences=1
)
generation_config.max_new_tokens = self.max_new_tokens
generation_config.eos_token_id = self.tokenizer.eos_token_id
generation_config.pad_token_id = self.tokenizer.pad_token_id
return super().run(generation_config=generation_config)
__all__ = ["SequenceClassificationTask"]