api-demo / opencompass-my-api /tools /case_analyzer.py
TwT-6's picture
Upload 2667 files
256a159 verified
import argparse
import copy
import json
import os.path as osp
import mmengine
from mmengine.config import Config, ConfigDict
from mmengine.utils import mkdir_or_exist
from tqdm import tqdm
from opencompass.registry import TEXT_POSTPROCESSORS
from opencompass.utils import build_dataset_from_cfg, get_infer_output_path
def parse_args():
parser = argparse.ArgumentParser(description='Run case analyzer')
parser.add_argument('config', help='Train config file path')
parser.add_argument(
'-f',
'--force',
help='Force to run the task even if the results already exist',
action='store_true',
default=False)
parser.add_argument('-w',
'--work-dir',
help='Work path, all the outputs will be '
'saved in this path, including the slurm logs, '
'the evaluation results, the summary results, etc.'
'If not specified, the work_dir will be set to '
'./outputs/default.',
default=None,
type=str)
args = parser.parse_args()
return args
class BadcaseShower:
""""""
def __init__(self, cfg: ConfigDict) -> None:
self.cfg = cfg
self.model_cfg = copy.deepcopy(self.cfg['model'])
self.dataset_cfg = copy.deepcopy(self.cfg['dataset'])
self.work_dir = self.cfg.get('work_dir')
# Load Dataset
self.eval_cfg = self.dataset_cfg.get('eval_cfg')
self.ds_split = self.eval_cfg.get('ds_split', None)
self.ds_column = self.eval_cfg.get('ds_column')
def run(self):
filename = get_infer_output_path(
self.model_cfg, self.dataset_cfg,
osp.join(self.work_dir, 'predictions'))
root, ext = osp.splitext(filename)
partial_filename = root + '_0' + ext
if not osp.exists(osp.realpath(filename)) and not osp.exists(
osp.realpath(partial_filename)):
print(f'{filename} not found')
return
dataset = build_dataset_from_cfg(self.dataset_cfg)
# Postprocess dataset if necessary
if 'dataset_postprocessor' in self.eval_cfg:
def postprocess(sample):
s = sample[self.ds_column]
proc = TEXT_POSTPROCESSORS.get(
self.eval_cfg['dataset_postprocessor']['type'])
sample[self.ds_column] = proc(s)
return sample
dataset = dataset.map(postprocess)
# Load predictions
if osp.exists(osp.realpath(filename)):
preds = mmengine.load(filename)
else:
filename = partial_filename
preds, offset = {}, 0
i = 1
while osp.exists(osp.realpath(filename)):
_preds = mmengine.load(filename)
filename = root + f'_{i}' + ext
i += 1
for _o in range(len(_preds)):
preds[str(offset)] = _preds[str(_o)]
offset += 1
pred_strs = [preds[str(i)]['prediction'] for i in range(len(preds))]
# Postprocess predictions if necessary
if 'pred_postprocessor' in self.eval_cfg:
proc = TEXT_POSTPROCESSORS.get(
self.eval_cfg['pred_postprocessor']['type'])
pred_strs = [proc(s) for s in pred_strs]
if self.ds_split:
references = dataset[self.ds_split][self.ds_column]
else:
references = dataset[self.ds_column]
if len(pred_strs) != len(references):
print('length mismatch')
return
# combine cases
allcase, badcase = [], []
if 'in-context examples' in preds['0']:
# ppl eval
for i, (pred_str,
reference) in enumerate(zip(tqdm(pred_strs), references)):
ref_str = str(reference)
try:
pred_prompt = preds[str(i)]['label: ' +
pred_str]['testing input']
pred_PPL = preds[str(i)]['label: ' + pred_str]['PPL']
ref_prompt = preds[str(i)]['label: ' +
ref_str]['testing input']
ref_PPL = preds[str(i)]['label: ' + ref_str]['PPL']
except KeyError:
continue
item = {
'prediction_prompt': pred_prompt,
'prediction': pred_str,
'prediction_PPL': pred_PPL,
'reference_prompt': ref_prompt,
'reference': ref_str,
'reference_PPL': ref_PPL
}
if pred_str != ref_str:
badcase.append(item)
allcase.append(item)
else:
allcase.append(item)
else:
# gen eval
for i, (pred_str,
reference) in enumerate(zip(tqdm(pred_strs), references)):
ref_str = str(reference)
origin_prompt = preds[str(i)]['origin_prompt']
item = {
'origin_prompt': origin_prompt,
'prediction': pred_str,
'reference': ref_str
}
# FIXME: we now consider all cases as bad cases
badcase.append(item)
allcase.append(item)
# Save result
out_path = get_infer_output_path(
self.cfg['model'], self.cfg['dataset'],
osp.join(self.work_dir, 'case_analysis/bad'))
mkdir_or_exist(osp.split(out_path)[0])
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(badcase, f, indent=4, ensure_ascii=False)
out_path = get_infer_output_path(
self.cfg['model'], self.cfg['dataset'],
osp.join(self.work_dir, 'case_analysis/all'))
mkdir_or_exist(osp.split(out_path)[0])
with open(out_path, 'w', encoding='utf-8') as f:
json.dump(allcase, f, indent=4, ensure_ascii=False)
def dispatch_tasks(cfg, force=False):
for model in cfg['models']:
for dataset in cfg['datasets']:
if force or not osp.exists(
get_infer_output_path(
model, dataset,
osp.join(cfg['work_dir'], 'case_analysis/all'))):
BadcaseShower({
'model': model,
'dataset': dataset,
'work_dir': cfg['work_dir']
}).run()
def main():
args = parse_args()
cfg = Config.fromfile(args.config)
# set work_dir
if args.work_dir is not None:
cfg['work_dir'] = args.work_dir
else:
cfg.setdefault('work_dir', './outputs/default')
dispatch_tasks(cfg, force=args.force)
if __name__ == '__main__':
main()