File size: 11,661 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""CLP Inferencer."""

import itertools
import os
from typing import List, Optional

import torch.nn.functional as F
from tqdm import trange

from opencompass.models import BaseModel
from opencompass.registry import ICL_INFERENCERS

from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler

logger = get_logger(__name__)


@ICL_INFERENCERS.register_module()
class CLPInferencer(BaseInferencer):
    """Conditional log probability based In-context Learning Inferencer.

    Calculate the log probability of each choices according the logits.
    The input is the context with single choice, e.g. Q: xx.\n A: first choice
    to this question.
    And starting from the first token of this choice, sum up all the log
    probabilities of each
    tokens from logits. Then, compare each choice with softmax.

    There are two scenarios in this case:
    1. Single token choices. Already supported.
    2. Muiltple token choices. TODO: More complicated and needs to be added in
       the future for specific dataset.

    Attributes:
        model (:obj:`BaseModel`, optional): The module to inference.
        max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
            the LM.
        batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
        output_json_filepath (:obj:`str`, optional): File path for output
            `JSON` file.
        output_json_filename (:obj:`str`, optional): File name for output
            `JSON` file.
        single_token (:obj:`bool`): If ``True``, choices only have one token to
            calculate. Defaults to True. Currently only support True.
    """

    def __init__(
            self,
            model: BaseModel,
            max_seq_len: Optional[int] = None,
            batch_size: Optional[int] = 1,
            output_json_filepath: Optional[str] = './icl_inference_output',
            output_json_filename: Optional[str] = 'predictions',
            single_token: bool = True,
            **kwargs) -> None:
        super().__init__(
            model=model,
            max_seq_len=max_seq_len,
            batch_size=batch_size,
            output_json_filename=output_json_filename,
            output_json_filepath=output_json_filepath,
            **kwargs,
        )

        # TODO: support multiple token
        assert single_token, 'Only support single token choice currently.'
        self.single_token = single_token

    def inference(self,
                  retriever: BaseRetriever,
                  ice_template: Optional[PromptTemplate] = None,
                  prompt_template: Optional[PromptTemplate] = None,
                  output_json_filepath: Optional[str] = None,
                  output_json_filename: Optional[str] = None,
                  normalizing_str: Optional[str] = None) -> List:
        # 1. Preparation for output logs
        output_handler = CLPInferencerOutputHandler()

        ice = []

        if output_json_filepath is None:
            output_json_filepath = self.output_json_filepath
        if output_json_filename is None:
            output_json_filename = self.output_json_filename

        # CLP cannot infer with log probability for api models
        # unless model provided such options which needs specific
        # implementation, open an issue if you encounter the case.
        if self.model.is_api:
            # Write empty file in case always rerun for this model
            if self.is_main_process:
                os.makedirs(output_json_filepath, exist_ok=True)
                err_msg = 'API model is not supported for conditional log '\
                    'probability inference and skip this exp.'
                output_handler.results_dict = {'error': err_msg}
                output_handler.write_to_json(output_json_filepath,
                                             output_json_filename)
            raise ValueError(err_msg)

        # 2. Get results of retrieval process
        ice_idx_list = retriever.retrieve()

        # 3. Generate in-context examples for testing inputs
        for idx in range(len(ice_idx_list)):
            ice.append(
                retriever.generate_ice(ice_idx_list[idx],
                                       ice_template=ice_template))
        output_handler.save_ice(ice)

        # 4. Collect prompts and calculate conditional log probs
        if self.single_token:
            index = 0
            prompt_list = []
            target_pos = []
            # TODO: Hard code temperaily, need to modified here
            choices = retriever.test_ds[0]['choices']
            try:
                choice_ids = [
                    self.model.tokenizer.encode(c, False, False)
                    for c in choices
                ]
            except ValueError:
                choice_ids = [self.model.tokenizer.encode(c) for c in choices]
                if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer':  # noqa
                    choice_ids = [c[2:] for c in choice_ids]
                elif hasattr(self.model.tokenizer, 'add_bos_token'):
                    if self.model.tokenizer.add_bos_token:
                        choice_ids = [c[1:] for c in choice_ids]
                    if self.model.tokenizer.add_eos_token:
                        choice_ids = [c[:-1] for c in choice_ids]
            if isinstance(choice_ids[0], list):
                # in case tokenizer returns list for single token
                choice_ids = list(itertools.chain(*choice_ids))

            get_token_len = self.model.get_token_len

            if hasattr(self.model.tokenizer, 'padding_side'):
                # get padding_side for huggingface model
                padding_side = self.model.tokenizer.padding_side
            else:
                # defaults to left for internal model
                padding_side = 'left'

            # prepare in context for each example and control the length
            for idx in range(len(ice_idx_list)):
                prompt = retriever.generate_prompt_for_generate_task(
                    idx,
                    ice[idx],
                    ice_template=ice_template,
                    prompt_template=prompt_template)
                prompt = self.model.parse_template(prompt, mode='gen')
                if self.max_seq_len is not None:
                    prompt_token_num = get_token_len(prompt)
                    # add one because additional token will be added in the end
                    while len(
                            ice_idx_list[idx]
                    ) > 0 and prompt_token_num + 1 > self.max_seq_len:
                        ice_idx_list[idx] = ice_idx_list[idx][:-1]
                        ice[idx] = retriever.generate_ice(
                            ice_idx_list[idx], ice_template=ice_template)
                        prompt = retriever.generate_prompt_for_generate_task(
                            idx,
                            ice[idx],
                            ice_template=ice_template,
                            prompt_template=prompt_template)
                        prompt_token_num = get_token_len(prompt)
                prompt_list.append(prompt)
                # in case prompt token num reaches max
                if self.max_seq_len is not None and \
                        prompt_token_num + 1 > self.max_seq_len:
                    prompt_token_num = self.max_seq_len - 1

                # get the target position index
                if padding_side == 'left':
                    # always the last position
                    target_pos.append(-1)
                else:
                    # the last position of the original prompt
                    target_pos.append(prompt_token_num - 1)

            # 4.1 Fetch and zip prompt & gold answer if output column exists
            ds_reader = retriever.dataset_reader
            if ds_reader.output_column:
                gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
            else:
                gold_ans = [None] * len(prompt_list)

            if hasattr(self.model, 'batch_padding'):
                # get batch padding for huggingface model
                batch_padding = self.model.batch_padding
            else:
                # defaults to False for internal model
                batch_padding = False

            logger.info('Calculating conditional log probability for prompts.')
            for idx in trange(0,
                              len(prompt_list),
                              self.batch_size,
                              disable=not self.is_main_process):
                # get batch data
                sub_prompt_list = prompt_list[idx:idx + self.batch_size]
                sub_golds = gold_ans[idx:idx + self.batch_size]
                sub_target_pos = target_pos[idx:idx + self.batch_size]

                # get probability result
                if batch_padding and self.batch_size > 1:
                    sub_res = self._get_cond_prob(sub_prompt_list,
                                                  sub_target_pos, choice_ids)
                else:
                    sub_res = []
                    for prompt, position in zip(sub_prompt_list,
                                                sub_target_pos):
                        sub_res.extend(
                            self._get_cond_prob([prompt], [position],
                                                choice_ids))

                # save all the result
                for res, prompt, gold in zip(sub_res, sub_prompt_list,
                                             sub_golds):
                    example_input = prompt.replace(ice[idx], '')
                    output_handler.save_prompt_and_condprob(example_input,
                                                            prompt,
                                                            res,
                                                            index,
                                                            choices,
                                                            gold=gold)
                    index = index + 1

        # 5. Output
        if self.is_main_process:
            os.makedirs(output_json_filepath, exist_ok=True)
            output_handler.write_to_json(output_json_filepath,
                                         output_json_filename)

        return [
            sample['prediction']
            for sample in output_handler.results_dict.values()
        ]

    def _get_cond_prob(self, input_texts: List[str], target_pos: List[int],
                       choice_ids: List[int]):
        """Get the condition probability of next token.

        Args:
            input_texts (List[str]): All the input prompt to be tested.
            target_pos (List[int]): Target position of next token.
            choice_ids (List[int]): Choice ids of target tokens.
        """
        if hasattr(self.model, 'generator'):
            get_logits = self.model.generator.get_logits
        else:
            get_logits = self.model.get_logits

        outputs, _ = get_logits(input_texts)

        # we want get the next token probability
        # therefore no shift here
        logits = outputs.contiguous().float()

        logits = F.log_softmax(logits, dim=-1)
        log_probs = []
        for logit, target_ids in zip(logits, target_pos):
            log_probs.append(
                F.softmax(logit[target_ids, choice_ids], dim=-1).tolist())
        return log_probs