|
|
|
import json |
|
import re |
|
|
|
from . import dataset_loader |
|
|
|
|
|
def extract_last_line(string): |
|
lines = string.split('\n') |
|
for item in lines[::-1]: |
|
if item.strip() != '': |
|
string = item |
|
break |
|
return string |
|
|
|
|
|
def remove_few_shot_prefix(string: str): |
|
prefix_list = ['The answer is therefore', '答案是'] |
|
for prefix in prefix_list: |
|
if string.startswith(prefix): |
|
string = string[len(prefix):].strip() |
|
elif prefix in string: |
|
index = string.rfind(prefix) |
|
if index >= 0: |
|
string = string[index + len(prefix):].strip() |
|
return string |
|
|
|
|
|
def try_parse_few_shot_qa_single_answer(string, setting_name, language='en'): |
|
if setting_name == 'few-shot-CoT': |
|
string = extract_last_line(string) |
|
if language == 'en': |
|
pattern = 'answer is .*?([A-G])' |
|
match = re.search(pattern, string) |
|
elif language == 'zh': |
|
pattern = '答案是.*?([A-G])' |
|
match = re.search(pattern, string) |
|
else: |
|
raise ValueError('Unknown language {0}'.format(language)) |
|
if match: |
|
return match.group(1) |
|
else: |
|
return None |
|
|
|
|
|
def try_parse_few_shot_pattern(string: str, dataset_name, setting_name): |
|
if setting_name == 'few-shot-CoT': |
|
string = extract_last_line(string) |
|
if dataset_name in dataset_loader.chinese_cloze_datasets: |
|
return string.startswith('答案是') |
|
elif dataset_name in dataset_loader.english_cloze_datasets: |
|
return string.startswith('The answer is therefore') |
|
elif dataset_name in dataset_loader.chinese_qa_datasets: |
|
pattern = '答案是.*?([A-G])' |
|
match = re.search(pattern, string) |
|
return match is not None |
|
elif dataset_name in dataset_loader.english_qa_datasets: |
|
pattern = 'answer is .*?([A-G])' |
|
match = re.search(pattern, string) |
|
return match is not None |
|
return False |
|
|
|
|
|
def parse_few_shot_qa_single_answer(string, setting_name, language='en'): |
|
answer = try_parse_few_shot_qa_single_answer(string, setting_name, |
|
language) |
|
if answer is None: |
|
return find_first_capital_letter(string) |
|
else: |
|
return answer |
|
|
|
|
|
def find_first_capital_letter(answer): |
|
letter_set = {'A', 'B', 'C', 'D', 'E', 'F'} |
|
for c in answer: |
|
if c in letter_set: |
|
return c |
|
|
|
return '' |
|
|
|
|
|
def extract_answer_in_bracket(answer, prefix='【', suffix='】'): |
|
if prefix not in answer and suffix not in answer: |
|
|
|
return '' |
|
s = answer.index(prefix) + len(prefix) |
|
t = answer.index(suffix) |
|
ret = answer[s:t] |
|
return ret |
|
|
|
|
|
def parse_math_answer(setting_name, raw_string): |
|
if setting_name == 'few-shot-CoT': |
|
raw_string = extract_last_line(raw_string) |
|
if setting_name == 'few-shot-CoT' or setting_name == 'few-shot': |
|
raw_string = remove_few_shot_prefix(raw_string) |
|
return raw_string |
|
|
|
def remove_boxed(s): |
|
left = '\\boxed{' |
|
try: |
|
assert s[:len(left)] == left |
|
assert s[-1] == '}' |
|
answer = s[len(left):-1] |
|
if '=' in answer: |
|
answer = answer.split('=')[-1].lstrip(' ') |
|
return answer |
|
except: |
|
return None |
|
|
|
def last_boxed_only_string(string): |
|
idx = string.rfind('\\boxed') |
|
if idx < 0: |
|
idx = string.rfind('\\fbox') |
|
if idx < 0: |
|
return None |
|
i = idx |
|
right_brace_idx = None |
|
num_left_braces_open = 0 |
|
while i < len(string): |
|
if string[i] == '{': |
|
num_left_braces_open += 1 |
|
if string[i] == '}': |
|
num_left_braces_open -= 1 |
|
if num_left_braces_open == 0: |
|
right_brace_idx = i |
|
break |
|
i += 1 |
|
|
|
if right_brace_idx == None: |
|
retval = None |
|
else: |
|
retval = string[idx:right_brace_idx + 1] |
|
|
|
return retval |
|
|
|
def get_answer_with_dollar_sign(s): |
|
first_pattern = '\$(.*)\$' |
|
last_match = None |
|
matches = re.findall(first_pattern, s) |
|
if matches: |
|
last_match = matches[-1] |
|
if '=' in last_match: |
|
last_match = last_match.split('=')[-1].lstrip(' ') |
|
return last_match |
|
|
|
def get_answer_without_dollar_sign(s): |
|
last_match = None |
|
if '=' in s: |
|
last_match = s.split('=')[-1].lstrip(' ').rstrip('.') |
|
if '\n' in last_match: |
|
last_match = last_match.split('\n')[0] |
|
else: |
|
pattern = '(?:\\$)?\d+(?:\.\d+)?(?![\w\d])' |
|
matches = re.findall(pattern, s) |
|
if matches: |
|
last_match = matches[-1] |
|
return last_match |
|
|
|
raw_string = remove_few_shot_prefix(raw_string) |
|
if '\\boxed' in raw_string: |
|
answer = remove_boxed(last_boxed_only_string(raw_string)) |
|
else: |
|
answer = get_answer_with_dollar_sign(raw_string) |
|
if not answer: |
|
answer = get_answer_without_dollar_sign(raw_string) |
|
return answer |
|
|
|
|
|
def parse_qa_multiple_answer(string): |
|
|
|
|
|
for x in ['CC', 'CA', 'AC', 'POMES', 'AI', 'MIBG', 'CF', 'CTE', 'AD', 'CB', 'BG', 'BD', 'BE', 'BH', 'CTB', 'BI', 'CE', 'Pugh', 'Child', 'CTI', 'CTA', 'TACE', 'PPD', 'Castleman', 'BA', 'CH', 'AB', 'CTC', 'CT', 'CTH', 'CD', 'AH', 'AE', 'AA', 'AF', 'BC', 'CG', 'BB', 'CI', 'BF', 'CTF', 'CTG', 'AG', 'CTD', '分级C', '分级A', 'I131', '分级B', '分级D', '131I‐MIBG', 'NYHA', 'IPF', 'DIP', 'Lambert-Eaton', 'Graves', 'IIA期', 'CKD', 'FDA', 'A级', 'B级', 'C级', 'D级', '维生素D']: |
|
string = string.replace(x, '') |
|
pattern = '\(*([A-Z])\)*' |
|
match = re.findall(pattern, string) |
|
if match: |
|
return match |
|
return [] |
|
|
|
|
|
def post_process(dataset_name, setting_name, prediction): |
|
if dataset_name in dataset_loader.english_cloze_datasets or dataset_name in dataset_loader.chinese_cloze_datasets: |
|
return parse_math_answer(setting_name, prediction) |
|
|
|
if dataset_name in ['jec-qa-kd', 'jec-qa-ca', 'gaokao-physics']: |
|
return parse_qa_multiple_answer(prediction, setting_name) |
|
|
|
|
|
if 'zero-shot' in setting_name: |
|
answer = find_first_capital_letter(prediction) |
|
return answer |
|
|
|
|
|
language = 'en' if dataset_name in dataset_loader.english_qa_datasets else 'zh' |
|
if dataset_name in dataset_loader.english_qa_datasets or dataset_name in dataset_loader.chinese_qa_datasets: |
|
return parse_few_shot_qa_single_answer(prediction, setting_name, |
|
language) |
|
else: |
|
raise ValueError(f'Unsupported dataset name {dataset_name}') |
|
|