AutoGPTQ/auto_gptq/eval_tasks/sequence_classification_task.py

132 lines
4.3 KiB
Python

import sys
from collections import Counter
from typing import Any, Dict, List, Optional
import numpy as np
from torch import LongTensor
from transformers import PreTrainedTokenizer, GenerationConfig
from ._base import BaseTask
def levenshtein_distance(str1: str, str2: str):
if str1 == str2:
return 0
num_rows = len(str1) + 1
num_cols = len(str2) + 1
dp_matrix = np.empty((num_rows, num_cols))
dp_matrix[0, :] = range(num_cols)
dp_matrix[:, 0] = range(num_rows)
for i in range(1, num_rows):
for j in range(1, num_cols):
if str1[i - 1] == str2[j - 1]:
dp_matrix[i, j] = dp_matrix[i - 1, j - 1]
else:
dp_matrix[i, j] = min(dp_matrix[i - 1, j - 1], dp_matrix[i - 1, j], dp_matrix[i, j - 1]) + 1
return dp_matrix[num_rows - 1, num_cols - 1]
def get_closest_label(pred: str, classes: List[str]) -> int:
min_id = sys.maxsize
min_edit_distance = sys.maxsize
for i, class_label in enumerate(classes):
edit_distance = levenshtein_distance(pred, class_label)
if edit_distance < min_edit_distance:
min_id = i
min_edit_distance = edit_distance
return min_id
def get_predictions(
input_ids: LongTensor,
output_ids: LongTensor,
num_return_sequences: int,
tokenizer: PreTrainedTokenizer,
classes: List[str]
) -> List[int]:
predictions = []
for idx, start in enumerate(range(0, len(output_ids), num_return_sequences)):
sub_output_ids = output_ids[start: start + num_return_sequences]
sub_generated_ids = sub_output_ids[..., input_ids[idx].size(0):]
sub_generated_texts = [
each.lower().strip() for each in tokenizer.batch_decode(
sub_generated_ids,
clean_up_tokenization_spaces=True
)
]
sub_predictions = []
for gen_text in sub_generated_texts:
sub_predictions.append(get_closest_label(gen_text, classes))
predictions.append(Counter(sub_predictions).most_common(1)[0][0])
return predictions
class SequenceClassificationTask(BaseTask):
def __init__(
self,
model,
tokenizer: PreTrainedTokenizer,
classes: List[str],
data_name_or_path: str,
prompt_col_name: str,
label_col_name: str,
device: Optional[str] = None,
**kwargs
):
kwargs["merge_prompt_label"] = False
super().__init__(
model=model,
tokenizer=tokenizer,
data_name_or_path=data_name_or_path,
prompt_col_name=prompt_col_name,
label_col_name=label_col_name,
device=device,
**kwargs
)
self.classes = [each.lower().strip() for each in classes]
classes_ids = self.tokenizer(classes)
self.max_new_tokens = max([len(each) for each in classes_ids])
def _predict(self, batch_data: Dict[str, Any], *args, **kwargs) -> List[int]:
generation_config = kwargs["generation_config"]
output_ids = self.model.generate(
input_ids=batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
generation_config=generation_config
)
return get_predictions(
batch_data["input_ids"],
output_ids,
generation_config.num_return_sequences,
self.tokenizer,
self.classes
)
def _parse_labels(self, label_ids: LongTensor) -> List[int]:
labels = []
for one_label_ids in label_ids:
one_label_ids = one_label_ids[(one_label_ids == -100).sum():]
label = self.classes.index(self.tokenizer.decode(one_label_ids).lower().strip())
labels.append(label)
return labels
def _metric(self, pred: List[int], label: List[int]) -> Dict[str, float]:
pred = np.array(pred)
label = np.array(label)
acc = (pred == label).mean()
return {"acc": acc}
def run(self, generation_config: Optional[GenerationConfig] = None) -> Dict[str, float]:
if not generation_config:
generation_config = GenerationConfig(
num_beams=1,
do_sample=False,
num_return_sequences=1
)
generation_config.max_new_tokens = self.max_new_tokens
return super().run(generation_config=generation_config)