TwT-6's picture
Upload 2667 files
256a159 verified
"""CLP Inferencer."""
import itertools
import os
from typing import List, Optional
import torch.nn.functional as F
from tqdm import trange
from opencompass.models import BaseModel
from opencompass.registry import ICL_INFERENCERS
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils import get_logger
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler
logger = get_logger(__name__)
@ICL_INFERENCERS.register_module()
class CLPInferencer(BaseInferencer):
"""Conditional log probability based In-context Learning Inferencer.
Calculate the log probability of each choices according the logits.
The input is the context with single choice, e.g. Q: xx.\n A: first choice
to this question.
And starting from the first token of this choice, sum up all the log
probabilities of each
tokens from logits. Then, compare each choice with softmax.
There are two scenarios in this case:
1. Single token choices. Already supported.
2. Muiltple token choices. TODO: More complicated and needs to be added in
the future for specific dataset.
Attributes:
model (:obj:`BaseModel`, optional): The module to inference.
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by
the LM.
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`
output_json_filepath (:obj:`str`, optional): File path for output
`JSON` file.
output_json_filename (:obj:`str`, optional): File name for output
`JSON` file.
single_token (:obj:`bool`): If ``True``, choices only have one token to
calculate. Defaults to True. Currently only support True.
"""
def __init__(
self,
model: BaseModel,
max_seq_len: Optional[int] = None,
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
single_token: bool = True,
**kwargs) -> None:
super().__init__(
model=model,
max_seq_len=max_seq_len,
batch_size=batch_size,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
# TODO: support multiple token
assert single_token, 'Only support single token choice currently.'
self.single_token = single_token
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None,
normalizing_str: Optional[str] = None) -> List:
# 1. Preparation for output logs
output_handler = CLPInferencerOutputHandler()
ice = []
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# CLP cannot infer with log probability for api models
# unless model provided such options which needs specific
# implementation, open an issue if you encounter the case.
if self.model.is_api:
# Write empty file in case always rerun for this model
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
err_msg = 'API model is not supported for conditional log '\
'probability inference and skip this exp.'
output_handler.results_dict = {'error': err_msg}
output_handler.write_to_json(output_json_filepath,
output_json_filename)
raise ValueError(err_msg)
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate in-context examples for testing inputs
for idx in range(len(ice_idx_list)):
ice.append(
retriever.generate_ice(ice_idx_list[idx],
ice_template=ice_template))
output_handler.save_ice(ice)
# 4. Collect prompts and calculate conditional log probs
if self.single_token:
index = 0
prompt_list = []
target_pos = []
# TODO: Hard code temperaily, need to modified here
choices = retriever.test_ds[0]['choices']
try:
choice_ids = [
self.model.tokenizer.encode(c, False, False)
for c in choices
]
except ValueError:
choice_ids = [self.model.tokenizer.encode(c) for c in choices]
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': # noqa
choice_ids = [c[2:] for c in choice_ids]
elif hasattr(self.model.tokenizer, 'add_bos_token'):
if self.model.tokenizer.add_bos_token:
choice_ids = [c[1:] for c in choice_ids]
if self.model.tokenizer.add_eos_token:
choice_ids = [c[:-1] for c in choice_ids]
if isinstance(choice_ids[0], list):
# in case tokenizer returns list for single token
choice_ids = list(itertools.chain(*choice_ids))
get_token_len = self.model.get_token_len
if hasattr(self.model.tokenizer, 'padding_side'):
# get padding_side for huggingface model
padding_side = self.model.tokenizer.padding_side
else:
# defaults to left for internal model
padding_side = 'left'
# prepare in context for each example and control the length
for idx in range(len(ice_idx_list)):
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt = self.model.parse_template(prompt, mode='gen')
if self.max_seq_len is not None:
prompt_token_num = get_token_len(prompt)
# add one because additional token will be added in the end
while len(
ice_idx_list[idx]
) > 0 and prompt_token_num + 1 > self.max_seq_len:
ice_idx_list[idx] = ice_idx_list[idx][:-1]
ice[idx] = retriever.generate_ice(
ice_idx_list[idx], ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice[idx],
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = get_token_len(prompt)
prompt_list.append(prompt)
# in case prompt token num reaches max
if self.max_seq_len is not None and \
prompt_token_num + 1 > self.max_seq_len:
prompt_token_num = self.max_seq_len - 1
# get the target position index
if padding_side == 'left':
# always the last position
target_pos.append(-1)
else:
# the last position of the original prompt
target_pos.append(prompt_token_num - 1)
# 4.1 Fetch and zip prompt & gold answer if output column exists
ds_reader = retriever.dataset_reader
if ds_reader.output_column:
gold_ans = ds_reader.dataset['test'][ds_reader.output_column]
else:
gold_ans = [None] * len(prompt_list)
if hasattr(self.model, 'batch_padding'):
# get batch padding for huggingface model
batch_padding = self.model.batch_padding
else:
# defaults to False for internal model
batch_padding = False
logger.info('Calculating conditional log probability for prompts.')
for idx in trange(0,
len(prompt_list),
self.batch_size,
disable=not self.is_main_process):
# get batch data
sub_prompt_list = prompt_list[idx:idx + self.batch_size]
sub_golds = gold_ans[idx:idx + self.batch_size]
sub_target_pos = target_pos[idx:idx + self.batch_size]
# get probability result
if batch_padding and self.batch_size > 1:
sub_res = self._get_cond_prob(sub_prompt_list,
sub_target_pos, choice_ids)
else:
sub_res = []
for prompt, position in zip(sub_prompt_list,
sub_target_pos):
sub_res.extend(
self._get_cond_prob([prompt], [position],
choice_ids))
# save all the result
for res, prompt, gold in zip(sub_res, sub_prompt_list,
sub_golds):
example_input = prompt.replace(ice[idx], '')
output_handler.save_prompt_and_condprob(example_input,
prompt,
res,
index,
choices,
gold=gold)
index = index + 1
# 5. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
return [
sample['prediction']
for sample in output_handler.results_dict.values()
]
def _get_cond_prob(self, input_texts: List[str], target_pos: List[int],
choice_ids: List[int]):
"""Get the condition probability of next token.
Args:
input_texts (List[str]): All the input prompt to be tested.
target_pos (List[int]): Target position of next token.
choice_ids (List[int]): Choice ids of target tokens.
"""
if hasattr(self.model, 'generator'):
get_logits = self.model.generator.get_logits
else:
get_logits = self.model.get_logits
outputs, _ = get_logits(input_texts)
# we want get the next token probability
# therefore no shift here
logits = outputs.contiguous().float()
logits = F.log_softmax(logits, dim=-1)
log_probs = []
for logit, target_ids in zip(logits, target_pos):
log_probs.append(
F.softmax(logit[target_ids, choice_ids], dim=-1).tolist())
return log_probs