File size: 5,863 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import copy
import io
import multiprocessing
from contextlib import redirect_stdout
from typing import Any, Optional
from lagent.actions.base_action import BaseAction
from lagent.schema import ActionReturn, ActionStatusCode
from opencompass.datasets.mbpp import TimeOutException, swallow_io, time_limit
class GenericRuntime:
GLOBAL_DICT = {}
LOCAL_DICT = None
HEADERS = []
def __init__(self):
self._global_vars = copy.copy(self.GLOBAL_DICT)
self._local_vars = copy.copy(
self.LOCAL_DICT) if self.LOCAL_DICT else None
for c in self.HEADERS:
self.exec_code(c)
def exec_code(self, code_piece: str) -> None:
exec(code_piece, self._global_vars)
def eval_code(self, expr: str) -> Any:
return eval(expr, self._global_vars)
DEFAULT_DESCRIPTION = """用来执行Python代码。代码必须是一个函数,
函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
```python
# import 依赖包
import xxx
def solution():
# 初始化一些变量
variable_names_with_real_meaning = xxx
# 步骤一
mid_variable = func(variable_names_with_real_meaning)
# 步骤 x
mid_variable = func(mid_variable)
# 最后结果
final_answer = func(mid_variable)
return final_answer
```"""
class PythonInterpreter(BaseAction):
"""A Python executor that can execute Python scripts.
Args:
description (str): The description of the action. Defaults to
DEFAULT_DESCRIPTION.
answer_symbol (str, Optional): the answer symbol from LLM
answer_expr (str, Optional): the answer function name of the Python
script. Default to 'solution()'.
answer_from_stdout (boolean): whether the execution results is from
stdout.
name (str, optional): The name of the action. If None, the name will
be class nameDefaults to None.
enable (bool, optional): Whether the action is enabled. Defaults to
True.
disable_description (str, optional): The description of the action when
it is disabled. Defaults to None.
timeout (int): Upper bound of waiting time for Python script execution.
"""
def __init__(self,
description: str = DEFAULT_DESCRIPTION,
answer_symbol: Optional[str] = None,
answer_expr: Optional[str] = 'solution()',
answer_from_stdout: bool = False,
name: Optional[str] = None,
enable: bool = True,
disable_description: Optional[str] = None,
timeout: int = 20) -> None:
super().__init__(description, name, enable, disable_description)
self.answer_symbol = answer_symbol
self.answer_expr = answer_expr
self.answer_from_stdout = answer_from_stdout
self.timeout = timeout
@staticmethod
def extract_code(command: str) -> str:
if '```python' in command:
command = command.split('```python')[1].split('```')[0]
elif '```' in command:
command = command.split('```')[1].split('```')[0]
command = command.split('\n')
return command
def __call__(self, command: str) -> ActionReturn:
"""Execution function for running generation code.
Args:
command(str): Python code to be executed.
"""
extracted_command = self.extract_code(command)
tool_return = ActionReturn(url=None,
args=dict(text=command,
extract_code=extracted_command),
type=self.name)
def _execution(q, command, tool_return):
try:
with swallow_io():
# leave 1s for multiprocess
with time_limit(self.timeout - 1):
res = self._call(command)
tool_return.result = dict(text=str(res))
tool_return.state = ActionStatusCode.SUCCESS
except TimeOutException:
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
except BaseException as e:
tool_return.errmsg = f'Failed. {e}.'
tool_return.state = ActionStatusCode.API_ERROR
q.put(tool_return)
# `signal` cannot be used in child thread, therefore, we
# need to create a process.
q = multiprocessing.Queue()
p = multiprocessing.Process(target=_execution,
args=(q, extracted_command, tool_return))
p.start()
p.join(timeout=self.timeout)
if p.is_alive():
p.kill()
# return timeout due to some unknown error
tool_return.errmsg = f'Time out after {self.timeout} seconds.'
tool_return.state = ActionStatusCode.API_ERROR
return tool_return
return q.get()
def _call(self, command: str) -> ActionReturn:
self.runtime = GenericRuntime()
if self.answer_from_stdout:
program_io = io.StringIO()
with redirect_stdout(program_io):
self.runtime.exec_code('\n'.join(command))
program_io.seek(0)
res = program_io.readlines()[-1]
elif self.answer_symbol:
self.runtime.exec_code('\n'.join(command))
res = self.runtime._global_vars[self.answer_symbol]
elif self.answer_expr:
self.runtime.exec_code('\n'.join(command))
res = self.runtime.eval_code(self.answer_expr)
else:
self.runtime.exec_code('\n'.join(command[:-1]))
res = True
return res
|