TwT-6's picture
Upload 2667 files
256a159 verified
# flake8: noqa
from . import dataset_loader, utils
from .math_equivalence import is_equiv
def convert_to_set(item):
if isinstance(item, list):
return set(item)
if isinstance(item, str):
return {item}
if item is None:
return {}
raise ValueError("Input can't parse:", item)
def evaluate_single_sample(dataset_name, prediction, label):
if dataset_name in dataset_loader.multi_choice_datasets:
p = convert_to_set(prediction)
l = convert_to_set(label)
return p == l
elif dataset_name in dataset_loader.math_output_datasets:
return is_equiv(prediction, label)
else:
return prediction == label
# def evaluate(dataset_name, prediction_list, label_list):
# correct = 0
# if dataset_name in multi_choice_datasets:
# for prediction, label in zip(prediction_list, label_list):
# p = convert_to_set(prediction)
# l = convert_to_set(label)
# if p == l:
# correct += 1
# elif dataset_name in math_output_datasets:
# for prediction, label in zip(prediction_list, label_list):
# if is_equiv(prediction, label):
# correct += 1
# else:
# for prediction, label in zip(prediction_list, label_list):
# if prediction == label:
# correct += 1
# return "{0:.2%}".format(correct / len(label_list))