# Copyright 2024 the LlamaFactory team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Optional import numpy as np from ...extras.misc import numpify if TYPE_CHECKING: from transformers import EvalPrediction @dataclass class ComputeAccuracy: r""" Computes reward accuracy and supports `batch_eval_metrics`. """ def _dump(self) -> Optional[Dict[str, float]]: result = None if hasattr(self, "score_dict"): result = {k: float(np.mean(v)) for k, v in self.score_dict.items()} self.score_dict = {"accuracy": []} return result def __post_init__(self): self._dump() def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]: chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1]) if not chosen_scores.shape: self.score_dict["accuracy"].append(chosen_scores > rejected_scores) else: for i in range(len(chosen_scores)): self.score_dict["accuracy"].append(chosen_scores[i] > rejected_scores[i]) if compute_result: return self._dump()