api-demo
/
opencompass-my-api
/build
/lib
/opencompass
/openicl
/icl_inferencer
/icl_ll_inferencer.py
"""PPL Inferencer.""" | |
import os | |
from typing import List, Optional | |
import torch | |
from tqdm import trange | |
from opencompass.models.base import BaseModel | |
from opencompass.registry import ICL_INFERENCERS | |
from ..icl_prompt_template import PromptTemplate | |
from ..icl_retriever import BaseRetriever | |
from ..utils import get_logger | |
from .icl_base_inferencer import BaseInferencer, dump_results_dict | |
logger = get_logger(__name__) | |
class LLInferencer(BaseInferencer): | |
"""Loglikelihood Inferencer class to evaluate by loglikelihood. | |
Attributes: | |
model (:obj:`BaseModel`, optional): The module to inference. | |
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by | |
the LM. | |
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader` | |
output_json_filepath (:obj:`str`, optional): File path for output | |
`JSON` file. | |
output_json_filename (:obj:`str`, optional): File name for output | |
`JSON` file. | |
labels (:obj:`List`, optional): A list of labels for all classes. | |
""" | |
def __init__( | |
self, | |
model: BaseModel, | |
max_seq_len: Optional[int] = None, | |
batch_size: Optional[int] = 1, | |
output_json_filepath: Optional[str] = './icl_inference_output', | |
output_json_filename: Optional[str] = 'predictions', | |
labels: Optional[List] = None, | |
**kwargs) -> None: | |
super().__init__( | |
model=model, | |
max_seq_len=max_seq_len, | |
batch_size=batch_size, | |
output_json_filename=output_json_filename, | |
output_json_filepath=output_json_filepath, | |
**kwargs, | |
) | |
self.labels = labels | |
def inference(self, | |
retriever: BaseRetriever, | |
ice_template: Optional[PromptTemplate] = None, | |
prompt_template: Optional[PromptTemplate] = None, | |
output_json_filepath: Optional[str] = None, | |
output_json_filename: Optional[str] = None) -> List: | |
# 1. Preparation for output logs | |
output_handler = LLInferencerOutputHandler() | |
sub_predictions = [] | |
ppl = [] | |
ice = [] | |
if output_json_filepath is None: | |
output_json_filepath = self.output_json_filepath | |
if output_json_filename is None: | |
output_json_filename = self.output_json_filename | |
# 2. Get results of retrieval process | |
ice_idx_list = retriever.retrieve() | |
# 3. Get labels of all the classes | |
if self.labels is None: | |
labels = retriever.get_labels(ice_template=ice_template, | |
prompt_template=prompt_template) | |
else: | |
labels = self.labels | |
# 4. Generate in-context examples for testing inputs | |
for idx in range(len(ice_idx_list)): | |
ice.append( | |
retriever.generate_ice(ice_idx_list[idx], | |
ice_template=ice_template)) | |
output_handler.save_ice(self.model.parse_template(ice, mode='ppl')) | |
# 5. Calculating loglikelihood for prompts in each label's class | |
for label in labels: | |
index = 0 | |
prompt_list = [] | |
sub_ppl_list = [] | |
token_num_list = [] | |
cont_list = [] | |
# 5.1 Generate prompts of current label and truncate | |
# TODO: Refactor | |
for idx in range(len(ice_idx_list)): | |
prompt = retriever.generate_label_prompt( | |
idx, | |
ice[idx], | |
label, | |
ice_template=ice_template, | |
prompt_template=prompt_template) | |
if self.max_seq_len is not None: | |
prompt_token_num = self.model.get_token_len_from_template( | |
prompt, mode='ppl') | |
while len(ice_idx_list[idx] | |
) > 0 and prompt_token_num > self.max_seq_len: | |
ice_idx_list[idx] = ice_idx_list[idx][:-1] | |
ice[idx] = retriever.generate_ice( | |
ice_idx_list[idx], ice_template=ice_template) | |
prompt = retriever.generate_label_prompt( | |
idx, | |
ice[idx], | |
label, | |
ice_template=ice_template, | |
prompt_template=prompt_template) | |
prompt_token_num = self.model.get_token_len_from_template( # noqa | |
prompt, mode='ppl') # noqa | |
prompt_list.append(prompt) | |
token_num_list.append(prompt_token_num) | |
cont_list.append(retriever.test_ds[idx]['cont']) | |
# 5.2 Get loglikelihood | |
logger.info( | |
f"Calculating Loglikelihood for prompts labeled '{label}'" | |
) # noqa | |
for idx in trange(0, | |
len(prompt_list), | |
self.batch_size, | |
disable=not self.is_main_process): | |
sub_prompt_list = prompt_list[idx:idx + self.batch_size] | |
sub_cont_list = cont_list[idx:idx + self.batch_size] | |
with torch.no_grad(): | |
# mainly modify compared to PPLInferencer | |
sub_inputs = self.model.parse_template(sub_prompt_list, | |
mode='ppl') | |
sub_res = self.model.get_loglikelihood( | |
sub_inputs, sub_cont_list).tolist() | |
for res, prompt in zip( | |
sub_res, | |
self.model.parse_template(sub_prompt_list, | |
mode='ppl')): | |
sub_ppl_list.append(res) | |
ice_str = self.model.parse_template(ice[idx], mode='ppl') | |
output_handler.save_prompt_and_loglikelihood( | |
label, prompt.replace(ice_str, ''), prompt, res, index) | |
index = index + 1 | |
ppl.append(sub_ppl_list) | |
# 6. Get lowest PPL class as predictions | |
ppl = list(zip(*ppl)) | |
for single_ppl in ppl: | |
sub_predictions.append(labels[single_ppl.index(max(single_ppl))]) | |
output_handler.save_predictions(sub_predictions) | |
# 7. Fetch gold answers if exist | |
ds_reader = retriever.dataset_reader | |
if ds_reader.output_column: | |
golds = ds_reader.dataset['test'][ds_reader.output_column] | |
output_handler.save_golds(golds) | |
# 8. Output | |
if self.is_main_process: | |
os.makedirs(output_json_filepath, exist_ok=True) | |
output_handler.write_to_json(output_json_filepath, | |
output_json_filename) | |
return [ | |
sample['prediction'] | |
for sample in output_handler.results_dict.values() | |
] | |
class LLInferencerOutputHandler: | |
results_dict = {} | |
def __init__(self) -> None: | |
self.results_dict = {} | |
def write_to_json(self, save_dir: str, filename: str): | |
"""Dump the result to a json file.""" | |
dump_results_dict(self.results_dict, os.path.join(save_dir, filename)) | |
def save_ice(self, ice): | |
for idx, example in enumerate(ice): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
self.results_dict[str(idx)]['in-context examples'] = example | |
def save_predictions(self, predictions): | |
for idx, prediction in enumerate(predictions): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
self.results_dict[str(idx)]['prediction'] = prediction | |
def save_prompt_and_loglikelihood(self, label, input, prompt, | |
loglikelihood, idx): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
if 'label: ' + str(label) not in self.results_dict[str(idx)].keys(): | |
self.results_dict[str(idx)]['label: ' + str(label)] = {} | |
self.results_dict[str(idx)]['label: ' + | |
str(label)]['testing input'] = input | |
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt | |
self.results_dict[str(idx)][ | |
'label: ' + str(label)]['Loglikelihood'] = loglikelihood | |
def save_golds(self, golds): | |
for idx, gold in enumerate(golds): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
self.results_dict[str(idx)]['gold'] = gold | |