AutoGPTQ/auto_gptq/eval_tasks/_utils/classification_utils.py
2023-04-23 16:11:54 +08:00

37 lines
1.1 KiB
Python

import sys
from typing import List, Sequence
import numpy as np
def levenshtein_distance(seq1: Sequence, seq2: Sequence):
if seq1 == seq2:
return 0
num_rows = len(seq1) + 1
num_cols = len(seq2) + 1
dp_matrix = np.empty((num_rows, num_cols))
dp_matrix[0, :] = range(num_cols)
dp_matrix[:, 0] = range(num_rows)
for i in range(1, num_rows):
for j in range(1, num_cols):
if seq1[i - 1] == seq2[j - 1]:
dp_matrix[i, j] = dp_matrix[i - 1, j - 1]
else:
dp_matrix[i, j] = min(dp_matrix[i - 1, j - 1], dp_matrix[i - 1, j], dp_matrix[i, j - 1]) + 1
return dp_matrix[num_rows - 1, num_cols - 1]
def get_closest_label(pred: Sequence, classes: List[Sequence]) -> int:
min_id = sys.maxsize
min_edit_distance = sys.maxsize
for i, class_label in enumerate(classes):
edit_distance = levenshtein_distance(pred, class_label)
if edit_distance < min_edit_distance:
min_id = i
min_edit_distance = edit_distance
return min_id
__all__ = ["levenshtein_distance", "get_closest_label"]