Roman Castagné commited on
Commit
bfe5167
1 Parent(s): 9b92850

metric file

Browse files
Files changed (1) hide show
  1. ERR.py +132 -0
ERR.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token prediction metric."""
2
+
3
+ from typing import List, Tuple
4
+
5
+ import datasets
6
+ import numpy as np
7
+ from Levenshtein import distance as levenshtein_distance
8
+ from scipy.optimize import linear_sum_assignment
9
+
10
+ import evaluate
11
+
12
+
13
+ _DESCRIPTION = """
14
+ Unofficial implementation of the Error Reduction Rate (ERR) metric introduced for lexical normalization.
15
+ This implementation works on Seq2Seq models by aligning the predictions with the ground truth outputs.
16
+ """
17
+
18
+
19
+ _KWARGS_DESCRIPTION = """
20
+ Args:
21
+ predictions (`list` of `str`): Predicted labels.
22
+ references (`list` of `Dict[str, str]`): Ground truth sentences, each with a field `input` and `output`.
23
+ Returns:
24
+ `err` (`float` or `int`): Error Reduction Rate. See here: http://noisy-text.github.io/2021/multi-lexnorm.html
25
+ `err_tp` (`int`): Number of true positives.
26
+ `err_fn` (`int`): Number of false negatives.
27
+ `err_tn` (`int`): Number of true negatives.
28
+ `err_fp` (`int`): Number of false positives.
29
+ Examples:
30
+ Example 1-A simple example
31
+ >>> err = evaluate.load("err")
32
+ >>> results = err.compute(predictions=[["The", "large", "dog"]], references=[{"input": ["The", "large", "dawg"], "output": ["The", "large", "dog"]}])
33
+ >>> print(results)
34
+ {'err': 1.0, 'err_tp': 2, 'err_fn': 0, 'err_tn': 1, 'err_fp': 0}
35
+ """
36
+
37
+
38
+ _CITATION = """
39
+ @inproceedings{baldwin-etal-2015-shared,
40
+ title = "Shared Tasks of the 2015 Workshop on Noisy User-generated Text: {T}witter Lexical Normalization and Named Entity Recognition",
41
+ author = "Baldwin, Timothy and
42
+ de Marneffe, Marie Catherine and
43
+ Han, Bo and
44
+ Kim, Young-Bum and
45
+ Ritter, Alan and
46
+ Xu, Wei",
47
+ booktitle = "Proceedings of the Workshop on Noisy User-generated Text",
48
+ month = jul,
49
+ year = "2015",
50
+ address = "Beijing, China",
51
+ publisher = "Association for Computational Linguistics",
52
+ url = "https://aclanthology.org/W15-4319",
53
+ doi = "10.18653/v1/W15-4319",
54
+ pages = "126--135",
55
+ }
56
+ """
57
+
58
+
59
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
60
+ class ErrorReductionRate(evaluate.Metric):
61
+ def _info(self):
62
+ return evaluate.MetricInfo(
63
+ description=_DESCRIPTION,
64
+ citation=_CITATION,
65
+ inputs_description=_KWARGS_DESCRIPTION,
66
+ features=datasets.Features(
67
+ {
68
+ "predictions": datasets.Sequence(datasets.Value("string")),
69
+ "references": {
70
+ "input": datasets.Sequence(datasets.Value("string")),
71
+ "output": datasets.Sequence(datasets.Value("string")),
72
+ },
73
+ }
74
+ ),
75
+ )
76
+
77
+ def _compute(self, predictions, references):
78
+
79
+ tp, fn, tn, fp = 0, 0, 0, 0
80
+ for pred, ref in zip(predictions, references):
81
+ inputs, outputs = ref["input"], ref["output"]
82
+
83
+ labels = self._split_expressions_into_tokens(outputs)
84
+
85
+ assert len(pred) == len(
86
+ labels
87
+ ), f"Number of predicted words ({len(pred)}) does not match number of target words ({len(labels)})"
88
+
89
+ formatted_preds = self._align_predictions_with_labels(pred, labels)
90
+
91
+ for i in range(len(inputs)):
92
+ # Normalization was necessary
93
+ if inputs[i].lower() != outputs[i]:
94
+ tp += formatted_preds[i] == outputs[i]
95
+ fn += formatted_preds[i] != outputs[i]
96
+ else:
97
+ tn += formatted_preds[i] == outputs[i]
98
+ fp += formatted_preds[i] != outputs[i]
99
+
100
+ err = (tp - fp) / (tp + fn)
101
+
102
+ return {"err": err, "err_tp": tp, "err_fn": fn, "err_tn": tn, "err_fp": fp}
103
+
104
+ def _align_predictions_with_labels(self, predictions: List[str], labels: List[Tuple[str, int]]) -> List[str]:
105
+ levenshtein_matrix = np.zeros((len(labels), len(predictions)))
106
+
107
+ for i, (label, _) in enumerate(labels):
108
+ for j, pred in enumerate(predictions):
109
+ levenshtein_matrix[i, j] = levenshtein_distance(label, pred)
110
+
111
+ col_alignment, row_alignment = linear_sum_assignment(levenshtein_matrix)
112
+ alignment = sorted(row_alignment, key=lambda i: col_alignment[i])
113
+
114
+ num_outputs = max(map(lambda x: x[1], labels)) + 1
115
+ formatted_preds = [[] for _ in range(num_outputs)]
116
+ for i, aligned_idx in enumerate(alignment):
117
+ formatted_preds[labels[i][1]].append(predictions[aligned_idx])
118
+
119
+ formatted_preds = [" ".join(preds) for preds in formatted_preds]
120
+
121
+ return formatted_preds
122
+
123
+ def _split_expressions_into_tokens(self, outputs: List[str]) -> List[Tuple[str, int]]:
124
+ labels = []
125
+ for segment, normalized in enumerate(outputs):
126
+ if normalized == "":
127
+ labels.append((normalized, segment))
128
+ else:
129
+ for w in normalized.split():
130
+ labels.append((w, segment))
131
+
132
+ return labels