Spaces:
Sleeping
Sleeping
import logging | |
import pandas as pd | |
import os | |
import csv | |
import src.envs as envs | |
from src.backend.model_operations import SummaryGenerator, EvaluationModel | |
import src.backend.util as util | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s') | |
class Evaluator: | |
"""A class to evaluate summaries generated by a language model. | |
Attributes: | |
model (str): The name or path of the model. | |
revision (str): The model revision. | |
precision (str): The precision setting of the model. | |
num_fewshot (int): Number of few-shot examples to use. | |
batch_size (int): Batch size for processing. | |
device (str): The device to run the model on. | |
no_cache (bool): Flag to disable caching. | |
limit (int): Limit on the number of items to process. | |
write_out (bool): Whether to write results to a file. | |
output_base_path (str): Base path for output files. | |
summary_generator (SummaryGenerator): Instance for generating summaries. | |
eval_model (EvaluationModel): Instance for evaluating summaries. | |
""" | |
def __init__(self, model, revision, precision, batch_size, | |
device, no_cache, limit, write_out=True, | |
output_base_path='logs'): | |
"""Initializes the Evaluator with the given model and settings. | |
Args: | |
model (str): The name or path of the model. | |
revision (str): The model revision. | |
precision (str): The precision setting of the model. | |
num_fewshot (int): Number of few-shot examples to use. | |
batch_size (int): Batch size for processing. | |
device (str): The device to run the model on. | |
no_cache (bool): Flag to disable caching. | |
limit (int): Limit on the number of items to process. | |
write_out (bool): Whether to write results to a file. | |
output_base_path (str): Base path for output files. | |
""" | |
self.model = model | |
self.revision = revision | |
self.precision = precision | |
self.batch_size = batch_size | |
self.device = device | |
self.no_cache = no_cache | |
self.limit = limit | |
self.write_out = write_out | |
self.output_base_path = output_base_path | |
try: | |
self.summary_generator = SummaryGenerator(model, revision) | |
self.eval_model = EvaluationModel(envs.HEM_PATH) | |
except Exception as e: | |
logging.error(f"Error initializing Evaluator: {e}") | |
raise | |
def evaluate(self): | |
""" | |
Performs the evaluation process by generating summaries | |
and computing metrics. | |
Returns: | |
dict: A dictionary containing evaluation results. | |
""" | |
try: | |
from openpyxl import load_workbook | |
# df = load_workbook(filename=envs.DATASET_PATH) | |
df_prompt = load_workbook(filename=envs.PROMPT_PATH) | |
# df = pd.read_excel(envs.DATASET_PATH, engine='xlrd') #读取原数据,原始数据,本项目这里应该是问题 | |
# df_prompt = pd.read_excel(envs.PROMPT_PATH, engine='xlrd') | |
# df_prompt = pd.read_csv(envs.PROMPT_PATH) | |
# print(envs.DATASET_PATH) | |
# print(df.shape) | |
# print(df.iloc[-1]) | |
self.generated_summaries_df = self.summary_generator.generate_summaries(envs.DATASET_PATH, df_prompt, save_path=f"generation_results/{self.model}.csv") | |
# exit() | |
# avg_summary_len = self.summary_generator.avg_length | |
# answer_rate = self.summary_generator.answer_rate | |
'''开始评估模型的结果''' | |
self.humanlike = self.eval_model.evaluate_humanlike(self.generated_summaries_df, envs.HUMAN_DATA, f"generation_results/{self.model}.csv") | |
'''原始指标''' | |
# self.hallucination_scores, self.eval_results = self.eval_model.evaluate_hallucination( | |
# self.generated_summaries_df) | |
# factual_consistency_rate = self.eval_model.compute_factual_consistency_rate() | |
# hallucination_rate = self.eval_model.hallucination_rate | |
factual_consistency_rate = 0 | |
answer_rate = 0 | |
avg_summary_len = 0 | |
results = util.format_results(model_name=self.model, revision=self.revision, | |
precision=self.precision, | |
factual_consistency_rate=factual_consistency_rate, | |
hallucination_rate=self.humanlike, | |
answer_rate=answer_rate, | |
avg_summary_len=avg_summary_len) | |
return results | |
except FileNotFoundError: | |
logging.error(f"File not found: {envs.DATASET_PATH}") | |
raise | |
except Exception as e: | |
logging.error(f"Error during evaluation: {e}") | |
raise | |
def write_results(self): | |
print('Updating result files') | |
leaderboard_path = os.getcwd() # the path of leaderboard folder | |
print(leaderboard_path) | |
working_path = os.path.join(leaderboard_path, 'Humanlike Leaderboard Results') | |
if not os.path.exists(working_path): | |
logging.error(f"Need to first download the results from google drive to the learderboard folder") | |
raise | |
source_summary_df = self.generated_summaries_df[["user_prompt", "response"]] | |
# #update leaderboard_summaries.csv | |
# #first remove previous results for the current model | |
# existing_df = pd.read_csv(os.path.join(working_path, 'leaderboard_summaries.csv'), encoding='utf-8', sep="\t") | |
# mask = existing_df['model'] == self.model | |
# existing_df = existing_df[~mask] | |
# # get new result | |
leaderboard_summaries_df = source_summary_df | |
leaderboard_summaries_df.insert(2, "model", [self.model]*leaderboard_summaries_df.shape[0]) | |
leaderboard_summaries_df.to_csv(os.path.join(working_path, 'leaderboard_summaries.csv'), mode='a', index=False, header=False) | |
print('leaderboard_summaries.csv has been updated') | |
# update leaderboard_summaries_with_scores.csv | |
# BUG: get error when opening the file | |
# existing_df = pd.read_csv(os.path.join(working_path, 'leaderboard_summaries_with_scores.csv'), | |
# encoding='utf-8', sep=",", on_bad_lines='warn', quotechar='"', quoting=2) | |
# print(existing_df.shape) | |
# mask = existing_df['model'] == self.model | |
# existing_df = existing_df[~mask] | |
# get new result | |
leaderboard_summaries_with_scores_df = pd.DataFrame.from_dict(self.eval_results) | |
leaderboard_summaries_with_scores_df.insert(3, "model", [self.model]*leaderboard_summaries_with_scores_df.shape[0]) | |
leaderboard_summaries_with_scores_df.to_csv(os.path.join(working_path, 'leaderboard_summaries_with_scores.csv'), mode='a', index=False, header=False) | |
print('leaderboard_summaries_with_scores.csv has been updated') |