TwT-6's picture
Upload 2667 files
256a159 verified
from opencompass.registry import ICL_EVALUATORS
from opencompass.utils.text_postprocessors import general_postprocess
from .icl_base_evaluator import BaseEvaluator
@ICL_EVALUATORS.register_module()
class EMEvaluator(BaseEvaluator):
"""Exact match evaluator."""
def __init__(self) -> None:
super().__init__()
def score(self, predictions, references):
if len(predictions) != len(references):
return {
'error': 'predictions and references have different '
'length'
}
predictions = [
general_postprocess(prediction) for prediction in predictions
]
processed_answers = [[general_postprocess(j) for j in i]
for i in references]
cnt = 0
details = []
for pred, ans, origin_ans in zip(predictions, processed_answers,
references):
answers = list(set(ans + origin_ans))
detail = {'pred': pred, 'answer': answers}
if pred in ans or pred in origin_ans:
cnt += 1
detail['correct'] = True
else:
detail['correct'] = False
details.append(detail)
score = cnt / len(predictions) * 100
return {'score': score, 'details': details}