code_eval_octopack / execute.py
Muennighoff's picture
Update execute.py
d57e254
raw
history blame
13 kB
import contextlib
import faulthandler
import io
import logging
import multiprocessing
import os
import platform
import shutil
import signal
import subprocess
import tempfile
def check_correctness(check_program, timeout, task_id, completion_id, language, cargo_string=""):
"""
Evaluates the functional correctness of a completion by running the test
suite provided in the problem.
:param completion_id: an optional completion ID so we can match
the results later even if execution finishes asynchronously.
"""
manager = multiprocessing.Manager()
result = manager.list()
if language == "python":
p = multiprocessing.Process(target=unsafe_execute, args=(check_program, result, timeout))
elif language == "cpp":
p = multiprocessing.Process(target=unsafe_execute_cpp, args=(check_program, result, timeout))
elif language == "go":
p = multiprocessing.Process(target=unsafe_execute_go, args=(check_program, result, timeout))
elif language == "java":
p = multiprocessing.Process(target=unsafe_execute_java, args=(check_program, result, timeout))
elif language == "javascript":
p = multiprocessing.Process(target=unsafe_execute_js, args=(check_program, result, timeout))
elif language == "rust":
p = multiprocessing.Process(target=unsafe_execute_rust, args=(check_program, result, timeout, cargo_string))
else:
raise ValueError(f"Language {language} not supported. Feel free to add it :)")
p.start()
p.join(timeout=timeout + 1)
if p.is_alive():
p.kill()
if not result:
result.append("timed out")
return dict(
task_id=task_id,
passed=result[0] == "passed",
result=result[0],
completion_id=completion_id,
)
def unsafe_execute(check_program, result, timeout):
with create_tempdir():
# These system calls are needed when cleaning up tempdir.
import os
import shutil
rmtree = shutil.rmtree
rmdir = os.rmdir
chdir = os.chdir
# Disable functionalities that can make destructive changes to the test.
reliability_guard()
# Run program.
try:
exec_globals = {}
with swallow_io():
with time_limit(timeout):
exec(check_program, exec_globals)
result.append("passed")
except TimeoutException:
result.append("timed out")
except BaseException as e:
result.append(f"failed: {e}")
# Needed for cleaning up.
shutil.rmtree = rmtree
os.rmdir = rmdir
os.chdir = chdir
def unsafe_execute_cpp(check_program, result, timeout):
with create_tempdir():
open(f"test.cpp", 'w').write(check_program)
# -lcrypto & -lssl are needed for the crypto library required by HumanEval-X-Bugs Task 162
compilation_result = subprocess.run(
["/usr/bin/g++", "-std=c++11", "test.cpp", "-lcrypto", "-lssl"],
timeout=timeout,
capture_output=True,
)
if compilation_result.returncode != 0:
if compilation_result.stderr:
err = compilation_result.stderr.decode()
else:
err = compilation_result.stdout.decode()
result.append(f"failed: compilation error: {err}")
else:
try:
exec_result = subprocess.run(["./a.out"], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except:
err = exec_result.stdout
result.append(f"failed: {err}")
except subprocess.TimeoutExpired as e:
result.append("timed out")
def unsafe_execute_go(check_program, result, timeout):
with create_tempdir():
open(f"main_test.go", 'w').write(check_program)
try:
exec_result = subprocess.run(["go", "test", f"-timeout={timeout}s", "main_test.go"], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
else:
if exec_result.stderr:
try:
err = exec_result.stderr.decode()
except:
err = exec_result.stderr
else:
try:
err = exec_result.stdout.decode()
except:
err = exec_result.stdout
result.append(f"failed: {err}")
except subprocess.TimeoutExpired as e:
result.append("timed out")
def unsafe_execute_java(check_program, result, timeout):
with create_tempdir():
open(f"Main.java", 'w').write(check_program)
res = "failed: unknown error"
compile_returncode, compilation_result = -1, None
for _ in range(5):
try:
compilation_result = subprocess.run(
['javac', "Main.java"], timeout=5, capture_output=True
)
compile_returncode = compilation_result.returncode
break
except subprocess.TimeoutExpired as e:
continue
if compile_returncode != 0:
err = "unknown error"
if compilation_result is not None:
err = compilation_result.stderr.decode() if compilation_result.stderr else compilation_result.stdout.decode()
res = "failed: compilation error: " + err
else:
try:
exec_result = subprocess.run([f'java', '-cp', '.', 'Main'], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
res = "passed"
elif exec_result.returncode == 1:
if "AssertionError" in exec_result.stderr.decode('unicode-escape'):
res = "failed: wrong answer"
else:
res = f"failed: {exec_result.stderr.decode()}"
except subprocess.TimeoutExpired as e:
res = "timed out"
except BaseException as e:
res = f"failed: {e}"
result.append(res)
def unsafe_execute_js(check_program, result, timeout):
with create_tempdir():
open(f"test.js", 'w').write(check_program)
# Run program.
try:
exec_result = subprocess.run(["node", "test.js"], timeout=timeout, capture_output=True)
if exec_result.stderr.decode():
err = exec_result.stderr.decode()
result.append(f"failed: {err}")
elif exec_result.stdout.decode():
err = exec_result.stdout.decode()
result.append(f"failed: {err}")
else:
result.append("passed")
except subprocess.TimeoutExpired as e:
result.append("timed out")
def unsafe_execute_rust(check_program, result, timeout, cargo_string):
# To make rust faster, do not use a tempdir
# with create_tempdir():
WD: str = os.getcwd()
RUST_DIR: str = os.path.join(WD, "rust")
RUST_SRC: str = os.path.join(RUST_DIR, "src")
RUST_BIN: str = os.path.join(RUST_SRC, "bin")
RUST_TMP_DIR: str = os.path.join(RUST_DIR, "tmp")
RUST_LOGS: str = os.path.join(RUST_TMP_DIR, "logs")
RUST_EXT: str = ".rs"
# Create mandatory tmp directories
os.makedirs(RUST_TMP_DIR, exist_ok=True)
os.makedirs(RUST_LOGS, exist_ok=True)
os.makedirs(RUST_SRC, exist_ok=True)
os.makedirs(RUST_BIN, exist_ok=True)
# Create Cargo.toml file
if not os.path.exists(f"{RUST_DIR}/Cargo.toml"):
open(f"{RUST_DIR}/Cargo.toml", 'w').write(cargo_string)
with tempfile.NamedTemporaryFile(dir = RUST_BIN, delete=False) as f:
file_name: str = "test" + RUST_EXT
os.rename(f.name, os.path.join(RUST_BIN, file_name))
# Dump the rust source code in the target temporal file
f.write(check_program.encode('utf-8'))
# Proceed towards Rust binaries compilation. Therefore move to Rust module root dir.
os.chdir(RUST_DIR)
compilation_result = subprocess.run(["cargo", "check", "--bin", "test", "--message-format", "json"], timeout=timeout, capture_output=True)
if compilation_result.returncode == 0:
exec_result = subprocess.run(["cargo", "test", "--bin", "test", "--message-format", "json"], timeout=timeout, capture_output=True)
if exec_result.returncode == 0:
result.append("passed")
else:
err = compilation_result.stderr.decode() if compilation_result.stderr else compilation_result.stdout.decode()
result.append("failed: execution error: " + err)
else:
err = compilation_result.stderr.decode() if compilation_result.stderr else compilation_result.stdout.decode()
result.append("failed: compilation error: " + err)
# Move back to the original working directory
os.chdir(WD)
@contextlib.contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.setitimer(signal.ITIMER_REAL, seconds)
signal.signal(signal.SIGALRM, signal_handler)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
@contextlib.contextmanager
def swallow_io():
stream = WriteOnlyStringIO()
with contextlib.redirect_stdout(stream):
with contextlib.redirect_stderr(stream):
with redirect_stdin(stream):
yield
@contextlib.contextmanager
def create_tempdir():
with tempfile.TemporaryDirectory() as dirname:
with chdir(dirname):
yield dirname
class TimeoutException(Exception):
pass
class WriteOnlyStringIO(io.StringIO):
"""StringIO that throws an exception when it's read from"""
def read(self, *args, **kwargs):
raise OSError
def readline(self, *args, **kwargs):
raise OSError
def readlines(self, *args, **kwargs):
raise OSError
def readable(self, *args, **kwargs):
"""Returns True if the IO object can be read."""
return False
class redirect_stdin(contextlib._RedirectStream): # type: ignore
_stream = "stdin"
@contextlib.contextmanager
def chdir(root):
if root == ".":
yield
return
cwd = os.getcwd()
os.chdir(root)
try:
yield
except BaseException as exc:
raise exc
finally:
os.chdir(cwd)
def reliability_guard(maximum_memory_bytes=None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
Codex paper for more information about OpenAI's code sandbox, and proceed
with caution.
"""
if maximum_memory_bytes is not None:
import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
if not platform.uname().system == "Darwin":
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable()
import builtins
builtins.exit = None
builtins.quit = None
import os
os.environ["OMP_NUM_THREADS"] = "1"
os.kill = None
os.system = None
os.putenv = None
os.remove = None
os.removedirs = None
os.rmdir = None
os.fchdir = None
os.setuid = None
os.fork = None
os.forkpty = None
os.killpg = None
os.rename = None
os.renames = None
os.truncate = None
os.replace = None
os.unlink = None
os.fchmod = None
os.fchown = None
os.chmod = None
os.chown = None
os.chroot = None
os.fchdir = None
os.lchflags = None
os.lchmod = None
os.lchown = None
os.getcwd = None
os.chdir = None
import shutil
shutil.rmtree = None
shutil.move = None
shutil.chown = None
import subprocess
subprocess.Popen = None # type: ignore
__builtins__["help"] = None
import sys
sys.modules["ipdb"] = None
sys.modules["joblib"] = None
sys.modules["resource"] = None
sys.modules["psutil"] = None
sys.modules["tkinter"] = None