File size: 7,082 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
# flake8: noqa
import json
import re
from . import dataset_loader
def extract_last_line(string):
lines = string.split('\n')
for item in lines[::-1]:
if item.strip() != '':
string = item
break
return string
def remove_few_shot_prefix(string: str):
prefix_list = ['The answer is therefore', '答案是']
for prefix in prefix_list:
if string.startswith(prefix):
string = string[len(prefix):].strip()
elif prefix in string:
index = string.rfind(prefix)
if index >= 0:
string = string[index + len(prefix):].strip()
return string
def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if language == 'en':
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
elif language == 'zh':
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
else:
raise ValueError('Unknown language {0}'.format(language))
if match:
return match.group(1)
else:
return None
def try_parse_few_shot_pattern(string: str, dataset_name, setting_name):
if setting_name == 'few-shot-CoT':
string = extract_last_line(string)
if dataset_name in dataset_loader.chinese_cloze_datasets:
return string.startswith('答案是')
elif dataset_name in dataset_loader.english_cloze_datasets:
return string.startswith('The answer is therefore')
elif dataset_name in dataset_loader.chinese_qa_datasets:
pattern = '答案是.*?([A-G])'
match = re.search(pattern, string)
return match is not None
elif dataset_name in dataset_loader.english_qa_datasets:
pattern = 'answer is .*?([A-G])'
match = re.search(pattern, string)
return match is not None
return False
def parse_few_shot_qa_single_answer(string, setting_name, language='en'):
answer = try_parse_few_shot_qa_single_answer(string, setting_name,
language)
if answer is None:
return find_first_capital_letter(string)
else:
return answer
def find_first_capital_letter(answer):
letter_set = {'A', 'B', 'C', 'D', 'E', 'F'}
for c in answer:
if c in letter_set:
return c
# print("Can't find capital letter in:", answer)
return ''
def extract_answer_in_bracket(answer, prefix='【', suffix='】'):
if prefix not in answer and suffix not in answer:
# print("doesn't found special tokens in:", answer)
return ''
s = answer.index(prefix) + len(prefix)
t = answer.index(suffix)
ret = answer[s:t]
return ret
def parse_math_answer(setting_name, raw_string):
if setting_name == 'few-shot-CoT':
raw_string = extract_last_line(raw_string)
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot':
raw_string = remove_few_shot_prefix(raw_string)
return raw_string
def remove_boxed(s):
left = '\\boxed{'
try:
assert s[:len(left)] == left
assert s[-1] == '}'
answer = s[len(left):-1]
if '=' in answer:
answer = answer.split('=')[-1].lstrip(' ')
return answer
except:
return None
def last_boxed_only_string(string):
idx = string.rfind('\\boxed')
if idx < 0:
idx = string.rfind('\\fbox')
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == '{':
num_left_braces_open += 1
if string[i] == '}':
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx == None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def get_answer_with_dollar_sign(s):
first_pattern = '\$(.*)\$'
last_match = None
matches = re.findall(first_pattern, s)
if matches:
last_match = matches[-1]
if '=' in last_match:
last_match = last_match.split('=')[-1].lstrip(' ')
return last_match
def get_answer_without_dollar_sign(s):
last_match = None
if '=' in s:
last_match = s.split('=')[-1].lstrip(' ').rstrip('.')
if '\n' in last_match:
last_match = last_match.split('\n')[0]
else:
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])'
matches = re.findall(pattern, s)
if matches:
last_match = matches[-1]
return last_match
raw_string = remove_few_shot_prefix(raw_string)
if '\\boxed' in raw_string:
answer = remove_boxed(last_boxed_only_string(raw_string))
else:
answer = get_answer_with_dollar_sign(raw_string)
if not answer:
answer = get_answer_without_dollar_sign(raw_string)
return answer
def parse_qa_multiple_answer(string):
# if setting_name == 'few-shot-CoT':
# string = extract_last_line(string)
for x in ['CC', 'CA', 'AC', 'POMES', 'AI', 'MIBG', 'CF', 'CTE', 'AD', 'CB', 'BG', 'BD', 'BE', 'BH', 'CTB', 'BI', 'CE', 'Pugh', 'Child', 'CTI', 'CTA', 'TACE', 'PPD', 'Castleman', 'BA', 'CH', 'AB', 'CTC', 'CT', 'CTH', 'CD', 'AH', 'AE', 'AA', 'AF', 'BC', 'CG', 'BB', 'CI', 'BF', 'CTF', 'CTG', 'AG', 'CTD', '分级C', '分级A', 'I131', '分级B', '分级D', '131I‐MIBG', 'NYHA', 'IPF', 'DIP', 'Lambert-Eaton', 'Graves', 'IIA期', 'CKD', 'FDA', 'A级', 'B级', 'C级', 'D级', '维生素D']:
string = string.replace(x, '')
pattern = '\(*([A-Z])\)*'
match = re.findall(pattern, string)
if match:
return match
return []
def post_process(dataset_name, setting_name, prediction):
if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets:
return parse_math_answer(setting_name, prediction)
if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']:
return parse_qa_multiple_answer(prediction, setting_name)
# all other datasets are QA problems with single answer
if 'zero-shot' in setting_name:
answer = find_first_capital_letter(prediction)
return answer
# all other datasets are QA problems with single answer and setting_name are few-shot
language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh'
if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets:
return parse_few_shot_qa_single_answer(prediction, setting_name,
language)
else:
raise ValueError(f'Unsupported dataset name {dataset_name}')
|