|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TODO: Add a description here.""" |
|
|
|
from operator import eq |
|
from typing import Callable, Iterable, Union |
|
|
|
import evaluate |
|
import datasets |
|
import numpy as np |
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
_CITATION = """\ |
|
@InProceedings{huggingface:module, |
|
title = {A great new module}, |
|
authors={huggingface, Inc.}, |
|
year={2020} |
|
} |
|
""" |
|
|
|
|
|
_DESCRIPTION = """\ |
|
Computes precision, recall, f1 scores for joint entity-relation extraction task. |
|
""" |
|
|
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
Calculates how good are predictions given some references, using certain scores |
|
Args: |
|
predictions: list of predictions to score. Each predictions |
|
should be a string with tokens separated by spaces. |
|
references: list of reference for each prediction. Each |
|
reference should be a string with tokens separated by spaces. |
|
eq_fn: function to compare two items. Defaults to the equality operator. |
|
Returns: |
|
recall: |
|
precision: |
|
f1: |
|
Examples: |
|
>>> jer = evaluate.load("jer") |
|
>>> results = jer.compute(references=[["Baris | play | tennis", "Deniz | travel | London"]], predictions=[["Baris | play | tennis"]]) |
|
>>> print(results) |
|
{'recall': 0.5, 'precision': 1.0, 'f1': 0.6666666666666666} |
|
""" |
|
|
|
Triplet = Union[str, tuple, int] |
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
class jer(evaluate.Metric): |
|
"""TODO: Short description of my evaluation module.""" |
|
|
|
def _info(self): |
|
|
|
return evaluate.MetricInfo( |
|
|
|
module_type="metric", |
|
description=_DESCRIPTION, |
|
citation=_CITATION, |
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
|
features=datasets.Features({ |
|
'predictions': datasets.features.Sequence(datasets.Value('string')), |
|
'references': datasets.features.Sequence(datasets.Value('string')), |
|
}), |
|
|
|
homepage="http://module.homepage", |
|
|
|
codebase_urls=["http://github.com/path/to/codebase/of/new_module"], |
|
reference_urls=["http://path.to.reference.url/new_module"] |
|
) |
|
|
|
def _download_and_prepare(self, dl_manager): |
|
"""Optional: download external resources useful to compute the scores""" |
|
pass |
|
|
|
def _compute(self, predictions, references, eq_fn=eq): |
|
"""Returns the scores""" |
|
score_dicts = [ |
|
self._compute_single(prediction=prediction, reference=reference, eq_fn=eq_fn) |
|
for prediction, reference in zip(predictions, references) |
|
] |
|
return {('mean_' + key): np.mean([scores[key] for scores in score_dicts]) for key in score_dicts[0].keys()} |
|
|
|
def _compute_single( |
|
self, |
|
*, |
|
prediction: Iterable[Triplet], |
|
reference: Iterable[Triplet], |
|
eq_fn: Callable[[Triplet, Triplet], bool], |
|
): |
|
reference_set = set(reference) |
|
if len(reference) != len(reference_set): |
|
logger.warn(f"Duplicates found in the reference list {reference}") |
|
prediction_set = set(prediction) |
|
|
|
tp = sum(int(is_in(item, prediction, eq_fn=eq_fn)) for item in reference) |
|
fp = len(prediction_set) - tp |
|
fn = len(reference_set) - tp |
|
|
|
|
|
precision = tp / (tp + fp) if tp + fp > 0 else 0 |
|
recall = tp / (tp + fn) if tp + fn > 0 else 0 |
|
f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 |
|
|
|
return { |
|
'precision': precision, |
|
'recall': recall, |
|
'f1': f1_score |
|
} |
|
|
|
def is_in(target, collection: Iterable, eq_fn=eq) -> bool: |
|
for item in collection: |
|
if eq_fn(item, target): |
|
return True |
|
return False |