api-demo
/
opencompass-my-api
/build
/lib
/opencompass
/openicl
/icl_inferencer
/icl_chat_inferencer.py
"""Chat Inferencer.""" | |
import os | |
import os.path as osp | |
from typing import List, Optional, Union | |
import mmengine | |
from mmengine import is_list_of | |
from tqdm import tqdm | |
from opencompass.models import APITemplateParser as _APITemplateParser | |
from opencompass.models import BaseModel | |
from opencompass.models import LMTemplateParser as _LMTemplateParser | |
from opencompass.registry import ICL_INFERENCERS | |
from opencompass.utils.prompt import PromptList | |
from ..icl_prompt_template import PromptTemplate | |
from ..icl_retriever import BaseRetriever | |
from ..utils.logging import get_logger | |
from .icl_base_inferencer import BaseInferencer, dump_results_dict | |
logger = get_logger(__name__) | |
def promptlist_to_openai(prompt: Union[str, PromptList]): | |
output = [] | |
if isinstance(prompt, str): | |
return [dict(role='user', content=prompt)] | |
for item in prompt: | |
if 'section' in item: | |
continue | |
if isinstance(item, str) and item: | |
output.append(dict(role='user', content=item)) | |
elif item['role'] == 'SYSTEM': | |
output.append(dict(role='system', content=item['prompt'])) | |
elif item['role'] == 'HUMAN': | |
output.append(dict(role='user', content=item['prompt'])) | |
elif item['role'] == 'BOT': | |
output.append(dict(role='assistant', content=item['prompt'])) | |
return output | |
class LMTemplateParser: | |
"""LMTemplateParser accepts OpenAI format dialog inputs.""" | |
def __init__(self, meta_template: Optional[dict] = None): | |
self.meta_template = meta_template | |
self.roles = {} | |
role_mapping = { | |
'SYSTEM': 'system', | |
'HUMAN': 'user', | |
'BOT': 'assistant', | |
} | |
if meta_template: | |
for item in meta_template.get('round', []): | |
role = role_mapping.get(item['role'], item['role']) | |
self.roles[role] = item.copy() | |
for item in meta_template.get('reserved_roles', []): | |
role = role_mapping.get(item['role'], item['role']) | |
self.roles[role] = item.copy() | |
def parse_template(self, chat: List[dict], mode='gen') -> str: | |
if is_list_of(chat, list): | |
# Handle batch inputs | |
return [self.parse_template(item) for item in chat] | |
assert is_list_of(chat, dict) | |
prompt = '' | |
if self.roles: | |
for dialog in chat: | |
role_cfg = self.roles.get(dialog['role'], {}) | |
prompt += (role_cfg.get('begin') or '') | |
prompt += (dialog.get('content') or '') | |
prompt += (role_cfg.get('end') or '') | |
prompt += (self.roles['assistant'].get('begin') or '') | |
else: | |
# in case the model does not have any meta template | |
last_sep = '' | |
for item in chat: | |
prompt += last_sep + (item.get('content') or '') | |
last_sep = '\n' | |
return prompt | |
class APITemplateParser: | |
"""APITemplateParser accepts OpenAI format dialog inputs.""" | |
def __init__(self, meta_template: Optional[dict] = None): | |
self.meta_template = meta_template | |
self.roles = {} | |
role_mapping = { | |
'SYSTEM': 'system', | |
'HUMAN': 'user', | |
'BOT': 'assistant', | |
} | |
if meta_template: | |
for item in meta_template.get('round', []): | |
role = role_mapping.get(item['role'], item['role']) | |
self.roles[role] = item.copy() | |
for item in meta_template.get('reserved_roles', []): | |
role = role_mapping.get(item['role'], item['role']) | |
self.roles[role] = item.copy() | |
else: | |
self.roles = dict( | |
system=dict(api_role='SYSTEM'), | |
user=dict(api_role='HUMAN'), | |
assistant=dict(api_role='BOT', generate=True), | |
) | |
def parse_template(self, chat: List[dict], mode='gen') -> str: | |
if is_list_of(chat, list): | |
# Handle batch inputs | |
return [self.parse_template(item) for item in chat] | |
assert is_list_of(chat, dict) | |
prompt = [] | |
for dialog in chat: | |
if dialog['role'] in self.roles: | |
role = self.roles[dialog['role']]['api_role'] | |
else: | |
role = dialog['role'] | |
prompt.append(dict(role=role, prompt=dialog.get('content') or '')) | |
return PromptList(prompt) | |
class ChatOutputHandler: | |
def __init__(self) -> None: | |
self.results_dict = {} | |
def write_to_json(self, save_dir: str, filename: str): | |
"""Dump the result to a json file.""" | |
dump_results_dict(self.results_dict, osp.join(save_dir, filename)) | |
def save_results(self, | |
origin_prompt: list, | |
prediction: str, | |
idx: int, | |
gold: str = None): | |
result_dict = {} | |
if gold: | |
result_dict['gold'] = gold | |
result_dict.update({ | |
'prediction': prediction, | |
'origin_prompt': origin_prompt, | |
}) | |
self.results_dict[str(idx)] = result_dict | |
def save_multiround_results(self, | |
origin_prompt: list, | |
prediction: str, | |
idx: int, | |
gold: str = None): | |
result_dict = self.results_dict.get(str(idx), { | |
'gold': [], | |
'prediction': [], | |
'origin_prompt': [], | |
}) | |
result_dict['gold'].append(gold) | |
result_dict['prediction'].append(prediction) | |
result_dict['origin_prompt'].append(origin_prompt) | |
self.results_dict[str(idx)] = result_dict | |
class ChatInferencer(BaseInferencer): | |
HandlerType = ChatOutputHandler | |
def __init__( | |
self, | |
model, | |
output_json_filepath: Optional[str] = './icl_inference_output', | |
output_json_filename: Optional[str] = 'predictions', | |
save_every: Optional[int] = 1, | |
infer_mode: str = 'last', | |
**kwargs) -> None: | |
super().__init__( | |
model=model, | |
output_json_filename=output_json_filename, | |
output_json_filepath=output_json_filepath, | |
**kwargs, | |
) | |
assert infer_mode in ['last', 'every', 'every_with_gt'] | |
self.infer_mode = infer_mode | |
self.model: BaseModel | |
self._set_meta_template(self.model) | |
if self.model.is_api and save_every is None: | |
save_every = 1 | |
self.save_every = save_every | |
self.dialogue_mode = False | |
def _set_meta_template(self, model): | |
origin = model.template_parser | |
if isinstance(origin, _APITemplateParser): | |
model.template_parser = APITemplateParser(origin.meta_template) | |
if isinstance(origin, _LMTemplateParser): | |
model.template_parser = LMTemplateParser(origin.meta_template) | |
def inference(self, | |
retriever: BaseRetriever, | |
ice_template: Optional[PromptTemplate] = None, | |
prompt_template: Optional[PromptTemplate] = None, | |
output_json_filepath: Optional[str] = None, | |
output_json_filename: Optional[str] = None) -> dict: | |
# 1. Preparation for output logs | |
output_handler = self.HandlerType() | |
if output_json_filepath is None: | |
output_json_filepath = self.output_json_filepath | |
if output_json_filename is None: | |
output_json_filename = self.output_json_filename | |
# 2. Get results of retrieval process | |
ice_idx_list = retriever.retrieve() | |
# 3. Generate prompts for testing input | |
chat_list = self.get_chat_list( | |
ice_idx_list, | |
retriever, | |
prompt_template=prompt_template, | |
) | |
# Create tmp json file for saving intermediate results and future | |
# resuming | |
index = 0 | |
tmp_json_filepath = os.path.join(output_json_filepath, | |
'tmp_' + output_json_filename) | |
if osp.exists(tmp_json_filepath): | |
# TODO: move resume to output handler | |
try: | |
tmp_result_dict = mmengine.load(tmp_json_filepath) | |
except Exception: | |
pass | |
else: | |
output_handler.results_dict = tmp_result_dict | |
index = len(tmp_result_dict) | |
# 4. Wrap prompts with Dataloader | |
dataloader = self.get_dataloader(chat_list[index:], batch_size=1) | |
# 5. Inference for prompts in each batch | |
logger.info('Starting inference process...') | |
for datum in tqdm(dataloader, disable=not self.is_main_process): | |
chat = datum[0] | |
if self.infer_mode == 'last': | |
self.infer_last(chat, index, output_handler) | |
elif self.infer_mode == 'every': | |
self.infer_every(chat, index, output_handler) | |
elif self.infer_mode == 'every_with_gt': | |
self.infer_every_with_gt(chat, index, output_handler) | |
index += 1 | |
# Save intermediate results | |
if (self.save_every is not None and index % self.save_every == 0 | |
and self.is_main_process): | |
output_handler.write_to_json(output_json_filepath, | |
'tmp_' + output_json_filename) | |
# 4. Output | |
if self.is_main_process: | |
os.makedirs(output_json_filepath, exist_ok=True) | |
output_handler.write_to_json(output_json_filepath, | |
output_json_filename) | |
if osp.exists(tmp_json_filepath): | |
os.remove(tmp_json_filepath) | |
return output_handler.results_dict | |
def get_chat_list(self, | |
ice_idx_list: List[List[int]], | |
retriever: BaseRetriever, | |
prompt_template: Optional[PromptTemplate] = None): | |
prompt_list = [] | |
input_columns = retriever.dataset_reader.input_columns | |
output_column = retriever.dataset_reader.output_column | |
def chat_from_entry(entry): | |
if prompt_template is None and len(input_columns) == 1: | |
# Directly use the input column as the user input | |
user = entry.get(input_columns[0]) | |
assistant = entry.get(output_column, '') | |
return [ | |
dict(role='user', content=user), | |
dict(role='assistant', content=assistant), | |
] | |
elif prompt_template is not None: | |
# Use prompt template to generate chat history | |
chat = promptlist_to_openai( | |
prompt_template.generate_item(entry)) | |
gold = entry.get(output_column, '') | |
if chat[-1]['role'] != 'assistant': | |
chat.append(dict(role='assistant', content=gold)) | |
return chat | |
else: | |
raise ValueError() | |
for idx, ice_idx in enumerate(ice_idx_list): | |
# NOTE: The in-context examples won't be used by now. | |
item = { | |
k: v | |
for k, v in retriever.test_ds[idx].items() | |
if k in input_columns or k == output_column | |
} | |
if all(isinstance(value, str) for value in item.values()): | |
# Every column is a single string | |
chat = chat_from_entry(item) | |
elif all(is_list_of(value, str) for value in item.values()): | |
# Every column is a list of string for multi-round chat | |
entries = [dict(zip(item, v)) for v in zip(*item.values())] | |
chat = sum((chat_from_entry(entry) for entry in entries), []) | |
elif len(input_columns) == 1 and is_list_of( | |
item[input_columns[0]], dict): | |
# Single input column and it's already a chat. | |
chat = item[input_columns[0]] | |
elif 'dialogue' in input_columns: | |
chat = item['dialogue'] | |
self.dialogue_mode = True | |
else: | |
raise ValueError('Cannot construct chat from the dataset.') | |
prompt_list.append(chat) | |
return prompt_list | |
def infer_last(self, chat: List[dict], index: int, output_handler): | |
assistant_indices = [ | |
i for i, item in enumerate(chat) if item['role'] == 'assistant' | |
] | |
history = chat[:assistant_indices[-1]] | |
output = self.model.generate_from_template([history], | |
max_out_len=512)[0] | |
output_handler.save_results( | |
origin_prompt=history, | |
prediction=output, | |
idx=index, | |
gold=chat[assistant_indices[-1]]['content'], | |
) | |
def infer_every(self, chat: List[dict], index: int, output_handler): | |
assistant_indices = [ | |
i for i, item in enumerate(chat) if item['role'] == 'assistant' | |
] | |
index_copy = index | |
for i in assistant_indices: | |
history = chat[:i] | |
output = self.model.generate_from_template([history], | |
max_out_len=512)[0] | |
chat[i]['content'] = output | |
if not self.dialogue_mode: | |
output_handler.save_multiround_results( | |
origin_prompt=history[-1]['content'], | |
prediction=output, | |
idx=index, | |
gold=chat[i]['content'], | |
) | |
index += 1 | |
if self.dialogue_mode: | |
# dialogue mode for subjective evaluation | |
assert len(chat) % 2 == 0 | |
round_num = int(len(chat) / 2) | |
preds_list = [] | |
for i in range(round_num): | |
temp_dict = { | |
'round': i + 1, | |
'user': chat[i * 2]['content'], | |
'assistant': chat[i * 2 + 1]['content'] | |
} | |
preds_list.append(temp_dict) | |
output_handler.save_results( | |
origin_prompt=None, | |
prediction=preds_list, | |
idx=index_copy, | |
gold=None, | |
) | |
def infer_every_with_gt(self, chat: List[dict], index: int, | |
output_handler): | |
assistant_indices = [ | |
i for i, item in enumerate(chat) if item['role'] == 'assistant' | |
] | |
for i in assistant_indices: | |
history = chat[:i] | |
output = self.model.generate_from_template([history], | |
max_out_len=512)[0] | |
output_handler.save_multiround_results( | |
origin_prompt=history[-1]['content'], | |
prediction=output, | |
idx=index, | |
gold=chat[i]['content'], | |
) | |
index += 1 | |