import argparse import fnmatch from typing import Dict from mmengine.config import Config, ConfigDict from opencompass.openicl.icl_inferencer import (AgentInferencer, ChatInferencer, CLPInferencer, GenInferencer, LLInferencer, PPLInferencer, PPLOnlyInferencer) from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS from opencompass.utils import (Menu, build_dataset_from_cfg, build_model_from_cfg, dataset_abbr_from_cfg, model_abbr_from_cfg) def parse_args(): parser = argparse.ArgumentParser( description='View generated prompts based on datasets (and models)') parser.add_argument('config', help='Train config file path') parser.add_argument('-n', '--non-interactive', action='store_true') parser.add_argument('-a', '--all', action='store_true') parser.add_argument('-p', '--pattern', type=str, help='To match the dataset abbr.') parser.add_argument('-c', '--count', type=int, default=1, help='Number of prompts to print') args = parser.parse_args() return args def parse_model_cfg(model_cfg: ConfigDict) -> Dict[str, ConfigDict]: model2cfg = {} for model in model_cfg: model2cfg[model_abbr_from_cfg(model)] = model return model2cfg def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]: dataset2cfg = {} for dataset in dataset_cfg: dataset2cfg[dataset_abbr_from_cfg(dataset)] = dataset return dataset2cfg def print_prompts(model_cfg, dataset_cfg, count=1): # TODO: A really dirty method that copies code from PPLInferencer and # GenInferencer. In the future, the prompt extraction code should be # extracted and generalized as a static method in these Inferencers # and reused here. if model_cfg: max_seq_len = model_cfg.max_seq_len if not model_cfg['type'].is_api: model_cfg['tokenizer_only'] = True model = build_model_from_cfg(model_cfg) else: max_seq_len = None model = None infer_cfg = dataset_cfg.get('infer_cfg') dataset = build_dataset_from_cfg(dataset_cfg) ice_template = None if hasattr(infer_cfg, 'ice_template'): ice_template = ICL_PROMPT_TEMPLATES.build(infer_cfg['ice_template']) prompt_template = None if hasattr(infer_cfg, 'prompt_template'): prompt_template = ICL_PROMPT_TEMPLATES.build( infer_cfg['prompt_template']) infer_cfg['retriever']['dataset'] = dataset retriever = ICL_RETRIEVERS.build(infer_cfg['retriever']) ice_idx_list = retriever.retrieve() supported_inferencer = [ AgentInferencer, PPLInferencer, GenInferencer, CLPInferencer, PPLOnlyInferencer, ChatInferencer, LLInferencer ] if infer_cfg.inferencer.type not in supported_inferencer: print(f'Only {supported_inferencer} are supported') return for idx in range(min(count, len(ice_idx_list))): if issubclass(infer_cfg.inferencer.type, (PPLInferencer, LLInferencer)): labels = retriever.get_labels(ice_template=ice_template, prompt_template=prompt_template) ice = retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template) print('-' * 100) print('ICE Template:') print('-' * 100) print(ice) print('-' * 100) for label in labels: prompt = retriever.generate_label_prompt( idx, ice, label, ice_template=ice_template, prompt_template=prompt_template, remain_sep=None) if max_seq_len is not None: prompt_token_num = model.get_token_len_from_template( prompt) while len(ice_idx_list[idx] ) > 0 and prompt_token_num > max_seq_len: num_ice = len(ice_idx_list[idx]) print(f'Truncating ice {num_ice} -> {num_ice - 1}', f'Number of tokens: {prompt_token_num} -> ...') ice_idx_list[idx] = ice_idx_list[idx][:-1] ice = retriever.generate_ice(ice_idx_list[idx], ice_template=ice_template) prompt = retriever.generate_label_prompt( idx, ice, label, ice_template=ice_template, prompt_template=prompt_template) prompt_token_num = model.get_token_len_from_template( prompt) print(f'Number of tokens: {prompt_token_num}') if model is not None: prompt = model.parse_template(prompt, mode='ppl') print('-' * 100) print(f'Label: {label}') print('Sample prompt:') print('-' * 100) print(prompt) print('-' * 100) else: ice_idx = ice_idx_list[idx] ice = retriever.generate_ice(ice_idx, ice_template=ice_template) prompt = retriever.generate_prompt_for_generate_task( idx, ice, gen_field_replace_token=infer_cfg.inferencer.get( 'gen_field_replace_token', ''), ice_template=ice_template, prompt_template=prompt_template) if max_seq_len is not None: prompt_token_num = model.get_token_len_from_template(prompt) while len(ice_idx) > 0 and prompt_token_num > max_seq_len: num_ice = len(ice_idx) print(f'Truncating ice {num_ice} -> {num_ice - 1}', f'Number of tokens: {prompt_token_num} -> ...') ice_idx = ice_idx[:-1] ice = retriever.generate_ice(ice_idx, ice_template=ice_template) prompt = retriever.generate_prompt_for_generate_task( idx, ice, gen_field_replace_token=infer_cfg.inferencer.get( 'gen_field_replace_token', ''), ice_template=ice_template, prompt_template=prompt_template) prompt_token_num = model.get_token_len_from_template( prompt) print(f'Number of tokens: {prompt_token_num}') if model is not None: prompt = model.parse_template(prompt, mode='gen') print('-' * 100) print('Sample prompt:') print('-' * 100) print(prompt) print('-' * 100) def main(): args = parse_args() cfg = Config.fromfile(args.config) # cfg.models = model2cfg = parse_model_cfg(cfg.models) if 'models' in cfg else { 'None': None } if 'datasets' in cfg: dataset2cfg = parse_dataset_cfg(cfg.datasets) else: dataset2cfg = {} for key in cfg.keys(): if key.endswith('_datasets'): dataset2cfg.update(parse_dataset_cfg(cfg[key])) if args.pattern is not None: matches = fnmatch.filter(dataset2cfg, args.pattern) if len(matches) == 0: raise ValueError( 'No dataset match the pattern. Please select from: \n' + '\n'.join(dataset2cfg.keys())) dataset2cfg = {k: dataset2cfg[k] for k in matches} if not args.all: if not args.non_interactive: model, dataset = Menu( [list(model2cfg.keys()), list(dataset2cfg.keys())], [ f'Please make a selection of {s}:' for s in ['model', 'dataset'] ]).run() else: model = list(model2cfg.keys())[0] dataset = list(dataset2cfg.keys())[0] model_cfg = model2cfg[model] dataset_cfg = dataset2cfg[dataset] print_prompts(model_cfg, dataset_cfg, args.count) else: for model_abbr, model_cfg in model2cfg.items(): for dataset_abbr, dataset_cfg in dataset2cfg.items(): print('=' * 64, '[BEGIN]', '=' * 64) print(f'[MODEL]: {model_abbr}') print(f'[DATASET]: {dataset_abbr}') print('---') print_prompts(model_cfg, dataset_cfg, args.count) print('=' * 65, '[END]', '=' * 65) print() if __name__ == '__main__': main()