Spaces:
Running
Running
# coding=utf-8 | |
# Copyright 2020 The HuggingFace Datasets Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" ROUGE metric from Google Research github repo. """ | |
# The dependencies in https://github.com/google-research/google-research/blob/master/rouge/requirements.txt | |
import absl # Here to have a nice missing dependency error message early on | |
import nltk # Here to have a nice missing dependency error message early on | |
import numpy # Here to have a nice missing dependency error message early on | |
import six # Here to have a nice missing dependency error message early on | |
from rouge_score import rouge_scorer, scoring | |
import datasets | |
_CITATION = """\ | |
@inproceedings{lin-2004-rouge, | |
title = "{ROUGE}: A Package for Automatic Evaluation of Summaries", | |
author = "Lin, Chin-Yew", | |
booktitle = "Text Summarization Branches Out", | |
month = jul, | |
year = "2004", | |
address = "Barcelona, Spain", | |
publisher = "Association for Computational Linguistics", | |
url = "https://www.aclweb.org/anthology/W04-1013", | |
pages = "74--81", | |
} | |
""" | |
_DESCRIPTION = """\ | |
ROUGE, or Recall-Oriented Understudy for Gisting Evaluation, is a set of metrics and a software package used for | |
evaluating automatic summarization and machine translation software in natural language processing. | |
The metrics compare an automatically produced summary or translation against a reference or a set of references (human-produced) summary or translation. | |
Note that ROUGE is case insensitive, meaning that upper case letters are treated the same way as lower case letters. | |
This metrics is a wrapper around Google Research reimplementation of ROUGE: | |
https://github.com/google-research/google-research/tree/master/rouge | |
""" | |
_KWARGS_DESCRIPTION = """ | |
Calculates average rouge scores for a list of hypotheses and references | |
Args: | |
predictions: list of predictions to score. Each predictions | |
should be a string with tokens separated by spaces. | |
references: list of reference for each prediction. Each | |
reference should be a string with tokens separated by spaces. | |
rouge_types: A list of rouge types to calculate. | |
Valid names: | |
`"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring, | |
`"rougeL"`: Longest common subsequence based scoring. | |
`"rougeLSum"`: rougeLsum splits text using `"\n"`. | |
See details in https://github.com/huggingface/datasets/issues/617 | |
use_stemmer: Bool indicating whether Porter stemmer should be used to strip word suffixes. | |
use_agregator: Return aggregates if this is set to True | |
Returns: | |
rouge1: rouge_1 (precision, recall, f1), | |
rouge2: rouge_2 (precision, recall, f1), | |
rougeL: rouge_l (precision, recall, f1), | |
rougeLsum: rouge_lsum (precision, recall, f1) | |
Examples: | |
>>> rouge = datasets.load_metric('rouge') | |
>>> predictions = ["hello there", "general kenobi"] | |
>>> references = ["hello there", "general kenobi"] | |
>>> results = rouge.compute(predictions=predictions, references=references) | |
>>> print(list(results.keys())) | |
['rouge1', 'rouge2', 'rougeL', 'rougeLsum'] | |
>>> print(results["rouge1"]) | |
AggregateScore(low=Score(precision=1.0, recall=1.0, fmeasure=1.0), mid=Score(precision=1.0, recall=1.0, fmeasure=1.0), high=Score(precision=1.0, recall=1.0, fmeasure=1.0)) | |
>>> print(results["rouge1"].mid.fmeasure) | |
1.0 | |
""" | |
class Rouge(datasets.Metric): | |
def _info(self): | |
return datasets.MetricInfo( | |
description=_DESCRIPTION, | |
citation=_CITATION, | |
inputs_description=_KWARGS_DESCRIPTION, | |
features=datasets.Features( | |
{ | |
"predictions": datasets.Value("string", id="sequence"), | |
"references": datasets.Value("string", id="sequence"), | |
} | |
), | |
codebase_urls=["https://github.com/google-research/google-research/tree/master/rouge"], | |
reference_urls=[ | |
"https://en.wikipedia.org/wiki/ROUGE_(metric)", | |
"https://github.com/google-research/google-research/tree/master/rouge", | |
], | |
) | |
def _compute(self, predictions, references, rouge_types=None, use_agregator=True, use_stemmer=False): | |
if rouge_types is None: | |
rouge_types = ["rouge1", "rouge2", "rougeL", "rougeLsum"] | |
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=use_stemmer) | |
if use_agregator: | |
aggregator = scoring.BootstrapAggregator() | |
else: | |
scores = [] | |
for ref, pred in zip(references, predictions): | |
score = scorer.score(ref, pred) | |
if use_agregator: | |
aggregator.add_scores(score) | |
else: | |
scores.append(score) | |
if use_agregator: | |
result = aggregator.aggregate() | |
else: | |
result = {} | |
for key in scores[0]: | |
result[key] = list(score[key] for score in scores) | |
return result |