TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
1.48 kB
from typing import List
import numpy as np
from sklearn.metrics import roc_auc_score
from opencompass.registry import ICL_EVALUATORS
from .icl_base_evaluator import BaseEvaluator
@ICL_EVALUATORS.register_module()
class AUCROCEvaluator(BaseEvaluator):
"""Calculate AUC-ROC scores and accuracy according the prediction.
For some dataset, the accuracy cannot reveal the difference between
models because of the saturation. AUC-ROC scores can further exam
model abilities to distinguish different labels. More details can refer to
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.roc_auc_score.html
""" # noqa
def __init__(self) -> None:
super().__init__()
def score(self, predictions: List, references: List) -> dict:
"""Calculate scores and accuracy.
Args:
predictions (List): List of probabilities for each class of each
sample.
references (List): List of target labels for each sample.
Returns:
dict: calculated scores.
"""
if len(predictions) != len(references):
return {
'error': 'predictions and references have different length.'
}
auc_score = roc_auc_score(references, np.array(predictions)[:, 1])
accuracy = sum(
references == np.argmax(predictions, axis=1)) / len(references)
return dict(auc_score=auc_score * 100, accuracy=accuracy * 100)