api-demo
/
opencompass-my-api
/build
/lib
/opencompass
/openicl
/icl_inferencer
/icl_base_inferencer.py
"""Basic Inferencer.""" | |
import json | |
import os | |
from pathlib import Path | |
from typing import List, Optional | |
import numpy as np | |
from mmengine.dist import is_main_process | |
from torch.utils.data import DataLoader | |
from ..icl_prompt_template import PromptTemplate | |
from ..icl_retriever import BaseRetriever | |
class BaseInferencer: | |
"""Base Inferencer class for all evaluation Inferencer. | |
Attributes: | |
model (:obj:`BaseModel`, optional): The module to inference. | |
max_model_token_num (:obj:`int`, optional): Maximum number of | |
tokenized words allowed by the LM. | |
batch_size (:obj:`int`, optional): Batch size for the | |
:obj:`DataLoader`. | |
output_json_filepath (:obj:`str`, optional): File path for output | |
`JSON` file. | |
output_json_filename (:obj:`str`, optional): File name for output | |
`JSON` file. | |
""" | |
model = None | |
def __init__( | |
self, | |
model, | |
max_seq_len: Optional[int] = None, | |
batch_size: Optional[int] = 1, | |
output_json_filepath: Optional[str] = './icl_inference_output', | |
output_json_filename: Optional[str] = 'predictions', | |
fix_id_list: Optional[List[int]] = None, | |
**kwargs, | |
) -> None: | |
if fix_id_list: | |
raise ValueError('Passing fix_id_list to Inferencer is no longer ' | |
'allowed. Please pass it to FixKRetriever ' | |
'instead.') | |
self.model = model | |
self.max_seq_len = max_seq_len | |
self.batch_size = batch_size | |
self.output_json_filepath = output_json_filepath | |
self.output_json_filename = output_json_filename | |
self.is_main_process = is_main_process() | |
os.makedirs(self.output_json_filepath, exist_ok=True) | |
def inference(self, | |
retriever: BaseRetriever, | |
ice_template: Optional[PromptTemplate] = None, | |
prompt_template: Optional[PromptTemplate] = None, | |
output_json_filepath: Optional[str] = None, | |
output_json_filename: Optional[str] = None) -> List: | |
"""Perform In-Context Inference given a retriever and optional | |
templates. | |
Args: | |
retriever (:obj:`BaseRetriever`): An instance of a Retriever class | |
that will be used to retrieve in-context examples | |
ice_template (:obj:`PromptTemplate`, optional): A template for | |
generating the in-context examples prompt. Defaults to None. | |
prompt_template (:obj:`PromptTemplate`, optional): A template for | |
generating the final prompt. Defaults to None. | |
output_json_filepath (:obj:`str`, optional): The file path to save | |
the results as a `JSON` file. Defaults to None. | |
output_json_filename (:obj:`str`, optional): The file name to save | |
the results as a `JSON` file. Defaults to None. | |
Raises: | |
NotImplementedError: If the function is not implemented in the | |
subclass. | |
Returns: | |
:obj:`List:` A list of string, each representing the results of one | |
inference. | |
""" | |
raise NotImplementedError("Method hasn't been implemented yet") | |
def get_dataloader(datalist: List[List], batch_size: int) -> DataLoader: | |
"""Return a dataloader of the input data list.""" | |
dataloader = DataLoader(datalist, | |
batch_size=batch_size, | |
collate_fn=lambda x: x) | |
return dataloader | |
def dump_results_dict(results_dict, filename): | |
with open(filename, 'w', encoding='utf-8') as json_file: | |
json.dump(results_dict, json_file, indent=4, ensure_ascii=False) | |
class GenInferencerOutputHandler: | |
origin_prompt_dict = {} | |
output_dict = {} | |
prediction_dict = {} | |
results_dict = {} | |
def __init__(self) -> None: | |
self.results_dict = {} | |
def write_to_json(self, save_dir: str, filename: str): | |
"""Dump the result to a json file.""" | |
dump_results_dict(self.results_dict, Path(save_dir) / filename) | |
def save_results(self, origin_prompt, prediction, idx, gold=None): | |
self.results_dict[str(idx)] = { | |
'origin_prompt': origin_prompt, | |
'prediction': prediction, | |
} | |
if gold: | |
self.results_dict[str(idx)]['gold'] = gold | |
class PPLInferencerOutputHandler: | |
results_dict = {} | |
def __init__(self) -> None: | |
self.results_dict = {} | |
def write_to_json(self, save_dir: str, filename: str): | |
"""Dump the result to a json file.""" | |
dump_results_dict(self.results_dict, Path(save_dir) / filename) | |
def save_ice(self, ice): | |
for idx, example in enumerate(ice): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
self.results_dict[str(idx)]['in-context examples'] = example | |
def save_predictions(self, predictions): | |
for idx, prediction in enumerate(predictions): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
self.results_dict[str(idx)]['prediction'] = prediction | |
def save_prompt_and_ppl(self, label, input, prompt, ppl, idx): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
if 'label: ' + str(label) not in self.results_dict[str(idx)].keys(): | |
self.results_dict[str(idx)]['label: ' + str(label)] = {} | |
self.results_dict[str(idx)]['label: ' + | |
str(label)]['testing input'] = input | |
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt | |
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl | |
def save_golds(self, golds): | |
for idx, gold in enumerate(golds): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
self.results_dict[str(idx)]['gold'] = gold | |
class CLPInferencerOutputHandler: | |
results_dict = {} | |
def __init__(self) -> None: | |
self.results_dict = {} | |
def write_to_json(self, save_dir: str, filename: str): | |
"""Dump the result to a json file.""" | |
dump_results_dict(self.results_dict, Path(save_dir) / filename) | |
def save_ice(self, ice): | |
for idx, example in enumerate(ice): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
self.results_dict[str(idx)]['in-context examples'] = example | |
def save_prompt_and_condprob(self, | |
input, | |
prompt, | |
cond_prob, | |
idx, | |
choices, | |
gold=None): | |
if str(idx) not in self.results_dict.keys(): | |
self.results_dict[str(idx)] = {} | |
# TODO: | |
# for single token situation, the input will always be yes currently | |
self.results_dict[str(idx)]['testing input'] = input | |
self.results_dict[str(idx)]['prompt'] = prompt | |
# TODO: hard code here | |
self.results_dict[str(idx)]['choices'] = choices | |
# For calculate auc scores, set scores as prediction | |
self.results_dict[str(idx)]['prediction'] = cond_prob | |
# set pred label in case needed | |
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob)) | |
self.results_dict[str(idx)]['gold'] = gold | |