|
import argparse |
|
import json |
|
import os |
|
import os.path as osp |
|
import re |
|
|
|
import mmengine |
|
from mmengine import Config |
|
from mmengine.utils import mkdir_or_exist |
|
|
|
from opencompass.datasets.humanevalx import _clean_up_code |
|
from opencompass.utils import (dataset_abbr_from_cfg, get_infer_output_path, |
|
get_logger, model_abbr_from_cfg) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser( |
|
description='Collect Humanevalx dataset predictions.') |
|
parser.add_argument('config', help='Config file path') |
|
parser.add_argument('-r', |
|
'--reuse', |
|
nargs='?', |
|
type=str, |
|
const='latest', |
|
help='Reuse previous outputs & results, and run any ' |
|
'missing jobs presented in the config. If its ' |
|
'argument is not specified, the latest results in ' |
|
'the work_dir will be reused. The argument should ' |
|
'also be a specific timestamp, e.g. 20230516_144254'), |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
_LANGUAGE_NAME_DICT = { |
|
'cpp': 'CPP', |
|
'go': 'Go', |
|
'java': 'Java', |
|
'js': 'JavaScript', |
|
'python': 'Python', |
|
'rust': 'Rust', |
|
} |
|
FAILED = 0 |
|
SUCCEED = 1 |
|
|
|
|
|
def gpt_python_postprocess(ori_prompt: str, text: str) -> str: |
|
"""Better answer postprocessor for better instruction-aligned models like |
|
GPT.""" |
|
if '```' in text: |
|
blocks = re.findall(r'```(.*?)```', text, re.DOTALL) |
|
if len(blocks) == 0: |
|
text = text.split('```')[1] |
|
else: |
|
text = blocks[0] |
|
if not text.startswith('\n'): |
|
text = text[max(text.find('\n') + 1, 0):] |
|
|
|
match_ori = re.search(r'def(.*?)\(', ori_prompt) |
|
match = re.search(r'def(.*?)\(', text) |
|
if match: |
|
if match.group() == match_ori.group(): |
|
text = re.sub('def(.*?)\n', '', text, count=1) |
|
|
|
for c_index, c in enumerate(text[:5]): |
|
if c != ' ': |
|
text = ' ' * (4 - c_index) + text |
|
break |
|
|
|
text = text.split('\n\n\n')[0] |
|
return text |
|
|
|
|
|
def wizardcoder_postprocess(text: str) -> str: |
|
"""Postprocess for WizardCoder Models.""" |
|
if '```' in text: |
|
blocks = re.findall(r'```(.*?)```', text, re.DOTALL) |
|
if len(blocks) == 0: |
|
text = text.split('```')[1] |
|
else: |
|
text = blocks[0] |
|
if not text.startswith('\n'): |
|
text = text[max(text.find('\n') + 1, 0):] |
|
else: |
|
match = re.search(r'Here(.*?)\n', text) |
|
if match: |
|
text = re.sub('Here(.*?)\n', '', text, count=1) |
|
|
|
return text |
|
|
|
|
|
def collect_preds(filename: str): |
|
|
|
root, ext = osp.splitext(filename) |
|
partial_filename = root + '_0' + ext |
|
|
|
if not osp.exists(osp.realpath(filename)) and not osp.exists( |
|
osp.realpath(partial_filename)): |
|
print(f'No predictions found for {filename}') |
|
return FAILED, None, None |
|
else: |
|
if osp.exists(osp.realpath(filename)): |
|
preds = mmengine.load(filename) |
|
pred_strs = [ |
|
preds[str(i)]['prediction'] for i in range(len(preds)) |
|
] |
|
ori_prompt_strs = [ |
|
preds[str(i)]['origin_prompt'] for i in range(len(preds)) |
|
] |
|
else: |
|
filename = partial_filename |
|
pred_strs = [] |
|
ori_prompt_strs = [] |
|
i = 1 |
|
while osp.exists(osp.realpath(filename)): |
|
preds = mmengine.load(filename) |
|
filename = root + f'_{i}' + ext |
|
i += 1 |
|
pred_strs += [ |
|
preds[str(i)]['prediction'] for i in range(len(preds)) |
|
] |
|
ori_prompt_strs += [ |
|
preds[str(i)]['origin_prompt'] for i in range(len(preds)) |
|
] |
|
return SUCCEED, ori_prompt_strs, pred_strs |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
|
|
logger = get_logger(log_level='INFO') |
|
cfg = Config.fromfile(args.config) |
|
cfg.setdefault('work_dir', './outputs/default/') |
|
|
|
assert args.reuse, 'Please provide the experienment work dir.' |
|
if args.reuse: |
|
if args.reuse == 'latest': |
|
if not os.path.exists(cfg.work_dir) or not os.listdir( |
|
cfg.work_dir): |
|
logger.warning('No previous results to reuse!') |
|
else: |
|
dirs = os.listdir(cfg.work_dir) |
|
dir_time_str = sorted(dirs)[-1] |
|
else: |
|
dir_time_str = args.reuse |
|
logger.info(f'Reusing experiements from {dir_time_str}') |
|
|
|
cfg['work_dir'] = osp.join(cfg.work_dir, dir_time_str) |
|
|
|
for model in cfg.models: |
|
model_abbr = model_abbr_from_cfg(model) |
|
for dataset in cfg.datasets: |
|
dataset_abbr = dataset_abbr_from_cfg(dataset) |
|
filename = get_infer_output_path( |
|
model, dataset, osp.join(cfg.work_dir, 'predictions')) |
|
|
|
succeed, ori_prompt_strs, pred_strs = collect_preds(filename) |
|
if not succeed: |
|
continue |
|
|
|
|
|
for k, v in _LANGUAGE_NAME_DICT.items(): |
|
if k in dataset_abbr: |
|
lang = k |
|
task = v |
|
break |
|
|
|
|
|
if model_abbr in [ |
|
'WizardCoder-1B-V1.0', |
|
'WizardCoder-3B-V1.0', |
|
'WizardCoder-15B-V1.0', |
|
'WizardCoder-Python-13B-V1.0', |
|
'WizardCoder-Python-34B-V1.0', |
|
]: |
|
predictions = [{ |
|
'task_id': f'{task}/{i}', |
|
'generation': wizardcoder_postprocess(pred), |
|
} for i, pred in enumerate(pred_strs)] |
|
elif 'CodeLlama' not in model_abbr and lang == 'python': |
|
predictions = [{ |
|
'task_id': |
|
f'{task}/{i}', |
|
'generation': |
|
gpt_python_postprocess(ori_prompt, pred), |
|
} for i, (ori_prompt, |
|
pred) in enumerate(zip(ori_prompt_strs, pred_strs))] |
|
else: |
|
predictions = [{ |
|
'task_id': f'{task}/{i}', |
|
'generation': _clean_up_code(pred, lang), |
|
} for i, pred in enumerate(pred_strs)] |
|
|
|
|
|
result_file_path = os.path.join(cfg['work_dir'], 'humanevalx', |
|
model_abbr, |
|
f'humanevalx_{lang}.json') |
|
if osp.exists(result_file_path): |
|
logger.info( |
|
f'File exists for {model_abbr}, skip copy from predictions.' |
|
) |
|
else: |
|
mkdir_or_exist(osp.split(result_file_path)[0]) |
|
with open(result_file_path, 'w') as f: |
|
for pred in predictions: |
|
f.write(json.dumps(pred) + '\n') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|