|
import configparser |
|
import importlib |
|
import json |
|
import os |
|
import os.path as osp |
|
import pickle |
|
import re |
|
import shutil |
|
import signal |
|
import subprocess |
|
import sys |
|
import tempfile |
|
import threading |
|
from concurrent.futures import ProcessPoolExecutor |
|
from pathlib import Path |
|
from shutil import copyfile |
|
from subprocess import PIPE, Popen |
|
from typing import Optional, Union |
|
|
|
from datasets import Dataset, DatasetDict |
|
|
|
from opencompass.openicl.icl_evaluator import BaseEvaluator |
|
from opencompass.registry import LOAD_DATASET, TEXT_POSTPROCESSORS |
|
|
|
from .base import BaseDataset |
|
|
|
_LIBRARY_NAME_LIST = [ |
|
'Pandas', 'Numpy', 'Tensorflow', 'Scipy', 'Sklearn', 'Pytorch', |
|
'Matplotlib' |
|
] |
|
|
|
|
|
@LOAD_DATASET.register_module() |
|
class DS1000Dataset(BaseDataset): |
|
"""DS1000 dataset. |
|
|
|
The DS-1000 dataset can be downloaded from |
|
https://github.com/HKUNLP/DS-1000/blob/main/ds1000_data.zip |
|
""" |
|
|
|
ALL_LIBS = [ |
|
'Pandas', |
|
'Numpy', |
|
'Matplotlib', |
|
'Tensorflow', |
|
'Scipy', |
|
'Sklearn', |
|
'Pytorch', |
|
] |
|
|
|
def get_data(self, problem_path: str) -> dict: |
|
"""Get necessary data from problem path.""" |
|
|
|
problem_path = Path(problem_path) |
|
data = dict() |
|
|
|
|
|
problem_config = configparser.RawConfigParser() |
|
problem_config.read(problem_path / '.cfg') |
|
for args in [ |
|
('library', 'lib'), |
|
('test', 'test_type'), |
|
('test', 'test_case_cnt'), |
|
('perturbation', 'perturbation_type'), |
|
('perturbation', 'perturbation_origin_id'), |
|
]: |
|
data[args[1]] = problem_config.get(*args) |
|
|
|
|
|
for file_name in [ |
|
'reference_code.txt', |
|
'test_code.py', |
|
'code_context.txt', |
|
'test_generate_pickle.py', |
|
'prompt.txt', |
|
]: |
|
with open(problem_path / file_name, 'r', encoding='UTF-8') as f: |
|
data[file_name.split('.')[0]] = f.read() |
|
|
|
data['test_column'] = dict( |
|
problem_path=str(problem_path), |
|
code_context=data.pop('code_context'), |
|
test_case_cnt=data.pop('test_case_cnt'), |
|
test_type=data.pop('test_type'), |
|
) |
|
return data |
|
|
|
def load(self, |
|
path: str, |
|
libs: Optional[Union[str, list]] = None, |
|
mode: str = 'Insertion'): |
|
"""Load dataset. |
|
|
|
Args: |
|
path(str): Data path of the dataset. |
|
libs(Optional[Union[str, list]]): All the libraries to be |
|
tested. Defaults to None, which means all. |
|
mode(str): Dataset mode. Only `Insertion` and `Completion` |
|
are supported. Defaults to `Insertion`. |
|
""" |
|
|
|
if isinstance(libs, str): |
|
libs = [libs] |
|
|
|
if libs: |
|
for lib in libs: |
|
assert lib in self.ALL_LIBS, \ |
|
f'Please choose libs in {self.ALL_LIBS}' |
|
else: |
|
libs = self.ALL_LIBS |
|
|
|
data_list = [] |
|
for lib in libs: |
|
source_path = Path(path) / lib / mode |
|
problems = sorted(os.listdir(source_path), |
|
key=lambda x: int(str(x).replace('q', ''))) |
|
for problem in problems: |
|
new_problem = self.get_data(source_path / problem) |
|
data_list.append(new_problem) |
|
|
|
dataset = Dataset.from_list(data_list) |
|
return DatasetDict({'test': dataset}) |
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('ds1000') |
|
def ds1000_postprocess(text: str) -> str: |
|
match = re.search('<code>(.*?)</code>', text, re.DOTALL) |
|
if match: |
|
text = match.group(1) |
|
|
|
match = re.search('```python(.*?)```', text, re.DOTALL) |
|
if match: |
|
text = match.group(1) |
|
|
|
match = re.search('```(.*?)```', text, re.DOTALL) |
|
if match: |
|
text = match.group(1) |
|
|
|
match = re.search('BEGIN SOLUTION(.*?)END SOLUTION', text, re.DOTALL) |
|
if match: |
|
text = match.group(1) |
|
|
|
return text |
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('ds1000_completion') |
|
def ds1000_completion_postprocess(text: str) -> str: |
|
text += '</code>' |
|
|
|
match = re.search('(.*?)</code>', text, re.DOTALL) |
|
if match: |
|
text = match.group(1) |
|
|
|
return text |
|
|
|
|
|
@TEXT_POSTPROCESSORS.register_module('ds1000_matplotlib') |
|
def ds1000_matplotlib_postprocess(text: str) -> str: |
|
text = ds1000_postprocess(text) |
|
code_lines = text.split('\n') |
|
postprocessed_lines = [] |
|
for line in code_lines: |
|
skip_line_flag = False |
|
|
|
|
|
|
|
for phrase in ['plt.show()', 'plt.clf()', 'plt.close()', 'savefig']: |
|
if phrase in line: |
|
skip_line_flag = True |
|
break |
|
|
|
if skip_line_flag: |
|
continue |
|
else: |
|
postprocessed_lines.append(line) |
|
text = '\n'.join(postprocessed_lines) |
|
return text |
|
|
|
|
|
class DS1000Evaluator(BaseEvaluator): |
|
"""Evaluator for ds1000 eval.""" |
|
|
|
def __init__(self, num_workers=16) -> None: |
|
self.num_workers = num_workers |
|
|
|
def score_single(self, pred, refer): |
|
generated_code = pred |
|
|
|
|
|
|
|
cwd = os.getcwd() |
|
|
|
def chdir_return(cwd, return_value): |
|
os.chdir(cwd) |
|
return return_value |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tempdir_name: |
|
|
|
tempdir_name = Path(tempdir_name) |
|
|
|
shutil.copytree(refer['problem_path'], |
|
tempdir_name, |
|
dirs_exist_ok=True) |
|
|
|
os.mkdir(tempdir_name / 'result') |
|
|
|
program = refer['code_context'].replace('[insert]', generated_code) |
|
with open(tempdir_name / 'program.py', 'w', encoding='UTF-8') as f: |
|
f.write(program) |
|
|
|
|
|
os.chdir(tempdir_name) |
|
|
|
execution_status = [] |
|
|
|
|
|
test_cnt = max(1, int(refer['test_case_cnt'])) |
|
for i in range(1, test_cnt + 1): |
|
|
|
|
|
cmd_text = f'python program.py --test_case {i}' |
|
time_limit = 60 |
|
cmd = Command(cmd_text, ) |
|
exit_code = cmd.run( |
|
timeout=time_limit) |
|
execution_status.append(exit_code) |
|
|
|
|
|
if sum(execution_status) > 0: |
|
return chdir_return(cwd, False) |
|
|
|
|
|
test_module = import_source_file(tempdir_name / 'test_code.py', |
|
'test_code') |
|
pass_flag = True |
|
|
|
if int(refer['test_type']) == 3: |
|
|
|
|
|
|
|
generated_code = generated_code.split('\n') |
|
for line in generated_code: |
|
if 'print' in line and '#' not in line.split('print'): |
|
generated_code.remove(line) |
|
generated_code = '\n'.join(generated_code) |
|
try: |
|
pass_flag = test_module.stringTest(generated_code) |
|
except Exception: |
|
|
|
return chdir_return(cwd, False) |
|
|
|
test_cnt = max(int(refer['test_case_cnt']), 1) |
|
for i in range(1, test_cnt + 1): |
|
try: |
|
ans = pickle.load(open(f'ans/ans{i}.pkl', 'rb')) |
|
|
|
|
|
result = pickle.load(open(f'result/result_{i}.pkl', 'rb')) |
|
pass_flag = test_module.test(result, ans) == 1 |
|
except Exception: |
|
|
|
return chdir_return(cwd, False) |
|
|
|
return chdir_return(cwd, pass_flag) |
|
|
|
def score(self, predictions, references): |
|
total = len(predictions) |
|
|
|
|
|
with ProcessPoolExecutor(self.num_workers) as executor: |
|
passed = sum( |
|
list(executor.map(self.score_single, predictions, references))) |
|
|
|
return {'accuracy': passed / total} |
|
|
|
|
|
class Command(object): |
|
"""This object takes in command and executes it with time out.""" |
|
|
|
def __init__(self, cmd): |
|
self.cmd = cmd |
|
self.process = None |
|
|
|
def run(self, timeout): |
|
|
|
def target(): |
|
|
|
if os.name == 'nt': |
|
self.process = Popen(self.cmd, |
|
shell=True, |
|
stdout=PIPE, |
|
stderr=PIPE) |
|
else: |
|
self.process = Popen(self.cmd, |
|
shell=True, |
|
stdout=PIPE, |
|
stderr=PIPE, |
|
preexec_fn=os.setsid) |
|
self.process.communicate() |
|
|
|
thread = threading.Thread(target=target) |
|
thread.start() |
|
|
|
thread.join(timeout) |
|
if thread.is_alive(): |
|
|
|
if os.name == 'nt': |
|
Popen('TASKKILL /F /PID {pid} /T'.format(pid=self.process.pid)) |
|
else: |
|
os.killpg(self.process.pid, signal.SIGTERM) |
|
thread.join() |
|
return self.process.returncode |
|
|
|
|
|
def import_source_file(fname, modname): |
|
"""Import a Python source file and return the loaded module. |
|
|
|
Args: |
|
fname: The full path to the source file. It may container characters |
|
like `.` or `-`. |
|
modname: The name for the loaded module. It may contain `.` and even |
|
characters that would normally not be allowed (e.g., `-`). |
|
Return: |
|
The imported module |
|
|
|
Raises: |
|
ImportError: If the file cannot be imported (e.g, if it's not a `.py` |
|
file or if it does not exist). |
|
Exception: Any exception that is raised while executing the module |
|
(e.g. :exc:`SyntaxError). |
|
These are errors made by the author of the module! |
|
""" |
|
|
|
spec = importlib.util.spec_from_file_location(modname, fname) |
|
if spec is None: |
|
raise ImportError( |
|
f"Could not load spec for module '{modname}' at: {fname}") |
|
module = importlib.util.module_from_spec(spec) |
|
sys.modules[modname] = module |
|
try: |
|
spec.loader.exec_module(module) |
|
except FileNotFoundError as e: |
|
raise ImportError(f'{e.strerror}: {fname}') from e |
|
return module |
|
|
|
|
|
class DS1000ServiceEvaluator(BaseEvaluator): |
|
"""Evaluator for ds1000 eval by using a service. |
|
|
|
Before you use this Evaluator, launch a code eval service according to: |
|
https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html |
|
|
|
Args: |
|
lib (str): The library to be evaluated. |
|
ip_address (str): The IP Address of DS1000 code evaluate service. |
|
Defaults to 'localhost'. |
|
port (int): The port of DS1000 code evaluate service. |
|
Defaults to 5000. |
|
timeout (int): Maximum wait time when accessing the service, |
|
Defaults to 100. |
|
""" |
|
|
|
def __init__(self, |
|
lib: str, |
|
ip_address='localhost', |
|
port=5000, |
|
timeout=600) -> None: |
|
assert lib in _LIBRARY_NAME_LIST, ( |
|
f' lib must be in {_LIBRARY_NAME_LIST}') |
|
self.lib = lib |
|
self.ip_address = ip_address |
|
self.port = port |
|
self.timeout = timeout |
|
super().__init__() |
|
|
|
def score(self, predictions, references): |
|
processed_predictions = {} |
|
assert len(predictions) == len(references) |
|
for i, (pred, gold) in enumerate(zip(predictions, references)): |
|
if len(pred) > 10000: |
|
pred = '' |
|
processed_predictions[str(i)] = {'prediction': pred, 'gold': gold} |
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
tmp_out_path = osp.join(tmp_dir, f'ds1000_{self.lib}.json') |
|
with open(tmp_out_path, 'w', encoding='utf-8') as json_file: |
|
json.dump(processed_predictions, |
|
json_file, |
|
indent=4, |
|
ensure_ascii=False) |
|
|
|
succeed, output = self._code_eval_service(file_path=tmp_out_path) |
|
if succeed: |
|
if isinstance(output, str): |
|
return json.loads(output) |
|
elif isinstance(output, dict): |
|
return output |
|
else: |
|
result_file_path = os.path.join('outputs', |
|
f'ds1000_{self.lib}.json') |
|
copyfile(tmp_out_path, result_file_path) |
|
ref_url = 'https://opencompass.readthedocs.io/en/latest/advanced_guides/code_eval_service.html' |
|
raise Exception( |
|
'Call CodeEvalService Error in `DS1000ServiceEvaluator`, ' |
|
'The results have been saved in path ' |
|
f"'{result_file_path}'. You need to check that your " |
|
'code evaluate service is launched and the network to ' |
|
'service is connected, you can also get results directly ' |
|
f'by using `curl` command refer to {ref_url}.' |
|
f'\nError Information: {output}') |
|
|
|
def _code_eval_service(self, file_path: str) -> tuple: |
|
"""Access the code eval service. |
|
|
|
Args: |
|
file_path (str): The file path to the file to be evaluated. |
|
|
|
Returns: |
|
tuple[bool, str]: Whether the access is successful and the output. |
|
""" |
|
exec_result = subprocess.run([ |
|
'curl', '-X', 'POST', '-F', f'file=@{file_path}', |
|
f'{self.ip_address}:{self.port}/evaluate' |
|
], |
|
timeout=self.timeout, |
|
capture_output=True) |
|
if exec_result.returncode == 0 and re.match( |
|
"\"{.*:.*}\"", exec_result.stdout.decode('utf-8')): |
|
return True, json.loads(exec_result.stdout.decode('utf-8')) |
|
else: |
|
if exec_result.stderr: |
|
try: |
|
err = exec_result.stderr.decode() |
|
except Exception: |
|
err = exec_result.stderr |
|
else: |
|
try: |
|
err = exec_result.stdout.decode() |
|
except Exception: |
|
err = exec_result.stdout |
|
return False, err |
|
|