|
import base64 |
|
import io |
|
import logging |
|
import os |
|
import queue |
|
import re |
|
import signal |
|
import sys |
|
import traceback |
|
import uuid |
|
from typing import Optional, Tuple |
|
|
|
import json5 |
|
import PIL.Image |
|
from jupyter_client import KernelManager |
|
from lagent.actions.base_action import BaseAction |
|
from lagent.schema import ActionReturn, ActionStatusCode |
|
|
|
WORK_DIR = os.getenv('CODE_INTERPRETER_WORK_DIR', |
|
f"{os.path.abspath('./output_images')}") |
|
|
|
DEFAULT_DESCRIPTION = """启动Jupter Kernel用于执行Python代码。""" |
|
|
|
START_CODE = """ |
|
import os |
|
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|
def input(*args, **kwargs): |
|
raise NotImplementedError('Python input() function is disabled.') |
|
|
|
get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!') |
|
{} |
|
""" |
|
|
|
|
|
class TimeoutError(Exception): |
|
pass |
|
|
|
|
|
class IPythonInterpreter(BaseAction): |
|
"""A IPython executor that can execute Python scripts in a jupyter manner. |
|
|
|
Args: |
|
description (str): The description of the action. Defaults to |
|
DEFAULT_DESCRIPTION. |
|
name (str, optional): The name of the action. If None, the name will |
|
be class nameDefaults to None. |
|
enable (bool, optional): Whether the action is enabled. Defaults to |
|
True. |
|
disable_description (str, optional): The description of the action when |
|
it is disabled. Defaults to None. |
|
timeout (int): Upper bound of waiting time for Python script execution. |
|
Defaults to 20. |
|
trim_output (int, optional): Max characters restriction of ipython |
|
outputs. If None, do not perform any trim. |
|
Notice that, this is not token length but string length. |
|
Trim strategies might be added later if needed. Defaults to 1024. |
|
user_data_dir (str): Specified the user data directory for files |
|
loading. If set to `ENV`, use `USER_DATA_DIR` environment variable. |
|
Defaults to `ENV`. |
|
force_user_data (bool): Whether to force use user data. |
|
Defaults to True. |
|
""" |
|
|
|
_KERNEL_CLIENTS = {} |
|
|
|
def __init__(self, |
|
description: str = DEFAULT_DESCRIPTION, |
|
name: Optional[str] = None, |
|
enable: bool = True, |
|
disable_description: Optional[str] = None, |
|
timeout: int = 20, |
|
trim_output: Optional[int] = 1024, |
|
user_data_dir: str = 'ENV', |
|
force_user_data: bool = True) -> None: |
|
super().__init__(description, name, enable, disable_description) |
|
|
|
self.timeout = timeout |
|
if user_data_dir == 'ENV': |
|
user_data_dir = os.environ.get('USER_DATA_DIR', '') |
|
|
|
if user_data_dir: |
|
|
|
|
|
assert os.path.exists(user_data_dir), \ |
|
f'{user_data_dir} does not exist.' |
|
user_data_dir = os.path.abspath(user_data_dir) |
|
user_data_dir = f"import os\nos.chdir('{user_data_dir}')" |
|
else: |
|
if force_user_data: |
|
raise ValueError('user_data_dir is not set. Please ' |
|
'set force_user_data to False if ' |
|
'no extra data needed.') |
|
self.user_data_dir = user_data_dir |
|
self._initialized = False |
|
self.trim_output = trim_output |
|
if not os.path.exists(WORK_DIR): |
|
os.mkdir(WORK_DIR) |
|
|
|
@staticmethod |
|
def start_kernel(): |
|
|
|
km = KernelManager() |
|
km.start_kernel() |
|
kc = km.client() |
|
return km, kc |
|
|
|
def initialize(self): |
|
if self._initialized: |
|
return |
|
pid = os.getpid() |
|
if pid not in self._KERNEL_CLIENTS: |
|
self._KERNEL_CLIENTS[pid] = self.start_kernel() |
|
self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid] |
|
self._initialized = True |
|
self._call(START_CODE.format(self.user_data_dir), None) |
|
|
|
def reset(self): |
|
if not self._initialized: |
|
self.initialize() |
|
else: |
|
code = "get_ipython().run_line_magic('reset', '-f')\n" + \ |
|
START_CODE.format(self.user_data_dir) |
|
self._call(code, None) |
|
|
|
def _call(self, |
|
command: str, |
|
timeout: Optional[int] = None) -> Tuple[str, bool]: |
|
self.initialize() |
|
command = extract_code(command) |
|
|
|
|
|
while True: |
|
try: |
|
msg = self.kernel_client.get_iopub_msg(timeout=1) |
|
msg_type = msg['msg_type'] |
|
if msg_type == 'status': |
|
if msg['content'].get('execution_state') == 'idle': |
|
break |
|
except queue.Empty: |
|
|
|
break |
|
|
|
self.kernel_client.execute(command) |
|
|
|
def _inner_call(): |
|
result = '' |
|
succeed = True |
|
image_idx = 0 |
|
|
|
while True: |
|
text = '' |
|
image = '' |
|
finished = False |
|
msg_type = 'error' |
|
try: |
|
msg = self.kernel_client.get_iopub_msg(timeout=10) |
|
msg_type = msg['msg_type'] |
|
if msg_type == 'status': |
|
if msg['content'].get('execution_state') == 'idle': |
|
finished = True |
|
elif msg_type == 'execute_result': |
|
text = msg['content']['data'].get('text/plain', '') |
|
if 'image/png' in msg['content']['data']: |
|
image_b64 = msg['content']['data']['image/png'] |
|
image_url = publish_image_to_local(image_b64) |
|
image_idx += 1 |
|
image = '![fig-%03d](%s)' % (image_idx, image_url) |
|
elif msg_type == 'display_data': |
|
if 'image/png' in msg['content']['data']: |
|
image_b64 = msg['content']['data']['image/png'] |
|
image_url = publish_image_to_local(image_b64) |
|
image_idx += 1 |
|
image = '![fig-%03d](%s)' % (image_idx, image_url) |
|
else: |
|
text = msg['content']['data'].get('text/plain', '') |
|
elif msg_type == 'stream': |
|
msg_type = msg['content']['name'] |
|
text = msg['content']['text'] |
|
elif msg_type == 'error': |
|
succeed = False |
|
text = escape_ansi('\n'.join( |
|
msg['content']['traceback'])) |
|
if 'M6_CODE_INTERPRETER_TIMEOUT' in text: |
|
text = f'Timeout. No response after {timeout} seconds.' |
|
except queue.Empty: |
|
|
|
self.kernel_manager.interrupt_kernel() |
|
succeed = False |
|
text = f'Timeout. No response after {timeout} seconds.' |
|
finished = True |
|
except Exception: |
|
succeed = False |
|
text = 'The code interpreter encountered an unexpected error.' |
|
logging.warning(''.join( |
|
traceback.format_exception(*sys.exc_info()))) |
|
finished = True |
|
if text: |
|
result += f'\n\n{msg_type}:\n\n```\n{text}\n```' |
|
if image: |
|
result += f'\n\n{image}' |
|
if finished: |
|
|
|
|
|
if self.trim_output and len(result) > self.trim_output: |
|
ellip = '......' |
|
half_len = int((self.trim_output - len(ellip)) / 2) |
|
result = result[:half_len] + ellip + result[-half_len:] |
|
return succeed, result |
|
|
|
try: |
|
if timeout: |
|
|
|
def handler(signum, frame): |
|
raise TimeoutError() |
|
|
|
signal.signal(signal.SIGALRM, handler) |
|
signal.alarm(timeout) |
|
succeed, result = _inner_call() |
|
except TimeoutError: |
|
succeed = False |
|
text = 'The code interpreter encountered an unexpected error.' |
|
result = f'\n\nerror:\n\n```\n{text}\n```' |
|
finally: |
|
if timeout: |
|
signal.alarm(0) |
|
|
|
result = result.lstrip('\n') |
|
return succeed, result |
|
|
|
def __call__(self, |
|
command: str, |
|
timeout: Optional[int] = None) -> ActionReturn: |
|
tool_return = ActionReturn(url=None, args=None, type=self.name) |
|
extracted_command = extract_code(command) |
|
tool_return.args = dict(text=command, extract_code=extracted_command) |
|
if extracted_command: |
|
succeed, result = self._call(extracted_command, timeout) |
|
if succeed: |
|
if not result: |
|
result = 'The code is succeed without any outputs.' |
|
tool_return.result = dict(text=result) |
|
tool_return.state = ActionStatusCode.SUCCESS |
|
else: |
|
tool_return.errmsg = repr(result) |
|
tool_return.state = ActionStatusCode.API_ERROR |
|
else: |
|
tool_return.errmsg = 'The input code is empty. Please follow the format.' |
|
tool_return.state = ActionStatusCode.API_ERROR |
|
return tool_return |
|
|
|
|
|
def extract_code(text): |
|
|
|
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL) |
|
|
|
single_match = re.search(r'`([^`]*)`', text, re.DOTALL) |
|
if triple_match: |
|
text = triple_match.group(1) |
|
elif single_match: |
|
text = single_match.group(1) |
|
else: |
|
try: |
|
text = json5.loads(text)['code'] |
|
except Exception: |
|
pass |
|
|
|
return text |
|
|
|
|
|
def escape_ansi(line): |
|
ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]') |
|
return ansi_escape.sub('', line) |
|
|
|
|
|
def publish_image_to_local(image_base64: str): |
|
image_file = str(uuid.uuid4()) + '.png' |
|
local_image_file = os.path.join(WORK_DIR, image_file) |
|
|
|
png_bytes = base64.b64decode(image_base64) |
|
assert isinstance(png_bytes, bytes) |
|
bytes_io = io.BytesIO(png_bytes) |
|
PIL.Image.open(bytes_io).save(local_image_file, 'png') |
|
|
|
return local_image_file |
|
|
|
|
|
|
|
def get_multiline_input(hint): |
|
print(hint) |
|
print('// Press ENTER to make a new line. Press CTRL-D to end input.') |
|
lines = [] |
|
while True: |
|
try: |
|
line = input() |
|
except EOFError: |
|
break |
|
lines.append(line) |
|
print('// Input received.') |
|
if lines: |
|
return '\n'.join(lines) |
|
else: |
|
return '' |
|
|
|
|
|
if __name__ == '__main__': |
|
code_interpreter = IPythonInterpreter() |
|
while True: |
|
print(code_interpreter(get_multiline_input('Enter python code:'))) |
|
|