File size: 5,540 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import ast
import json
from datasets import Dataset
from opencompass.openicl.icl_evaluator import BaseEvaluator
from opencompass.registry import ICL_EVALUATORS, LOAD_DATASET
from ..base import BaseDataset
from .prompts import edpPrompts
def q2text(q, p=edpPrompts):
string_a = q['string_a']
string_b = q['string_b']
prompt_text = p['Intro'] + '\n' + \
p['Initial_question'].format(string_a=string_a, string_b=string_b) + '\n' + \
p['Output_content'] + '\n' + \
p['Output_format']
return prompt_text
@LOAD_DATASET.register_module(force=True)
class p_EDP_Dataset(BaseDataset):
@staticmethod
def load(path: str):
raw_data = []
data_path = path
all_data = []
with open(data_path + 'edp_instances.json', 'r') as f:
data = json.load(f)
for sample in data:
level = len(sample['string_a']) - 2
all_data.append((level, sample))
for level, q in all_data:
prompt = q2text(q)
raw_data.append({
'prompt': prompt,
'q': str(level) + '####\n' + json.dumps(q),
'level': level
})
dataset = Dataset.from_list(raw_data)
return dataset
@ICL_EVALUATORS.register_module(force=True)
class p_EDP_Evaluator(BaseEvaluator):
def score(self, predictions, references):
assert len(predictions) == len(references)
result = {'pass': 0, 'fail': 0}
for index, (q, output) in enumerate(zip(references, predictions)):
output_dict = {}
level = int(q.split('####\n')[0])
q = json.loads(q.split('####\n')[-1])
output, reasoning = self.parse_xml_to_dict(output)
output_dict['output'] = output
try:
output_dict['correctness'], _ = self.edp_check(q, output)
except Exception as e:
print(f'Check failed: {e}')
output_dict['correctness'] = False
output_dict['reasoning'] = reasoning
output_dict['level'] = level
if output_dict['correctness']:
r = 'pass'
else:
r = 'fail'
result[r] += level
result['score'] = result['pass'] / (result['pass'] + result['fail']) * 100
final_result = {'Weighted Accuracy': result['score']}
return final_result
def compute_min_edit_distance(self, string_a, string_b):
"""Computes the minimum edit distance between two strings using dynamic
programming."""
m, n = len(string_a), len(string_b)
dp = [[0] * (n + 1) for _ in range(m + 1)]
for i in range(m + 1):
for j in range(n + 1):
if i == 0:
dp[i][j] = j
elif j == 0:
dp[i][j] = i
elif string_a[i - 1] == string_b[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = 1 + min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1])
return dp[m][n]
def edp_check(self, instance, solution):
"""Check if the edit distance solution is valid.
:param instance: The instance dictionary with 'string_a' and 'string_b'.
:param solution: The solution dictionary with the reported 'edit_distance'.
:return: A tuple of (is_correct, message).
"""
string_a = instance['string_a']
string_b = instance['string_b']
try:
reported_distance = int(solution.get('Operations', -1))
except Exception:
reported_distance = -1
actual_distance = self.compute_min_edit_distance(string_a, string_b)
if reported_distance == -1:
return False, 'No solution provided.'
elif reported_distance != actual_distance:
return False, f'The reported edit distance ({reported_distance}) is incorrect. Actual distance: {actual_distance}.'
return True, 'The solution is valid.'
def parse_xml_to_dict(self, xml_string):
try:
assert '<final_answer>' in xml_string
assert '</final_answer>' in xml_string
# assert '<reasoning>' in xml_string
# assert '</reasoning>' in xml_string
final_answer_start = xml_string.index('<final_answer>') + len('<final_answer>')
final_answer_end = xml_string.index('</final_answer>')
# reasoning_start = xml_string.index('<reasoning>') + len('<reasoning>')
# reasoning_end = xml_string.index('</reasoning>')
final_answer_element = xml_string[final_answer_start:final_answer_end].rstrip().strip().rstrip()
assert '{' in final_answer_element
assert '}' in final_answer_element
dic_start = final_answer_element.index('{')
dic_end = final_answer_element.index('}')
final_answer_element = final_answer_element[dic_start:dic_end + 1].rstrip().strip().rstrip()
reasoning_element = xml_string
# reasoning_element = xml_string[reasoning_start:reasoning_end].rstrip().strip().rstrip()
try:
final_answer_element = ast.literal_eval(final_answer_element)
except Exception:
final_answer_element = ''
except Exception:
final_answer_element = ''
reasoning_element = ''
return final_answer_element, reasoning_element
|