File size: 9,317 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""PPL Inferencer."""

import os
from typing import List, Optional

import torch
from tqdm import trange

from opencompass.models.base import BaseModel
from opencompass.registry import ICL_INFERENCERS

from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, PPLInferencerOutputHandler

logger = get_logger(__name__)


@ICL_INFERENCERS.register_module()
class PPLInferencer(BaseInferencer):
    """PPL Inferencer class to evaluate by perplexity.

    Attributes:
        model (:obj:`BaseModel`, optional): The module to inference.
        max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
            the LM.
        batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
        output_json_filepath (:obj:`str`, optional): File path for output
            `JSON` file.
        output_json_filename (:obj:`str`, optional): File name for output
            `JSON` file.
        labels (:obj:`List`, optional): A list of labels for all classes.
    """

    def __init__(
            self,
            model: BaseModel,
            max_seq_len: Optional[int] = None,
            batch_size: Optional[int] = 1,
            output_json_filepath: Optional[str] = './icl_inference_output',
            output_json_filename: Optional[str] = 'predictions',
            labels: Optional[List] = None,
            **kwargs) -> None:
        super().__init__(
            model=model,
            max_seq_len=max_seq_len,
            batch_size=batch_size,
            output_json_filename=output_json_filename,
            output_json_filepath=output_json_filepath,
            **kwargs,
        )

        self.labels = labels

    def inference(self,
                  retriever: BaseRetriever,
                  ice_template: Optional[PromptTemplate] = None,
                  prompt_template: Optional[PromptTemplate] = None,
                  output_json_filepath: Optional[str] = None,
                  output_json_filename: Optional[str] = None,
                  normalizing_str: Optional[str] = None) -> List:
        # 1. Preparation for output logs
        output_handler = PPLInferencerOutputHandler()

        sub_predictions = []
        ppl = []
        ice = []

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = self.output_json_filename

        # 2. Get results of retrieval process
        ice_idx_list = retriever.retrieve()

        # 3. Get labels of all the classes
        if self.labels is None:
            labels = retriever.get_labels(ice_template=ice_template,
                                          prompt_template=prompt_template)
        else:
            labels = self.labels

        # 4. Generate in-context examples for testing inputs
        for idx in range(len(ice_idx_list)):
            ice.append(
                retriever.generate_ice(ice_idx_list[idx],
                                       ice_template=ice_template))
        output_handler.save_ice(self.model.parse_template(ice, mode='ppl'))

        # 5. Calculating PPL for prompts in each label's class
        for label in labels:
            index = 0
            prompt_list = []
            sub_ppl_list = []
            token_num_list = []
            normalizing_prompt_list = []
            context_length_list = []

            # 5.1 Generate prompts of current label and truncate
            # TODO: Refactor
            for idx in range(len(ice_idx_list)):
                prompt = retriever.generate_label_prompt(
                    idx,
                    ice[idx],
                    label,
                    ice_template=ice_template,
                    prompt_template=prompt_template,
                    remain_sep=normalizing_str is not None)
                if self.max_seq_len is not None:
                    prompt_token_num = self.model.get_token_len_from_template(
                        prompt, mode='ppl')
                    while len(ice_idx_list[idx]
                              ) > 0 and prompt_token_num > self.max_seq_len:
                        ice_idx_list[idx] = ice_idx_list[idx][:-1]
                        ice[idx] = retriever.generate_ice(
                            ice_idx_list[idx], ice_template=ice_template)
                        prompt = retriever.generate_label_prompt(
                            idx,
                            ice[idx],
                            label,
                            ice_template=ice_template,
                            prompt_template=prompt_template)
                        prompt_token_num = self.model.get_token_len_from_template(  # noqa
                            prompt, mode='ppl')  # noqa

                if normalizing_str is not None:
                    assert isinstance(prompt, str), \
                         'Prompt must be a string when normalizing_str is set.'
                    prompt_sep = prompt
                    if prompt_template is not None:
                        sep_token = prompt_template.sep_token
                    else:
                        sep_token = ice_template.sep_token
                    sep_pos = prompt_sep.find(sep_token)

                    context = prompt_sep[0:sep_pos]
                    answer = prompt_sep[sep_pos:].replace(sep_token, '')
                    prompt = context + answer
                    normalizing_prompt = normalizing_str + answer

                    context_length_list.append(
                        self.model.get_token_len_from_template(context,
                                                               mode='ppl'))
                    normalizing_prompt_list.append(normalizing_prompt)
                prompt_list.append(prompt)
                token_num_list.append(prompt_token_num)

            if normalizing_str is not None:
                normalizing_str_len = self.model.get_token_len_from_template(
                    normalizing_str, mode='ppl')

            # 5.2 Get PPL
            logger.info(f"Calculating PPL for prompts labeled '{label}'")
            for idx in trange(0,
                              len(prompt_list),
                              self.batch_size,
                              disable=not self.is_main_process):
                sub_prompt_list = prompt_list[idx:idx + self.batch_size]
                if normalizing_str is not None:
                    sub_context_length_list = context_length_list[idx:idx +
                                                                  self.
                                                                  batch_size]
                    sub_normalizing_prompt_list = normalizing_prompt_list[
                        idx:idx + self.batch_size]

                with torch.no_grad():
                    if normalizing_str is not None:
                        res1 = self.model.get_ppl_from_template(
                            sub_prompt_list,
                            mask_length=sub_context_length_list)
                        res2 = self.model.get_ppl_from_template(
                            sub_normalizing_prompt_list,
                            mask_length=[
                                normalizing_str_len
                                for i in range(len(sub_prompt_list))
                            ])
                        sub_res = res1 - res2
                    else:
                        sub_res = self.model.get_ppl_from_template(
                            sub_prompt_list).tolist()
                for res, prompt in zip(
                        sub_res,
                        self.model.parse_template(sub_prompt_list,
                                                  mode='ppl')):
                    sub_ppl_list.append(res)
                    ice_str = self.model.parse_template(ice[idx], mode='ppl')
                    output_handler.save_prompt_and_ppl(
                        label, prompt.replace(ice_str, ''), prompt, res, index)
                    output_handler.results_dict[str(
                        index)][f'label: {str(label)}'][
                            'BPB'] = res * token_num_list[index] / len(
                                prompt.replace(ice_str, '').encode())
                    index = index + 1
            ppl.append(sub_ppl_list)

        # 6. Get lowest PPL class as predictions
        ppl = list(zip(*ppl))
        for single_ppl in ppl:
            sub_predictions.append(labels[single_ppl.index(min(single_ppl))])
        output_handler.save_predictions(sub_predictions)

        # 7. Fetch gold answers if exist
        ds_reader = retriever.dataset_reader
        if ds_reader.output_column:
            golds = ds_reader.dataset['test'][ds_reader.output_column]
            output_handler.save_golds(golds)

        # 8. Output
        if self.is_main_process:
            os.makedirs(output_json_filepath, exist_ok=True)
            output_handler.write_to_json(output_json_filepath,
                                         output_json_filename)

        return [
            sample['prediction']
            for sample in output_handler.results_dict.values()
        ]