File size: 9,262 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 |
import argparse
import fnmatch
from typing import Dict
from mmengine.config import Config, ConfigDict
from opencompass.openicl.icl_inferencer import (AgentInferencer,
ChatInferencer, CLPInferencer,
GenInferencer, LLInferencer,
PPLInferencer,
PPLOnlyInferencer)
from opencompass.registry import ICL_PROMPT_TEMPLATES, ICL_RETRIEVERS
from opencompass.utils import (Menu, build_dataset_from_cfg,
build_model_from_cfg, dataset_abbr_from_cfg,
model_abbr_from_cfg)
def parse_args():
parser = argparse.ArgumentParser(
description='View generated prompts based on datasets (and models)')
parser.add_argument('config', help='Train config file path')
parser.add_argument('-n', '--non-interactive', action='store_true')
parser.add_argument('-a', '--all', action='store_true')
parser.add_argument('-p',
'--pattern',
type=str,
help='To match the dataset abbr.')
parser.add_argument('-c',
'--count',
type=int,
default=1,
help='Number of prompts to print')
args = parser.parse_args()
return args
def parse_model_cfg(model_cfg: ConfigDict) -> Dict[str, ConfigDict]:
model2cfg = {}
for model in model_cfg:
model2cfg[model_abbr_from_cfg(model)] = model
return model2cfg
def parse_dataset_cfg(dataset_cfg: ConfigDict) -> Dict[str, ConfigDict]:
dataset2cfg = {}
for dataset in dataset_cfg:
dataset2cfg[dataset_abbr_from_cfg(dataset)] = dataset
return dataset2cfg
def print_prompts(model_cfg, dataset_cfg, count=1):
# TODO: A really dirty method that copies code from PPLInferencer and
# GenInferencer. In the future, the prompt extraction code should be
# extracted and generalized as a static method in these Inferencers
# and reused here.
if model_cfg:
max_seq_len = model_cfg.max_seq_len
if not model_cfg['type'].is_api:
model_cfg['tokenizer_only'] = True
model = build_model_from_cfg(model_cfg)
else:
max_seq_len = None
model = None
infer_cfg = dataset_cfg.get('infer_cfg')
dataset = build_dataset_from_cfg(dataset_cfg)
ice_template = None
if hasattr(infer_cfg, 'ice_template'):
ice_template = ICL_PROMPT_TEMPLATES.build(infer_cfg['ice_template'])
prompt_template = None
if hasattr(infer_cfg, 'prompt_template'):
prompt_template = ICL_PROMPT_TEMPLATES.build(
infer_cfg['prompt_template'])
infer_cfg['retriever']['dataset'] = dataset
retriever = ICL_RETRIEVERS.build(infer_cfg['retriever'])
ice_idx_list = retriever.retrieve()
supported_inferencer = [
AgentInferencer, PPLInferencer, GenInferencer, CLPInferencer,
PPLOnlyInferencer, ChatInferencer, LLInferencer
]
if infer_cfg.inferencer.type not in supported_inferencer:
print(f'Only {supported_inferencer} are supported')
return
for idx in range(min(count, len(ice_idx_list))):
if issubclass(infer_cfg.inferencer.type,
(PPLInferencer, LLInferencer)):
labels = retriever.get_labels(ice_template=ice_template,
prompt_template=prompt_template)
ice = retriever.generate_ice(ice_idx_list[idx],
ice_template=ice_template)
print('-' * 100)
print('ICE Template:')
print('-' * 100)
print(ice)
print('-' * 100)
for label in labels:
prompt = retriever.generate_label_prompt(
idx,
ice,
label,
ice_template=ice_template,
prompt_template=prompt_template,
remain_sep=None)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(
prompt)
while len(ice_idx_list[idx]
) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx_list[idx])
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx_list[idx] = ice_idx_list[idx][:-1]
ice = retriever.generate_ice(ice_idx_list[idx],
ice_template=ice_template)
prompt = retriever.generate_label_prompt(
idx,
ice,
label,
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='ppl')
print('-' * 100)
print(f'Label: {label}')
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
else:
ice_idx = ice_idx_list[idx]
ice = retriever.generate_ice(ice_idx, ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
if max_seq_len is not None:
prompt_token_num = model.get_token_len_from_template(prompt)
while len(ice_idx) > 0 and prompt_token_num > max_seq_len:
num_ice = len(ice_idx)
print(f'Truncating ice {num_ice} -> {num_ice - 1}',
f'Number of tokens: {prompt_token_num} -> ...')
ice_idx = ice_idx[:-1]
ice = retriever.generate_ice(ice_idx,
ice_template=ice_template)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=infer_cfg.inferencer.get(
'gen_field_replace_token', ''),
ice_template=ice_template,
prompt_template=prompt_template)
prompt_token_num = model.get_token_len_from_template(
prompt)
print(f'Number of tokens: {prompt_token_num}')
if model is not None:
prompt = model.parse_template(prompt, mode='gen')
print('-' * 100)
print('Sample prompt:')
print('-' * 100)
print(prompt)
print('-' * 100)
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# cfg.models =
model2cfg = parse_model_cfg(cfg.models) if 'models' in cfg else {
'None': None
}
if 'datasets' in cfg:
dataset2cfg = parse_dataset_cfg(cfg.datasets)
else:
dataset2cfg = {}
for key in cfg.keys():
if key.endswith('_datasets'):
dataset2cfg.update(parse_dataset_cfg(cfg[key]))
if args.pattern is not None:
matches = fnmatch.filter(dataset2cfg, args.pattern)
if len(matches) == 0:
raise ValueError(
'No dataset match the pattern. Please select from: \n' +
'\n'.join(dataset2cfg.keys()))
dataset2cfg = {k: dataset2cfg[k] for k in matches}
if not args.all:
if not args.non_interactive:
model, dataset = Menu(
[list(model2cfg.keys()),
list(dataset2cfg.keys())], [
f'Please make a selection of {s}:'
for s in ['model', 'dataset']
]).run()
else:
model = list(model2cfg.keys())[0]
dataset = list(dataset2cfg.keys())[0]
model_cfg = model2cfg[model]
dataset_cfg = dataset2cfg[dataset]
print_prompts(model_cfg, dataset_cfg, args.count)
else:
for model_abbr, model_cfg in model2cfg.items():
for dataset_abbr, dataset_cfg in dataset2cfg.items():
print('=' * 64, '[BEGIN]', '=' * 64)
print(f'[MODEL]: {model_abbr}')
print(f'[DATASET]: {dataset_abbr}')
print('---')
print_prompts(model_cfg, dataset_cfg, args.count)
print('=' * 65, '[END]', '=' * 65)
print()
if __name__ == '__main__':
main()
|