Spaces:
Sleeping
Sleeping
""" Official evaluation script for v1.1 of the SQuAD dataset. """ | |
import argparse | |
import json | |
import re | |
import string | |
import sys | |
from collections import Counter | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
return re.sub(r"\b(a|an|the)\b", " ", text) | |
def white_space_fix(text): | |
return " ".join(text.split()) | |
def remove_punc(text): | |
exclude = set(string.punctuation) | |
return "".join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def recall_score(prediction, ground_truth): | |
prediction_tokens = normalize_answer(prediction).split() | |
ground_truth_tokens = normalize_answer(ground_truth).split() | |
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |
num_same = sum(common.values()) | |
if num_same == 0: | |
return 0 | |
recall = 1.0 * num_same / len(ground_truth_tokens) | |
return recall | |
def precision_score(prediction, ground_truth): | |
prediction_tokens = normalize_answer(prediction).split() | |
ground_truth_tokens = normalize_answer(ground_truth).split() | |
common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |
num_same = sum(common.values()) | |
if num_same == 0: | |
return 0 | |
precision = 1.0 * num_same / len(prediction_tokens) | |
return precision | |
# def f1_score(prediction, ground_truth): | |
# prediction_tokens = normalize_answer(prediction).split() | |
# ground_truth_tokens = normalize_answer(ground_truth).split() | |
# common = Counter(prediction_tokens) & Counter(ground_truth_tokens) | |
# num_same = sum(common.values()) | |
# if num_same == 0: | |
# return 0 | |
# precision = 1.0 * num_same / len(prediction_tokens) | |
# recall = 1.0 * num_same / len(ground_truth_tokens) | |
# f1 = (2 * precision * recall) / (precision + recall) | |
# return f1 | |
def exact_match_score(prediction, ground_truth): | |
return normalize_answer(prediction) == normalize_answer(ground_truth) | |
def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): | |
scores_for_ground_truths = [] | |
for ground_truth in ground_truths: | |
score = metric_fn(prediction, ground_truth) | |
scores_for_ground_truths.append(score) | |
return max(scores_for_ground_truths) | |
def compute_score(dataset, predictions): | |
recall = precision = f1 = exact_match = total = 0 | |
for article in dataset: | |
for paragraph in article["paragraphs"]: | |
for qa in paragraph["qas"]: | |
total += 1 | |
if qa["id"] not in predictions: | |
message = "Unanswered question " + qa["id"] + " will receive score 0." | |
print(message, file=sys.stderr) | |
continue | |
ground_truths = list(map(lambda x: x["text"], qa["answers"])) | |
prediction = predictions[qa["id"]] | |
exact_match += metric_max_over_ground_truths(exact_match_score, prediction, ground_truths) | |
recall_temp = metric_max_over_ground_truths(recall_score, prediction, ground_truths) | |
precision_temp = metric_max_over_ground_truths(precision_score, prediction, ground_truths) | |
if recall_temp + precision_temp == 0: | |
f1_temp = 0 | |
else: | |
f1_temp = (2 * precision_temp * recall_temp) / (precision_temp + recall_temp) | |
f1 += f1_temp | |
recall += recall_temp | |
precision += precision_temp | |
exact_match = 100.0 * exact_match / total | |
f1 = 100.0 * f1 / total | |
precision = 100.0 * precision / total | |
recall = 100.0 * recall / total | |
return {"exact_match": exact_match, "f1": f1, "recall": recall, "precision": precision} | |
if __name__ == "__main__": | |
expected_version = "1.1" | |
parser = argparse.ArgumentParser(description="Evaluation for SQuAD " + expected_version) | |
parser.add_argument("dataset_file", help="Dataset file") | |
parser.add_argument("prediction_file", help="Prediction File") | |
args = parser.parse_args() | |
with open(args.dataset_file) as dataset_file: | |
dataset_json = json.load(dataset_file) | |
if dataset_json["version"] != expected_version: | |
print( | |
"Evaluation expects v-" + expected_version + ", but got dataset with v-" + dataset_json["version"], | |
file=sys.stderr, | |
) | |
dataset = dataset_json["data"] | |
with open(args.prediction_file) as prediction_file: | |
predictions = json.load(prediction_file) | |
print(json.dumps(compute_score(dataset, predictions))) | |