|
"""PPL Inferencer.""" |
|
|
|
import os |
|
from typing import List, Optional |
|
|
|
import torch |
|
from tqdm import trange |
|
|
|
from opencompass.models.base import BaseModel |
|
from opencompass.registry import ICL_INFERENCERS |
|
|
|
from ..icl_prompt_template import PromptTemplate |
|
from ..icl_retriever import BaseRetriever |
|
from ..utils import get_logger |
|
from .icl_base_inferencer import BaseInferencer, dump_results_dict |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@ICL_INFERENCERS.register_module() |
|
class LLInferencer(BaseInferencer): |
|
"""Loglikelihood Inferencer class to evaluate by loglikelihood. |
|
|
|
Attributes: |
|
model (:obj:`BaseModel`, optional): The module to inference. |
|
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by |
|
the LM. |
|
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader` |
|
output_json_filepath (:obj:`str`, optional): File path for output |
|
`JSON` file. |
|
output_json_filename (:obj:`str`, optional): File name for output |
|
`JSON` file. |
|
labels (:obj:`List`, optional): A list of labels for all classes. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: BaseModel, |
|
max_seq_len: Optional[int] = None, |
|
batch_size: Optional[int] = 1, |
|
output_json_filepath: Optional[str] = './icl_inference_output', |
|
output_json_filename: Optional[str] = 'predictions', |
|
labels: Optional[List] = None, |
|
**kwargs) -> None: |
|
super().__init__( |
|
model=model, |
|
max_seq_len=max_seq_len, |
|
batch_size=batch_size, |
|
output_json_filename=output_json_filename, |
|
output_json_filepath=output_json_filepath, |
|
**kwargs, |
|
) |
|
|
|
self.labels = labels |
|
|
|
def inference(self, |
|
retriever: BaseRetriever, |
|
ice_template: Optional[PromptTemplate] = None, |
|
prompt_template: Optional[PromptTemplate] = None, |
|
output_json_filepath: Optional[str] = None, |
|
output_json_filename: Optional[str] = None) -> List: |
|
|
|
output_handler = LLInferencerOutputHandler() |
|
|
|
sub_predictions = [] |
|
ppl = [] |
|
ice = [] |
|
|
|
if output_json_filepath is None: |
|
output_json_filepath = self.output_json_filepath |
|
if output_json_filename is None: |
|
output_json_filename = self.output_json_filename |
|
|
|
|
|
ice_idx_list = retriever.retrieve() |
|
|
|
|
|
if self.labels is None: |
|
labels = retriever.get_labels(ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
else: |
|
labels = self.labels |
|
|
|
|
|
for idx in range(len(ice_idx_list)): |
|
ice.append( |
|
retriever.generate_ice(ice_idx_list[idx], |
|
ice_template=ice_template)) |
|
output_handler.save_ice(self.model.parse_template(ice, mode='ppl')) |
|
|
|
|
|
for label in labels: |
|
index = 0 |
|
prompt_list = [] |
|
sub_ppl_list = [] |
|
token_num_list = [] |
|
cont_list = [] |
|
|
|
|
|
|
|
for idx in range(len(ice_idx_list)): |
|
prompt = retriever.generate_label_prompt( |
|
idx, |
|
ice[idx], |
|
label, |
|
ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
if self.max_seq_len is not None: |
|
prompt_token_num = self.model.get_token_len_from_template( |
|
prompt, mode='ppl') |
|
while len(ice_idx_list[idx] |
|
) > 0 and prompt_token_num > self.max_seq_len: |
|
ice_idx_list[idx] = ice_idx_list[idx][:-1] |
|
ice[idx] = retriever.generate_ice( |
|
ice_idx_list[idx], ice_template=ice_template) |
|
prompt = retriever.generate_label_prompt( |
|
idx, |
|
ice[idx], |
|
label, |
|
ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
prompt_token_num = self.model.get_token_len_from_template( |
|
prompt, mode='ppl') |
|
|
|
prompt_list.append(prompt) |
|
token_num_list.append(prompt_token_num) |
|
cont_list.append(retriever.test_ds[idx]['cont']) |
|
|
|
|
|
logger.info( |
|
f"Calculating Loglikelihood for prompts labeled '{label}'" |
|
) |
|
for idx in trange(0, |
|
len(prompt_list), |
|
self.batch_size, |
|
disable=not self.is_main_process): |
|
sub_prompt_list = prompt_list[idx:idx + self.batch_size] |
|
sub_cont_list = cont_list[idx:idx + self.batch_size] |
|
|
|
with torch.no_grad(): |
|
|
|
sub_inputs = self.model.parse_template(sub_prompt_list, |
|
mode='ppl') |
|
sub_res = self.model.get_loglikelihood( |
|
sub_inputs, sub_cont_list).tolist() |
|
for res, prompt in zip( |
|
sub_res, |
|
self.model.parse_template(sub_prompt_list, |
|
mode='ppl')): |
|
sub_ppl_list.append(res) |
|
ice_str = self.model.parse_template(ice[idx], mode='ppl') |
|
output_handler.save_prompt_and_loglikelihood( |
|
label, prompt.replace(ice_str, ''), prompt, res, index) |
|
index = index + 1 |
|
ppl.append(sub_ppl_list) |
|
|
|
|
|
ppl = list(zip(*ppl)) |
|
for single_ppl in ppl: |
|
sub_predictions.append(labels[single_ppl.index(max(single_ppl))]) |
|
output_handler.save_predictions(sub_predictions) |
|
|
|
|
|
ds_reader = retriever.dataset_reader |
|
if ds_reader.output_column: |
|
golds = ds_reader.dataset['test'][ds_reader.output_column] |
|
output_handler.save_golds(golds) |
|
|
|
|
|
if self.is_main_process: |
|
os.makedirs(output_json_filepath, exist_ok=True) |
|
output_handler.write_to_json(output_json_filepath, |
|
output_json_filename) |
|
|
|
return [ |
|
sample['prediction'] |
|
for sample in output_handler.results_dict.values() |
|
] |
|
|
|
|
|
class LLInferencerOutputHandler: |
|
results_dict = {} |
|
|
|
def __init__(self) -> None: |
|
self.results_dict = {} |
|
|
|
def write_to_json(self, save_dir: str, filename: str): |
|
"""Dump the result to a json file.""" |
|
dump_results_dict(self.results_dict, os.path.join(save_dir, filename)) |
|
|
|
def save_ice(self, ice): |
|
for idx, example in enumerate(ice): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
self.results_dict[str(idx)]['in-context examples'] = example |
|
|
|
def save_predictions(self, predictions): |
|
for idx, prediction in enumerate(predictions): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
self.results_dict[str(idx)]['prediction'] = prediction |
|
|
|
def save_prompt_and_loglikelihood(self, label, input, prompt, |
|
loglikelihood, idx): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
if 'label: ' + str(label) not in self.results_dict[str(idx)].keys(): |
|
self.results_dict[str(idx)]['label: ' + str(label)] = {} |
|
self.results_dict[str(idx)]['label: ' + |
|
str(label)]['testing input'] = input |
|
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt |
|
self.results_dict[str(idx)][ |
|
'label: ' + str(label)]['Loglikelihood'] = loglikelihood |
|
|
|
def save_golds(self, golds): |
|
for idx, gold in enumerate(golds): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
self.results_dict[str(idx)]['gold'] = gold |
|
|