|
import argparse |
|
import fnmatch |
|
from typing import Dict |
|
|
|
from mmengine.config import Config, ConfigDict |
|
|
|
from opencompass.openicl.icl_inferencer import (AgentInferencer, |
|
ChatInferencer, CLPInferencer, |
|
GenInferencer, LLInferencer, |
|
PPLInferencer, |
|
PPLOnlyInferencer) |
|
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS |
|
from opencompass.utils import (Menu, build_dataset_from_cfg, |
|
build_model_from_cfg, dataset_abbr_from_cfg, |
|
model_abbr_from_cfg) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='View generated prompts based on datasets (and models)') |
|
parser.add_argument('config', help='Train config file path') |
|
parser.add_argument('-n', '--non-interactive', action='store_true') |
|
parser.add_argument('-a', '--all', action='store_true') |
|
parser.add_argument('-p', |
|
'--pattern', |
|
type=str, |
|
help='To match the dataset abbr.') |
|
parser.add_argument('-c', |
|
'--count', |
|
type=int, |
|
default=1, |
|
help='Number of prompts to print') |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def parse_model_cfg(model_cfg: ConfigDict) -> Dict[str, ConfigDict]: |
|
model2cfg = {} |
|
for model in model_cfg: |
|
model2cfg[model_abbr_from_cfg(model)] = model |
|
return model2cfg |
|
|
|
|
|
def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]: |
|
dataset2cfg = {} |
|
for dataset in dataset_cfg: |
|
dataset2cfg[dataset_abbr_from_cfg(dataset)] = dataset |
|
return dataset2cfg |
|
|
|
|
|
def print_prompts(model_cfg, dataset_cfg, count=1): |
|
|
|
|
|
|
|
|
|
if model_cfg: |
|
max_seq_len = model_cfg.max_seq_len |
|
if not model_cfg['type'].is_api: |
|
model_cfg['tokenizer_only'] = True |
|
model = build_model_from_cfg(model_cfg) |
|
else: |
|
max_seq_len = None |
|
model = None |
|
|
|
infer_cfg = dataset_cfg.get('infer_cfg') |
|
|
|
dataset = build_dataset_from_cfg(dataset_cfg) |
|
|
|
ice_template = None |
|
if hasattr(infer_cfg, 'ice_template'): |
|
ice_template = ICL_PROMPT_TEMPLATES.build(infer_cfg['ice_template']) |
|
|
|
prompt_template = None |
|
if hasattr(infer_cfg, 'prompt_template'): |
|
prompt_template = ICL_PROMPT_TEMPLATES.build( |
|
infer_cfg['prompt_template']) |
|
|
|
infer_cfg['retriever']['dataset'] = dataset |
|
retriever = ICL_RETRIEVERS.build(infer_cfg['retriever']) |
|
|
|
ice_idx_list = retriever.retrieve() |
|
|
|
supported_inferencer = [ |
|
AgentInferencer, PPLInferencer, GenInferencer, CLPInferencer, |
|
PPLOnlyInferencer, ChatInferencer, LLInferencer |
|
] |
|
if infer_cfg.inferencer.type not in supported_inferencer: |
|
print(f'Only {supported_inferencer} are supported') |
|
return |
|
|
|
for idx in range(min(count, len(ice_idx_list))): |
|
if issubclass(infer_cfg.inferencer.type, |
|
(PPLInferencer, LLInferencer)): |
|
labels = retriever.get_labels(ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
ice = retriever.generate_ice(ice_idx_list[idx], |
|
ice_template=ice_template) |
|
print('-' * 100) |
|
print('ICE Template:') |
|
print('-' * 100) |
|
print(ice) |
|
print('-' * 100) |
|
for label in labels: |
|
prompt = retriever.generate_label_prompt( |
|
idx, |
|
ice, |
|
label, |
|
ice_template=ice_template, |
|
prompt_template=prompt_template, |
|
remain_sep=None) |
|
if max_seq_len is not None: |
|
prompt_token_num = model.get_token_len_from_template( |
|
prompt) |
|
while len(ice_idx_list[idx] |
|
) > 0 and prompt_token_num > max_seq_len: |
|
num_ice = len(ice_idx_list[idx]) |
|
print(f'Truncating ice {num_ice} -> {num_ice - 1}', |
|
f'Number of tokens: {prompt_token_num} -> ...') |
|
ice_idx_list[idx] = ice_idx_list[idx][:-1] |
|
ice = retriever.generate_ice(ice_idx_list[idx], |
|
ice_template=ice_template) |
|
prompt = retriever.generate_label_prompt( |
|
idx, |
|
ice, |
|
label, |
|
ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
prompt_token_num = model.get_token_len_from_template( |
|
prompt) |
|
print(f'Number of tokens: {prompt_token_num}') |
|
if model is not None: |
|
prompt = model.parse_template(prompt, mode='ppl') |
|
print('-' * 100) |
|
print(f'Label: {label}') |
|
print('Sample prompt:') |
|
print('-' * 100) |
|
print(prompt) |
|
print('-' * 100) |
|
else: |
|
ice_idx = ice_idx_list[idx] |
|
ice = retriever.generate_ice(ice_idx, ice_template=ice_template) |
|
prompt = retriever.generate_prompt_for_generate_task( |
|
idx, |
|
ice, |
|
gen_field_replace_token=infer_cfg.inferencer.get( |
|
'gen_field_replace_token', ''), |
|
ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
if max_seq_len is not None: |
|
prompt_token_num = model.get_token_len_from_template(prompt) |
|
while len(ice_idx) > 0 and prompt_token_num > max_seq_len: |
|
num_ice = len(ice_idx) |
|
print(f'Truncating ice {num_ice} -> {num_ice - 1}', |
|
f'Number of tokens: {prompt_token_num} -> ...') |
|
ice_idx = ice_idx[:-1] |
|
ice = retriever.generate_ice(ice_idx, |
|
ice_template=ice_template) |
|
prompt = retriever.generate_prompt_for_generate_task( |
|
idx, |
|
ice, |
|
gen_field_replace_token=infer_cfg.inferencer.get( |
|
'gen_field_replace_token', ''), |
|
ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
prompt_token_num = model.get_token_len_from_template( |
|
prompt) |
|
print(f'Number of tokens: {prompt_token_num}') |
|
if model is not None: |
|
prompt = model.parse_template(prompt, mode='gen') |
|
print('-' * 100) |
|
print('Sample prompt:') |
|
print('-' * 100) |
|
print(prompt) |
|
print('-' * 100) |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
cfg = Config.fromfile(args.config) |
|
|
|
model2cfg = parse_model_cfg(cfg.models) if 'models' in cfg else { |
|
'None': None |
|
} |
|
if 'datasets' in cfg: |
|
dataset2cfg = parse_dataset_cfg(cfg.datasets) |
|
else: |
|
dataset2cfg = {} |
|
for key in cfg.keys(): |
|
if key.endswith('_datasets'): |
|
dataset2cfg.update(parse_dataset_cfg(cfg[key])) |
|
|
|
if args.pattern is not None: |
|
matches = fnmatch.filter(dataset2cfg, args.pattern) |
|
if len(matches) == 0: |
|
raise ValueError( |
|
'No dataset match the pattern. Please select from: \n' + |
|
'\n'.join(dataset2cfg.keys())) |
|
dataset2cfg = {k: dataset2cfg[k] for k in matches} |
|
|
|
if not args.all: |
|
if not args.non_interactive: |
|
model, dataset = Menu( |
|
[list(model2cfg.keys()), |
|
list(dataset2cfg.keys())], [ |
|
f'Please make a selection of {s}:' |
|
for s in ['model', 'dataset'] |
|
]).run() |
|
else: |
|
model = list(model2cfg.keys())[0] |
|
dataset = list(dataset2cfg.keys())[0] |
|
model_cfg = model2cfg[model] |
|
dataset_cfg = dataset2cfg[dataset] |
|
print_prompts(model_cfg, dataset_cfg, args.count) |
|
else: |
|
for model_abbr, model_cfg in model2cfg.items(): |
|
for dataset_abbr, dataset_cfg in dataset2cfg.items(): |
|
print('=' * 64, '[BEGIN]', '=' * 64) |
|
print(f'[MODEL]: {model_abbr}') |
|
print(f'[DATASET]: {dataset_abbr}') |
|
print('---') |
|
print_prompts(model_cfg, dataset_cfg, args.count) |
|
print('=' * 65, '[END]', '=' * 65) |
|
print() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|