TwT-6's picture
Upload 2667 files
256a159 verified
"""Chat Inferencer."""
import os
import os.path as osp
from typing import List, Optional, Union
import mmengine
from mmengine import is_list_of
from tqdm import tqdm
from opencompass.models import APITemplateParser as _APITemplateParser
from opencompass.models import BaseModel
from opencompass.models import LMTemplateParser as _LMTemplateParser
from opencompass.registry import ICL_INFERENCERS
from opencompass.utils.prompt import PromptList
from ..icl_prompt_template import PromptTemplate
from ..icl_retriever import BaseRetriever
from ..utils.logging import get_logger
from .icl_base_inferencer import BaseInferencer, dump_results_dict
logger = get_logger(__name__)
def promptlist_to_openai(prompt: Union[str, PromptList]):
output = []
if isinstance(prompt, str):
return [dict(role='user', content=prompt)]
for item in prompt:
if 'section' in item:
continue
if isinstance(item, str) and item:
output.append(dict(role='user', content=item))
elif item['role'] == 'SYSTEM':
output.append(dict(role='system', content=item['prompt']))
elif item['role'] == 'HUMAN':
output.append(dict(role='user', content=item['prompt']))
elif item['role'] == 'BOT':
output.append(dict(role='assistant', content=item['prompt']))
return output
class LMTemplateParser:
"""LMTemplateParser accepts OpenAI format dialog inputs."""
def __init__(self, meta_template: Optional[dict] = None):
self.meta_template = meta_template
self.roles = {}
role_mapping = {
'SYSTEM': 'system',
'HUMAN': 'user',
'BOT': 'assistant',
}
if meta_template:
for item in meta_template.get('round', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
for item in meta_template.get('reserved_roles', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
def parse_template(self, chat: List[dict], mode='gen') -> str:
if is_list_of(chat, list):
# Handle batch inputs
return [self.parse_template(item) for item in chat]
assert is_list_of(chat, dict)
prompt = ''
if self.roles:
for dialog in chat:
role_cfg = self.roles.get(dialog['role'], {})
prompt += (role_cfg.get('begin') or '')
prompt += (dialog.get('content') or '')
prompt += (role_cfg.get('end') or '')
prompt += (self.roles['assistant'].get('begin') or '')
else:
# in case the model does not have any meta template
last_sep = ''
for item in chat:
prompt += last_sep + (item.get('content') or '')
last_sep = '\n'
return prompt
class APITemplateParser:
"""APITemplateParser accepts OpenAI format dialog inputs."""
def __init__(self, meta_template: Optional[dict] = None):
self.meta_template = meta_template
self.roles = {}
role_mapping = {
'SYSTEM': 'system',
'HUMAN': 'user',
'BOT': 'assistant',
}
if meta_template:
for item in meta_template.get('round', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
for item in meta_template.get('reserved_roles', []):
role = role_mapping.get(item['role'], item['role'])
self.roles[role] = item.copy()
else:
self.roles = dict(
system=dict(api_role='SYSTEM'),
user=dict(api_role='HUMAN'),
assistant=dict(api_role='BOT', generate=True),
)
def parse_template(self, chat: List[dict], mode='gen') -> str:
if is_list_of(chat, list):
# Handle batch inputs
return [self.parse_template(item) for item in chat]
assert is_list_of(chat, dict)
prompt = []
for dialog in chat:
if dialog['role'] in self.roles:
role = self.roles[dialog['role']]['api_role']
else:
role = dialog['role']
prompt.append(dict(role=role, prompt=dialog.get('content') or ''))
return PromptList(prompt)
class ChatOutputHandler:
def __init__(self) -> None:
self.results_dict = {}
def write_to_json(self, save_dir: str, filename: str):
"""Dump the result to a json file."""
dump_results_dict(self.results_dict, osp.join(save_dir, filename))
def save_results(self,
origin_prompt: list,
prediction: str,
idx: int,
gold: str = None):
result_dict = {}
if gold:
result_dict['gold'] = gold
result_dict.update({
'prediction': prediction,
'origin_prompt': origin_prompt,
})
self.results_dict[str(idx)] = result_dict
def save_multiround_results(self,
origin_prompt: list,
prediction: str,
idx: int,
gold: str = None):
result_dict = self.results_dict.get(str(idx), {
'gold': [],
'prediction': [],
'origin_prompt': [],
})
result_dict['gold'].append(gold)
result_dict['prediction'].append(prediction)
result_dict['origin_prompt'].append(origin_prompt)
self.results_dict[str(idx)] = result_dict
@ICL_INFERENCERS.register_module()
class ChatInferencer(BaseInferencer):
HandlerType = ChatOutputHandler
def __init__(
self,
model,
output_json_filepath: Optional[str] = './icl_inference_output',
output_json_filename: Optional[str] = 'predictions',
save_every: Optional[int] = 1,
infer_mode: str = 'last',
**kwargs) -> None:
super().__init__(
model=model,
output_json_filename=output_json_filename,
output_json_filepath=output_json_filepath,
**kwargs,
)
assert infer_mode in ['last', 'every', 'every_with_gt']
self.infer_mode = infer_mode
self.model: BaseModel
self._set_meta_template(self.model)
if self.model.is_api and save_every is None:
save_every = 1
self.save_every = save_every
self.dialogue_mode = False
def _set_meta_template(self, model):
origin = model.template_parser
if isinstance(origin, _APITemplateParser):
model.template_parser = APITemplateParser(origin.meta_template)
if isinstance(origin, _LMTemplateParser):
model.template_parser = LMTemplateParser(origin.meta_template)
def inference(self,
retriever: BaseRetriever,
ice_template: Optional[PromptTemplate] = None,
prompt_template: Optional[PromptTemplate] = None,
output_json_filepath: Optional[str] = None,
output_json_filename: Optional[str] = None) -> dict:
# 1. Preparation for output logs
output_handler = self.HandlerType()
if output_json_filepath is None:
output_json_filepath = self.output_json_filepath
if output_json_filename is None:
output_json_filename = self.output_json_filename
# 2. Get results of retrieval process
ice_idx_list = retriever.retrieve()
# 3. Generate prompts for testing input
chat_list = self.get_chat_list(
ice_idx_list,
retriever,
prompt_template=prompt_template,
)
# Create tmp json file for saving intermediate results and future
# resuming
index = 0
tmp_json_filepath = os.path.join(output_json_filepath,
'tmp_' + output_json_filename)
if osp.exists(tmp_json_filepath):
# TODO: move resume to output handler
try:
tmp_result_dict = mmengine.load(tmp_json_filepath)
except Exception:
pass
else:
output_handler.results_dict = tmp_result_dict
index = len(tmp_result_dict)
# 4. Wrap prompts with Dataloader
dataloader = self.get_dataloader(chat_list[index:], batch_size=1)
# 5. Inference for prompts in each batch
logger.info('Starting inference process...')
for datum in tqdm(dataloader, disable=not self.is_main_process):
chat = datum[0]
if self.infer_mode == 'last':
self.infer_last(chat, index, output_handler)
elif self.infer_mode == 'every':
self.infer_every(chat, index, output_handler)
elif self.infer_mode == 'every_with_gt':
self.infer_every_with_gt(chat, index, output_handler)
index += 1
# Save intermediate results
if (self.save_every is not None and index % self.save_every == 0
and self.is_main_process):
output_handler.write_to_json(output_json_filepath,
'tmp_' + output_json_filename)
# 4. Output
if self.is_main_process:
os.makedirs(output_json_filepath, exist_ok=True)
output_handler.write_to_json(output_json_filepath,
output_json_filename)
if osp.exists(tmp_json_filepath):
os.remove(tmp_json_filepath)
return output_handler.results_dict
def get_chat_list(self,
ice_idx_list: List[List[int]],
retriever: BaseRetriever,
prompt_template: Optional[PromptTemplate] = None):
prompt_list = []
input_columns = retriever.dataset_reader.input_columns
output_column = retriever.dataset_reader.output_column
def chat_from_entry(entry):
if prompt_template is None and len(input_columns) == 1:
# Directly use the input column as the user input
user = entry.get(input_columns[0])
assistant = entry.get(output_column, '')
return [
dict(role='user', content=user),
dict(role='assistant', content=assistant),
]
elif prompt_template is not None:
# Use prompt template to generate chat history
chat = promptlist_to_openai(
prompt_template.generate_item(entry))
gold = entry.get(output_column, '')
if chat[-1]['role'] != 'assistant':
chat.append(dict(role='assistant', content=gold))
return chat
else:
raise ValueError()
for idx, ice_idx in enumerate(ice_idx_list):
# NOTE: The in-context examples won't be used by now.
item = {
k: v
for k, v in retriever.test_ds[idx].items()
if k in input_columns or k == output_column
}
if all(isinstance(value, str) for value in item.values()):
# Every column is a single string
chat = chat_from_entry(item)
elif all(is_list_of(value, str) for value in item.values()):
# Every column is a list of string for multi-round chat
entries = [dict(zip(item, v)) for v in zip(*item.values())]
chat = sum((chat_from_entry(entry) for entry in entries), [])
elif len(input_columns) == 1 and is_list_of(
item[input_columns[0]], dict):
# Single input column and it's already a chat.
chat = item[input_columns[0]]
elif 'dialogue' in input_columns:
chat = item['dialogue']
self.dialogue_mode = True
else:
raise ValueError('Cannot construct chat from the dataset.')
prompt_list.append(chat)
return prompt_list
def infer_last(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
history = chat[:assistant_indices[-1]]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output_handler.save_results(
origin_prompt=history,
prediction=output,
idx=index,
gold=chat[assistant_indices[-1]]['content'],
)
def infer_every(self, chat: List[dict], index: int, output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
index_copy = index
for i in assistant_indices:
history = chat[:i]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
chat[i]['content'] = output
if not self.dialogue_mode:
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
index += 1
if self.dialogue_mode:
# dialogue mode for subjective evaluation
assert len(chat) % 2 == 0
round_num = int(len(chat) / 2)
preds_list = []
for i in range(round_num):
temp_dict = {
'round': i + 1,
'user': chat[i * 2]['content'],
'assistant': chat[i * 2 + 1]['content']
}
preds_list.append(temp_dict)
output_handler.save_results(
origin_prompt=None,
prediction=preds_list,
idx=index_copy,
gold=None,
)
def infer_every_with_gt(self, chat: List[dict], index: int,
output_handler):
assistant_indices = [
i for i, item in enumerate(chat) if item['role'] == 'assistant'
]
for i in assistant_indices:
history = chat[:i]
output = self.model.generate_from_template([history],
max_out_len=512)[0]
output_handler.save_multiround_results(
origin_prompt=history[-1]['content'],
prediction=output,
idx=index,
gold=chat[i]['content'],
)
index += 1