# flake8: noqa from . import dataset_loader, utils from .math_equivalence import is_equiv def convert_to_set(item): if isinstance(item, list): return set(item) if isinstance(item, str): return {item} if item is None: return {} raise ValueError("Input can't parse:", item) def evaluate_single_sample(dataset_name, prediction, label): if dataset_name in dataset_loader.multi_choice_datasets: p = convert_to_set(prediction) l = convert_to_set(label) return p == l elif dataset_name in dataset_loader.math_output_datasets: return is_equiv(prediction, label) else: return prediction == label # def evaluate(dataset_name, prediction_list, label_list): # correct = 0 # if dataset_name in multi_choice_datasets: # for prediction, label in zip(prediction_list, label_list): # p = convert_to_set(prediction) # l = convert_to_set(label) # if p == l: # correct += 1 # elif dataset_name in math_output_datasets: # for prediction, label in zip(prediction_list, label_list): # if is_equiv(prediction, label): # correct += 1 # else: # for prediction, label in zip(prediction_list, label_list): # if prediction == label: # correct += 1 # return "{0:.2%}".format(correct / len(label_list))