import argparse import json import os import os.path as osp import re import mmengine from mmengine import Config from mmengine.utils import mkdir_or_exist from opencompass.datasets.humanevalx import _clean_up_code from opencompass.utils import (dataset_abbr_from_cfg, get_infer_output_path, get_logger, model_abbr_from_cfg) def parse_args(): parser = argparse.ArgumentParser( description='Collect Humanevalx dataset predictions.') parser.add_argument('config', help='Config file path') parser.add_argument('-r', '--reuse', nargs='?', type=str, const='latest', help='Reuse previous outputs & results, and run any ' 'missing jobs presented in the config. If its ' 'argument is not specified, the latest results in ' 'the work_dir will be reused. The argument should ' 'also be a specific timestamp, e.g. 20230516_144254'), args = parser.parse_args() return args _LANGUAGE_NAME_DICT = { 'cpp': 'CPP', 'go': 'Go', 'java': 'Java', 'js': 'JavaScript', 'python': 'Python', 'rust': 'Rust', } FAILED = 0 SUCCEED = 1 def gpt_python_postprocess(ori_prompt: str, text: str) -> str: """Better answer postprocessor for better instruction-aligned models like GPT.""" if '```' in text: blocks = re.findall(r'```(.*?)```', text, re.DOTALL) if len(blocks) == 0: text = text.split('```')[1] # fall back to default strategy else: text = blocks[0] # fetch the first code block if not text.startswith('\n'): # in case starting with ```python text = text[max(text.find('\n') + 1, 0):] match_ori = re.search(r'def(.*?)\(', ori_prompt) match = re.search(r'def(.*?)\(', text) if match: if match.group() == match_ori.group(): text = re.sub('def(.*?)\n', '', text, count=1) for c_index, c in enumerate(text[:5]): if c != ' ': text = ' ' * (4 - c_index) + text break text = text.split('\n\n\n')[0] return text def wizardcoder_postprocess(text: str) -> str: """Postprocess for WizardCoder Models.""" if '```' in text: blocks = re.findall(r'```(.*?)```', text, re.DOTALL) if len(blocks) == 0: text = text.split('```')[1] # fall back to default strategy else: text = blocks[0] # fetch the first code block if not text.startswith('\n'): # in case starting with ```python text = text[max(text.find('\n') + 1, 0):] else: match = re.search(r'Here(.*?)\n', text) if match: text = re.sub('Here(.*?)\n', '', text, count=1) return text def collect_preds(filename: str): # in case the prediction is partial root, ext = osp.splitext(filename) partial_filename = root + '_0' + ext # collect all the prediction results if not osp.exists(osp.realpath(filename)) and not osp.exists( osp.realpath(partial_filename)): print(f'No predictions found for {filename}') return FAILED, None, None else: if osp.exists(osp.realpath(filename)): preds = mmengine.load(filename) pred_strs = [ preds[str(i)]['prediction'] for i in range(len(preds)) ] ori_prompt_strs = [ preds[str(i)]['origin_prompt'] for i in range(len(preds)) ] else: filename = partial_filename pred_strs = [] ori_prompt_strs = [] i = 1 while osp.exists(osp.realpath(filename)): preds = mmengine.load(filename) filename = root + f'_{i}' + ext i += 1 pred_strs += [ preds[str(i)]['prediction'] for i in range(len(preds)) ] ori_prompt_strs += [ preds[str(i)]['origin_prompt'] for i in range(len(preds)) ] return SUCCEED, ori_prompt_strs, pred_strs def main(): args = parse_args() # initialize logger logger = get_logger(log_level='INFO') cfg = Config.fromfile(args.config) cfg.setdefault('work_dir', './outputs/default/') assert args.reuse, 'Please provide the experienment work dir.' if args.reuse: if args.reuse == 'latest': if not os.path.exists(cfg.work_dir) or not os.listdir( cfg.work_dir): logger.warning('No previous results to reuse!') else: dirs = os.listdir(cfg.work_dir) dir_time_str = sorted(dirs)[-1] else: dir_time_str = args.reuse logger.info(f'Reusing experiements from {dir_time_str}') # update "actual" work_dir cfg['work_dir'] = osp.join(cfg.work_dir, dir_time_str) for model in cfg.models: model_abbr = model_abbr_from_cfg(model) for dataset in cfg.datasets: dataset_abbr = dataset_abbr_from_cfg(dataset) filename = get_infer_output_path( model, dataset, osp.join(cfg.work_dir, 'predictions')) succeed, ori_prompt_strs, pred_strs = collect_preds(filename) if not succeed: continue # infer the language type for k, v in _LANGUAGE_NAME_DICT.items(): if k in dataset_abbr: lang = k task = v break # special postprocess for GPT if model_abbr in [ 'WizardCoder-1B-V1.0', 'WizardCoder-3B-V1.0', 'WizardCoder-15B-V1.0', 'WizardCoder-Python-13B-V1.0', 'WizardCoder-Python-34B-V1.0', ]: predictions = [{ 'task_id': f'{task}/{i}', 'generation': wizardcoder_postprocess(pred), } for i, pred in enumerate(pred_strs)] elif 'CodeLlama' not in model_abbr and lang == 'python': predictions = [{ 'task_id': f'{task}/{i}', 'generation': gpt_python_postprocess(ori_prompt, pred), } for i, (ori_prompt, pred) in enumerate(zip(ori_prompt_strs, pred_strs))] else: predictions = [{ 'task_id': f'{task}/{i}', 'generation': _clean_up_code(pred, lang), } for i, pred in enumerate(pred_strs)] # save processed results if not exists result_file_path = os.path.join(cfg['work_dir'], 'humanevalx', model_abbr, f'humanevalx_{lang}.json') if osp.exists(result_file_path): logger.info( f'File exists for {model_abbr}, skip copy from predictions.' # noqa ) else: mkdir_or_exist(osp.split(result_file_path)[0]) with open(result_file_path, 'w') as f: for pred in predictions: f.write(json.dumps(pred) + '\n') if __name__ == '__main__': main()