ERR / ERR.py
Roman Castagné
metric file
bfe5167
"""Token prediction metric."""
from typing import List, Tuple
import datasets
import numpy as np
from Levenshtein import distance as levenshtein_distance
from scipy.optimize import linear_sum_assignment
import evaluate
_DESCRIPTION = """
Unofficial implementation of the Error Reduction Rate (ERR) metric introduced for lexical normalization.
This implementation works on Seq2Seq models by aligning the predictions with the ground truth outputs.
"""
_KWARGS_DESCRIPTION = """
Args:
predictions (`list` of `str`): Predicted labels.
references (`list` of `Dict[str, str]`): Ground truth sentences, each with a field `input` and `output`.
Returns:
`err` (`float` or `int`): Error Reduction Rate. See here: http://noisy-text.github.io/2021/multi-lexnorm.html
`err_tp` (`int`): Number of true positives.
`err_fn` (`int`): Number of false negatives.
`err_tn` (`int`): Number of true negatives.
`err_fp` (`int`): Number of false positives.
Examples:
Example 1-A simple example
>>> err = evaluate.load("err")
>>> results = err.compute(predictions=[["The", "large", "dog"]], references=[{"input": ["The", "large", "dawg"], "output": ["The", "large", "dog"]}])
>>> print(results)
{'err': 1.0, 'err_tp': 2, 'err_fn': 0, 'err_tn': 1, 'err_fp': 0}
"""
_CITATION = """
@inproceedings{baldwin-etal-2015-shared,
title = "Shared Tasks of the 2015 Workshop on Noisy User-generated Text: {T}witter Lexical Normalization and Named Entity Recognition",
author = "Baldwin, Timothy and
de Marneffe, Marie Catherine and
Han, Bo and
Kim, Young-Bum and
Ritter, Alan and
Xu, Wei",
booktitle = "Proceedings of the Workshop on Noisy User-generated Text",
month = jul,
year = "2015",
address = "Beijing, China",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/W15-4319",
doi = "10.18653/v1/W15-4319",
pages = "126--135",
}
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class ErrorReductionRate(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Sequence(datasets.Value("string")),
"references": {
"input": datasets.Sequence(datasets.Value("string")),
"output": datasets.Sequence(datasets.Value("string")),
},
}
),
)
def _compute(self, predictions, references):
tp, fn, tn, fp = 0, 0, 0, 0
for pred, ref in zip(predictions, references):
inputs, outputs = ref["input"], ref["output"]
labels = self._split_expressions_into_tokens(outputs)
assert len(pred) == len(
labels
), f"Number of predicted words ({len(pred)}) does not match number of target words ({len(labels)})"
formatted_preds = self._align_predictions_with_labels(pred, labels)
for i in range(len(inputs)):
# Normalization was necessary
if inputs[i].lower() != outputs[i]:
tp += formatted_preds[i] == outputs[i]
fn += formatted_preds[i] != outputs[i]
else:
tn += formatted_preds[i] == outputs[i]
fp += formatted_preds[i] != outputs[i]
err = (tp - fp) / (tp + fn)
return {"err": err, "err_tp": tp, "err_fn": fn, "err_tn": tn, "err_fp": fp}
def _align_predictions_with_labels(self, predictions: List[str], labels: List[Tuple[str, int]]) -> List[str]:
levenshtein_matrix = np.zeros((len(labels), len(predictions)))
for i, (label, _) in enumerate(labels):
for j, pred in enumerate(predictions):
levenshtein_matrix[i, j] = levenshtein_distance(label, pred)
col_alignment, row_alignment = linear_sum_assignment(levenshtein_matrix)
alignment = sorted(row_alignment, key=lambda i: col_alignment[i])
num_outputs = max(map(lambda x: x[1], labels)) + 1
formatted_preds = [[] for _ in range(num_outputs)]
for i, aligned_idx in enumerate(alignment):
formatted_preds[labels[i][1]].append(predictions[aligned_idx])
formatted_preds = [" ".join(preds) for preds in formatted_preds]
return formatted_preds
def _split_expressions_into_tokens(self, outputs: List[str]) -> List[Tuple[str, int]]:
labels = []
for segment, normalized in enumerate(outputs):
if normalized == "":
labels.append((normalized, segment))
else:
for w in normalized.split():
labels.append((w, segment))
return labels