api-demo
/
opencompass-my-api
/opencompass
/datasets
/lawbench
/evaluation_functions
/ljp_imprison.py
import math | |
import cn2an | |
import re | |
#法律判决预测-刑期预测 | |
def compute_ljp_imprison(data_dict): | |
score_list, abstentions = [], 0 | |
for example in data_dict: | |
question, prediction, answer = example["origin_prompt"], example["prediction"], example["refr"] | |
# get answer digit, which is the number between "刑期:" and "个月" | |
if "æ»åˆ‘" in answer or "æ— æœŸ" in answer: | |
# TODO: data imperfection | |
continue | |
assert answer.startswith("刑期:") and answer.endswith("个月"), f"answer: {answer}, question: {question}" | |
answer = answer.replace("刑期:", "") | |
answer = answer.replace("个月", "") | |
answer_digit = int(answer) | |
prediction = cn2an.transform(prediction, "cn2an") | |
# use regular expression to extract the digits from prediction, only consider digits before "个月" or "月" | |
prediction_digit_month_list = re.findall(r"\d+个月", prediction) | |
prediction_digit_month_list = [int(digit.replace("个月", "")) for digit in prediction_digit_month_list] | |
prediction_digit_month_list2 = re.findall(r"\d+月", prediction) | |
prediction_digit_month_list2 = [int(digit.replace("月", "")) for digit in prediction_digit_month_list2] | |
prediction_digit_month_list.extend(prediction_digit_month_list2) | |
# catches the digits before "å¹´" | |
prediction_digit_year_list = re.findall(r"\d+å¹´", prediction) | |
prediction_digit_year_list = [int(digit.replace("å¹´", "")) for digit in prediction_digit_year_list] | |
if len(prediction_digit_month_list) > 0: | |
prediction_digit_month = int(prediction_digit_month_list[0]) | |
elif len(prediction_digit_year_list) > 0: | |
prediction_digit_month = int(prediction_digit_year_list[0]) * 12 | |
else: | |
abstentions += 1 | |
prediction_digit_month = -1 | |
if prediction_digit_month != -1: | |
score_list.append(abs(math.log(answer_digit + 1) - math.log(prediction_digit_month + 1))) | |
else: | |
score_list.append(math.log(216)) | |
# compute the average of score_list (log distance) | |
log_distance = sum(score_list) / len(score_list) | |
# normalize the score to between 0 and 1 | |
log_distance = (math.log(216) - log_distance)/math.log(216) | |
return {"score": log_distance, "abstention_rate": abstentions/len(data_dict)} | |