"""Token prediction metric.""" from typing import List, Tuple import datasets import numpy as np from Levenshtein import distance as levenshtein_distance from scipy.optimize import linear_sum_assignment import evaluate _DESCRIPTION = """ Unofficial implementation of the Error Reduction Rate (ERR) metric introduced for lexical normalization. This implementation works on Seq2Seq models by aligning the predictions with the ground truth outputs. """ _KWARGS_DESCRIPTION = """ Args: predictions (`list` of `str`): Predicted labels. references (`list` of `Dict[str, str]`): Ground truth sentences, each with a field `input` and `output`. Returns: `err` (`float` or `int`): Error Reduction Rate. See here: http://noisy-text.github.io/2021/multi-lexnorm.html `err_tp` (`int`): Number of true positives. `err_fn` (`int`): Number of false negatives. `err_tn` (`int`): Number of true negatives. `err_fp` (`int`): Number of false positives. Examples: Example 1-A simple example >>> err = evaluate.load("err") >>> results = err.compute(predictions=[["The", "large", "dog"]], references=[{"input": ["The", "large", "dawg"], "output": ["The", "large", "dog"]}]) >>> print(results) {'err': 1.0, 'err_tp': 2, 'err_fn': 0, 'err_tn': 1, 'err_fp': 0} """ _CITATION = """ @inproceedings{baldwin-etal-2015-shared, title = "Shared Tasks of the 2015 Workshop on Noisy User-generated Text: {T}witter Lexical Normalization and Named Entity Recognition", author = "Baldwin, Timothy and de Marneffe, Marie Catherine and Han, Bo and Kim, Young-Bum and Ritter, Alan and Xu, Wei", booktitle = "Proceedings of the Workshop on Noisy User-generated Text", month = jul, year = "2015", address = "Beijing, China", publisher = "Association for Computational Linguistics", url = "https://aclanthology.org/W15-4319", doi = "10.18653/v1/W15-4319", pages = "126--135", } """ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) class ErrorReductionRate(evaluate.Metric): def _info(self): return evaluate.MetricInfo( description=_DESCRIPTION, citation=_CITATION, inputs_description=_KWARGS_DESCRIPTION, features=datasets.Features( { "predictions": datasets.Sequence(datasets.Value("string")), "references": { "input": datasets.Sequence(datasets.Value("string")), "output": datasets.Sequence(datasets.Value("string")), }, } ), ) def _compute(self, predictions, references): tp, fn, tn, fp = 0, 0, 0, 0 for pred, ref in zip(predictions, references): inputs, outputs = ref["input"], ref["output"] labels = self._split_expressions_into_tokens(outputs) assert len(pred) == len( labels ), f"Number of predicted words ({len(pred)}) does not match number of target words ({len(labels)})" formatted_preds = self._align_predictions_with_labels(pred, labels) for i in range(len(inputs)): # Normalization was necessary if inputs[i].lower() != outputs[i]: tp += formatted_preds[i] == outputs[i] fn += formatted_preds[i] != outputs[i] else: tn += formatted_preds[i] == outputs[i] fp += formatted_preds[i] != outputs[i] err = (tp - fp) / (tp + fn) return {"err": err, "err_tp": tp, "err_fn": fn, "err_tn": tn, "err_fp": fp} def _align_predictions_with_labels(self, predictions: List[str], labels: List[Tuple[str, int]]) -> List[str]: levenshtein_matrix = np.zeros((len(labels), len(predictions))) for i, (label, _) in enumerate(labels): for j, pred in enumerate(predictions): levenshtein_matrix[i, j] = levenshtein_distance(label, pred) col_alignment, row_alignment = linear_sum_assignment(levenshtein_matrix) alignment = sorted(row_alignment, key=lambda i: col_alignment[i]) num_outputs = max(map(lambda x: x[1], labels)) + 1 formatted_preds = [[] for _ in range(num_outputs)] for i, aligned_idx in enumerate(alignment): formatted_preds[labels[i][1]].append(predictions[aligned_idx]) formatted_preds = [" ".join(preds) for preds in formatted_preds] return formatted_preds def _split_expressions_into_tokens(self, outputs: List[str]) -> List[Tuple[str, int]]: labels = [] for segment, normalized in enumerate(outputs): if normalized == "": labels.append((normalized, segment)) else: for w in normalized.split(): labels.append((w, segment)) return labels