37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import sys
|
|
from typing import List, Sequence
|
|
|
|
import numpy as np
|
|
|
|
|
|
def levenshtein_distance(seq1: Sequence, seq2: Sequence):
|
|
if seq1 == seq2:
|
|
return 0
|
|
num_rows = len(seq1) + 1
|
|
num_cols = len(seq2) + 1
|
|
dp_matrix = np.empty((num_rows, num_cols))
|
|
dp_matrix[0, :] = range(num_cols)
|
|
dp_matrix[:, 0] = range(num_rows)
|
|
|
|
for i in range(1, num_rows):
|
|
for j in range(1, num_cols):
|
|
if seq1[i - 1] == seq2[j - 1]:
|
|
dp_matrix[i, j] = dp_matrix[i - 1, j - 1]
|
|
else:
|
|
dp_matrix[i, j] = min(dp_matrix[i - 1, j - 1], dp_matrix[i - 1, j], dp_matrix[i, j - 1]) + 1
|
|
|
|
return dp_matrix[num_rows - 1, num_cols - 1]
|
|
|
|
|
|
def get_closest_label(pred: Sequence, classes: List[Sequence]) -> int:
|
|
min_id = sys.maxsize
|
|
min_edit_distance = sys.maxsize
|
|
for i, class_label in enumerate(classes):
|
|
edit_distance = levenshtein_distance(pred, class_label)
|
|
if edit_distance < min_edit_distance:
|
|
min_id = i
|
|
min_edit_distance = edit_distance
|
|
return min_id
|
|
|
|
|
|
__all__ = ["levenshtein_distance", "get_closest_label"]
|