"""Chat Inferencer.""" import os import os.path as osp from typing import List, Optional, Union import mmengine from mmengine import is_list_of from tqdm import tqdm from opencompass.models import APITemplateParser as _APITemplateParser from opencompass.models import BaseModel from opencompass.models import LMTemplateParser as _LMTemplateParser from opencompass.registry import ICL_INFERENCERS from opencompass.utils.prompt import PromptList from ..icl_prompt_template import PromptTemplate from ..icl_retriever import BaseRetriever from ..utils.logging import get_logger from .icl_base_inferencer import BaseInferencer, dump_results_dict logger = get_logger(__name__) def promptlist_to_openai(prompt: Union[str, PromptList]): output = [] if isinstance(prompt, str): return [dict(role='user', content=prompt)] for item in prompt: if 'section' in item: continue if isinstance(item, str) and item: output.append(dict(role='user', content=item)) elif item['role'] == 'SYSTEM': output.append(dict(role='system', content=item['prompt'])) elif item['role'] == 'HUMAN': output.append(dict(role='user', content=item['prompt'])) elif item['role'] == 'BOT': output.append(dict(role='assistant', content=item['prompt'])) return output class LMTemplateParser: """LMTemplateParser accepts OpenAI format dialog inputs.""" def __init__(self, meta_template: Optional[dict] = None): self.meta_template = meta_template self.roles = {} role_mapping = { 'SYSTEM': 'system', 'HUMAN': 'user', 'BOT': 'assistant', } if meta_template: for item in meta_template.get('round', []): role = role_mapping.get(item['role'], item['role']) self.roles[role] = item.copy() for item in meta_template.get('reserved_roles', []): role = role_mapping.get(item['role'], item['role']) self.roles[role] = item.copy() def parse_template(self, chat: List[dict], mode='gen') -> str: if is_list_of(chat, list): # Handle batch inputs return [self.parse_template(item) for item in chat] assert is_list_of(chat, dict) prompt = '' if self.roles: for dialog in chat: role_cfg = self.roles.get(dialog['role'], {}) prompt += (role_cfg.get('begin') or '') prompt += (dialog.get('content') or '') prompt += (role_cfg.get('end') or '') prompt += (self.roles['assistant'].get('begin') or '') else: # in case the model does not have any meta template last_sep = '' for item in chat: prompt += last_sep + (item.get('content') or '') last_sep = '\n' return prompt class APITemplateParser: """APITemplateParser accepts OpenAI format dialog inputs.""" def __init__(self, meta_template: Optional[dict] = None): self.meta_template = meta_template self.roles = {} role_mapping = { 'SYSTEM': 'system', 'HUMAN': 'user', 'BOT': 'assistant', } if meta_template: for item in meta_template.get('round', []): role = role_mapping.get(item['role'], item['role']) self.roles[role] = item.copy() for item in meta_template.get('reserved_roles', []): role = role_mapping.get(item['role'], item['role']) self.roles[role] = item.copy() else: self.roles = dict( system=dict(api_role='SYSTEM'), user=dict(api_role='HUMAN'), assistant=dict(api_role='BOT', generate=True), ) def parse_template(self, chat: List[dict], mode='gen') -> str: if is_list_of(chat, list): # Handle batch inputs return [self.parse_template(item) for item in chat] assert is_list_of(chat, dict) prompt = [] for dialog in chat: if dialog['role'] in self.roles: role = self.roles[dialog['role']]['api_role'] else: role = dialog['role'] prompt.append(dict(role=role, prompt=dialog.get('content') or '')) return PromptList(prompt) class ChatOutputHandler: def __init__(self) -> None: self.results_dict = {} def write_to_json(self, save_dir: str, filename: str): """Dump the result to a json file.""" dump_results_dict(self.results_dict, osp.join(save_dir, filename)) def save_results(self, origin_prompt: list, prediction: str, idx: int, gold: str = None): result_dict = {} if gold: result_dict['gold'] = gold result_dict.update({ 'prediction': prediction, 'origin_prompt': origin_prompt, }) self.results_dict[str(idx)] = result_dict def save_multiround_results(self, origin_prompt: list, prediction: str, idx: int, gold: str = None): result_dict = self.results_dict.get(str(idx), { 'gold': [], 'prediction': [], 'origin_prompt': [], }) result_dict['gold'].append(gold) result_dict['prediction'].append(prediction) result_dict['origin_prompt'].append(origin_prompt) self.results_dict[str(idx)] = result_dict @ICL_INFERENCERS.register_module() class ChatInferencer(BaseInferencer): HandlerType = ChatOutputHandler def __init__( self, model, output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', save_every: Optional[int] = 1, infer_mode: str = 'last', **kwargs) -> None: super().__init__( model=model, output_json_filename=output_json_filename, output_json_filepath=output_json_filepath, **kwargs, ) assert infer_mode in ['last', 'every', 'every_with_gt'] self.infer_mode = infer_mode self.model: BaseModel self._set_meta_template(self.model) if self.model.is_api and save_every is None: save_every = 1 self.save_every = save_every self.dialogue_mode = False def _set_meta_template(self, model): origin = model.template_parser if isinstance(origin, _APITemplateParser): model.template_parser = APITemplateParser(origin.meta_template) if isinstance(origin, _LMTemplateParser): model.template_parser = LMTemplateParser(origin.meta_template) def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, output_json_filename: Optional[str] = None) -> dict: # 1. Preparation for output logs output_handler = self.HandlerType() if output_json_filepath is None: output_json_filepath = self.output_json_filepath if output_json_filename is None: output_json_filename = self.output_json_filename # 2. Get results of retrieval process ice_idx_list = retriever.retrieve() # 3. Generate prompts for testing input chat_list = self.get_chat_list( ice_idx_list, retriever, prompt_template=prompt_template, ) # Create tmp json file for saving intermediate results and future # resuming index = 0 tmp_json_filepath = os.path.join(output_json_filepath, 'tmp_' + output_json_filename) if osp.exists(tmp_json_filepath): # TODO: move resume to output handler try: tmp_result_dict = mmengine.load(tmp_json_filepath) except Exception: pass else: output_handler.results_dict = tmp_result_dict index = len(tmp_result_dict) # 4. Wrap prompts with Dataloader dataloader = self.get_dataloader(chat_list[index:], batch_size=1) # 5. Inference for prompts in each batch logger.info('Starting inference process...') for datum in tqdm(dataloader, disable=not self.is_main_process): chat = datum[0] if self.infer_mode == 'last': self.infer_last(chat, index, output_handler) elif self.infer_mode == 'every': self.infer_every(chat, index, output_handler) elif self.infer_mode == 'every_with_gt': self.infer_every_with_gt(chat, index, output_handler) index += 1 # Save intermediate results if (self.save_every is not None and index % self.save_every == 0 and self.is_main_process): output_handler.write_to_json(output_json_filepath, 'tmp_' + output_json_filename) # 4. Output if self.is_main_process: os.makedirs(output_json_filepath, exist_ok=True) output_handler.write_to_json(output_json_filepath, output_json_filename) if osp.exists(tmp_json_filepath): os.remove(tmp_json_filepath) return output_handler.results_dict def get_chat_list(self, ice_idx_list: List[List[int]], retriever: BaseRetriever, prompt_template: Optional[PromptTemplate] = None): prompt_list = [] input_columns = retriever.dataset_reader.input_columns output_column = retriever.dataset_reader.output_column def chat_from_entry(entry): if prompt_template is None and len(input_columns) == 1: # Directly use the input column as the user input user = entry.get(input_columns[0]) assistant = entry.get(output_column, '') return [ dict(role='user', content=user), dict(role='assistant', content=assistant), ] elif prompt_template is not None: # Use prompt template to generate chat history chat = promptlist_to_openai( prompt_template.generate_item(entry)) gold = entry.get(output_column, '') if chat[-1]['role'] != 'assistant': chat.append(dict(role='assistant', content=gold)) return chat else: raise ValueError() for idx, ice_idx in enumerate(ice_idx_list): # NOTE: The in-context examples won't be used by now. item = { k: v for k, v in retriever.test_ds[idx].items() if k in input_columns or k == output_column } if all(isinstance(value, str) for value in item.values()): # Every column is a single string chat = chat_from_entry(item) elif all(is_list_of(value, str) for value in item.values()): # Every column is a list of string for multi-round chat entries = [dict(zip(item, v)) for v in zip(*item.values())] chat = sum((chat_from_entry(entry) for entry in entries), []) elif len(input_columns) == 1 and is_list_of( item[input_columns[0]], dict): # Single input column and it's already a chat. chat = item[input_columns[0]] elif 'dialogue' in input_columns: chat = item['dialogue'] self.dialogue_mode = True else: raise ValueError('Cannot construct chat from the dataset.') prompt_list.append(chat) return prompt_list def infer_last(self, chat: List[dict], index: int, output_handler): assistant_indices = [ i for i, item in enumerate(chat) if item['role'] == 'assistant' ] history = chat[:assistant_indices[-1]] output = self.model.generate_from_template([history], max_out_len=512)[0] output_handler.save_results( origin_prompt=history, prediction=output, idx=index, gold=chat[assistant_indices[-1]]['content'], ) def infer_every(self, chat: List[dict], index: int, output_handler): assistant_indices = [ i for i, item in enumerate(chat) if item['role'] == 'assistant' ] index_copy = index for i in assistant_indices: history = chat[:i] output = self.model.generate_from_template([history], max_out_len=512)[0] chat[i]['content'] = output if not self.dialogue_mode: output_handler.save_multiround_results( origin_prompt=history[-1]['content'], prediction=output, idx=index, gold=chat[i]['content'], ) index += 1 if self.dialogue_mode: # dialogue mode for subjective evaluation assert len(chat) % 2 == 0 round_num = int(len(chat) / 2) preds_list = [] for i in range(round_num): temp_dict = { 'round': i + 1, 'user': chat[i * 2]['content'], 'assistant': chat[i * 2 + 1]['content'] } preds_list.append(temp_dict) output_handler.save_results( origin_prompt=None, prediction=preds_list, idx=index_copy, gold=None, ) def infer_every_with_gt(self, chat: List[dict], index: int, output_handler): assistant_indices = [ i for i, item in enumerate(chat) if item['role'] == 'assistant' ] for i in assistant_indices: history = chat[:i] output = self.model.generate_from_template([history], max_out_len=512)[0] output_handler.save_multiround_results( origin_prompt=history[-1]['content'], prediction=output, idx=index, gold=chat[i]['content'], ) index += 1