File size: 3,948 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import re

from opencompass.datasets.humaneval import humaneval_gpt_postprocess
from opencompass.datasets.record import ReCoRD_postprocess
from opencompass.datasets.xsum import Xsum_postprocess
from opencompass.utils.text_postprocessors import first_option_postprocess


def gsm8k_postprocess(text: str) -> str:
    text = text.split(' ')[::-1]
    flag = False
    ret = ''
    for i in range(len(text)):
        s = text[i]
        for i in range(len(s)):
            if s[i].isdigit():
                flag = True
                ret = s
                break
        if flag:
            break
    ret1 = ''
    for i in range(len(ret)):
        if ret[i].isdigit():
            ret1 += ret[i]
    return ret1


def humaneval_postprocess(text: str) -> str:
    text = '\n'.join(text.split('\n')[1:]).strip()
    if '```' in text:
        blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
        if len(blocks) == 0:
            text = text.split('```')[1]  # fall back to default strategy
        else:
            text = blocks[0]  # fetch the first code block
            if not text.startswith('\n'):  # in case starting with ```python
                text = text[max(text.find('\n') + 1, 0):]
    if text.strip().startswith('from') or text.strip().startswith('import'):
        def_idx = text.find('def')
        if def_idx != -1:
            text = text[max(text.find('\n', def_idx) + 1, 0):]
    if text.strip().startswith('def'):
        text = '\n'.join(text.split('\n')[1:])
    if not text.startswith('    '):
        if text.startswith(' '):
            text = '    ' + text.lstrip()
        else:
            text = '\n'.join(['    ' + line for line in text.split('\n')])
    return text


def lcsts_postprocess(text: str) -> str:
    text = text.strip()
    text = text.replace('1. ', '') if text.startswith('1. ') else text
    text = text.replace('- ', '') if text.startswith('- ') else text
    text = text.strip('“,。!”')
    return text


def mbpp_postprocess(text: str) -> str:
    if text.startswith('Here'):
        text = '\n'.join(text.split('\n')[1:]).strip()
    if '```' in text:
        blocks = re.findall(r'```(.*?)```', text, re.DOTALL)
        if len(blocks) == 0:
            text = text.split('```')[1]  # fall back to default strategy
        else:
            text = blocks[0]  # fetch the first code block
            if not text.startswith('\n'):  # in case starting with ```python
                text = text[max(text.find('\n') + 1, 0):]
    return text


def strategyqa_pred_postprocess(text: str) -> str:
    if text.startswith('Here'):
        text = '\n'.join(text.split('\n')[1:]).strip()
    text = text.split('answer is ')[-1]
    match = re.search(r'(yes|no)', text.lower())
    if match:
        return match.group(1)
    return ''


def flores_postprocess(text: str) -> str:
    text = text.strip().split('\n')[-1].strip()
    return text


def flores_postprocess_chinese(text: str) -> str:
    text = text.strip().split('\n')[-1].strip()
    import jieba
    truncated_text = text.strip().split('\n')[0]
    cleaned_text = re.sub(r'\s+', ' ', truncated_text).strip()
    cleaned_text = ' '.join(jieba.cut(cleaned_text))
    return cleaned_text


def record_postprocess(text: str) -> str:
    match = re.search(r'(?<=refers to )[^.]+', text)

    if match:
        return match.group().strip()  # Outputs: abc def

    return ReCoRD_postprocess(text)


def humaneval_claude2_postprocess(text: str) -> str:
    if text.startswith('Here'):
        text = '\n\n'.join(text.split('\n\n')[1:])
    return humaneval_gpt_postprocess(text)


def xsum_postprocess(text: str) -> str:
    if text.startswith('Here'):
        text = '\n\n'.join(text.split('\n\n')[1:])
    return Xsum_postprocess(text)


def yes_no_postprocess(text: str) -> str:
    if 'yes' in text.lower():
        return 'A'
    elif 'no' in text.lower():
        return 'B'
    return first_option_postprocess(text, 'AB')