Spaces:
Sleeping
Sleeping
upload
Browse files- mindsearch/agent/__init__.py +60 -0
- mindsearch/agent/mindsearch_agent.py +422 -0
- mindsearch/agent/mindsearch_prompt.py +326 -0
- mindsearch/agent/models.py +77 -0
- mindsearch/app.py +136 -0
- mindsearch/terminal.py +50 -0
mindsearch/agent/__init__.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from datetime import datetime
|
3 |
+
|
4 |
+
from lagent.actions import ActionExecutor, BingBrowser
|
5 |
+
|
6 |
+
import mindsearch.agent.models as llm_factory
|
7 |
+
from mindsearch.agent.mindsearch_agent import (MindSearchAgent,
|
8 |
+
MindSearchProtocol)
|
9 |
+
from mindsearch.agent.mindsearch_prompt import (
|
10 |
+
FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN,
|
11 |
+
fewshot_example_cn, fewshot_example_en, graph_fewshot_example_cn,
|
12 |
+
graph_fewshot_example_en, searcher_context_template_cn,
|
13 |
+
searcher_context_template_en, searcher_input_template_cn,
|
14 |
+
searcher_input_template_en, searcher_system_prompt_cn,
|
15 |
+
searcher_system_prompt_en)
|
16 |
+
|
17 |
+
LLM = {}
|
18 |
+
|
19 |
+
|
20 |
+
def init_agent(lang='cn', model_format='internlm_server',search_engine='DuckDuckGoSearch'):
|
21 |
+
llm = LLM.get(model_format, None)
|
22 |
+
if llm is None:
|
23 |
+
llm_cfg = getattr(llm_factory, model_format)
|
24 |
+
if llm_cfg is None:
|
25 |
+
raise NotImplementedError
|
26 |
+
llm_cfg = llm_cfg.copy()
|
27 |
+
llm = llm_cfg.pop('type')(**llm_cfg)
|
28 |
+
LLM[model_format] = llm
|
29 |
+
|
30 |
+
interpreter_prompt = GRAPH_PROMPT_CN if lang == 'cn' else GRAPH_PROMPT_EN
|
31 |
+
plugin_prompt = searcher_system_prompt_cn if lang == 'cn' else searcher_system_prompt_en
|
32 |
+
if not model_format.lower().startswith('internlm'):
|
33 |
+
interpreter_prompt += graph_fewshot_example_cn if lang == 'cn' else graph_fewshot_example_en
|
34 |
+
plugin_prompt += fewshot_example_cn if lang == 'cn' else fewshot_example_en
|
35 |
+
|
36 |
+
agent = MindSearchAgent(
|
37 |
+
llm=llm,
|
38 |
+
protocol=MindSearchProtocol(meta_prompt=datetime.now().strftime(
|
39 |
+
'The current date is %Y-%m-%d.'),
|
40 |
+
interpreter_prompt=interpreter_prompt,
|
41 |
+
response_prompt=FINAL_RESPONSE_CN
|
42 |
+
if lang == 'cn' else FINAL_RESPONSE_EN),
|
43 |
+
searcher_cfg=dict(
|
44 |
+
llm=llm,
|
45 |
+
plugin_executor=ActionExecutor(
|
46 |
+
BingBrowser(searcher_type=search_engine,
|
47 |
+
topk=6,
|
48 |
+
api_key=os.environ.get('BING_API_KEY',
|
49 |
+
'YOUR BING API'))),
|
50 |
+
protocol=MindSearchProtocol(
|
51 |
+
meta_prompt=datetime.now().strftime(
|
52 |
+
'The current date is %Y-%m-%d.'),
|
53 |
+
plugin_prompt=plugin_prompt,
|
54 |
+
),
|
55 |
+
template=dict(input=searcher_input_template_cn
|
56 |
+
if lang == 'cn' else searcher_input_template_en,
|
57 |
+
context=searcher_context_template_cn
|
58 |
+
if lang == 'cn' else searcher_context_template_en)),
|
59 |
+
max_turn=10)
|
60 |
+
return agent
|
mindsearch/agent/mindsearch_agent.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import queue
|
4 |
+
import random
|
5 |
+
import re
|
6 |
+
import threading
|
7 |
+
import uuid
|
8 |
+
from collections import defaultdict
|
9 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
10 |
+
from copy import deepcopy
|
11 |
+
from dataclasses import asdict
|
12 |
+
from typing import Dict, List, Optional
|
13 |
+
|
14 |
+
from lagent.actions import ActionExecutor
|
15 |
+
from lagent.agents import BaseAgent, Internlm2Agent
|
16 |
+
from lagent.agents.internlm2_agent import Internlm2Protocol
|
17 |
+
from lagent.schema import AgentReturn, AgentStatusCode, ModelStatusCode
|
18 |
+
from termcolor import colored
|
19 |
+
|
20 |
+
# 初始化日志记录
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class SearcherAgent(Internlm2Agent):
|
26 |
+
|
27 |
+
def __init__(self, template='{query}', **kwargs) -> None:
|
28 |
+
super().__init__(**kwargs)
|
29 |
+
self.template = template
|
30 |
+
|
31 |
+
def stream_chat(self,
|
32 |
+
question: str,
|
33 |
+
root_question: str = None,
|
34 |
+
parent_response: List[dict] = None,
|
35 |
+
**kwargs) -> AgentReturn:
|
36 |
+
message = self.template['input'].format(question=question,
|
37 |
+
topic=root_question)
|
38 |
+
if parent_response:
|
39 |
+
if 'context' in self.template:
|
40 |
+
parent_response = [
|
41 |
+
self.template['context'].format(**item)
|
42 |
+
for item in parent_response
|
43 |
+
]
|
44 |
+
message = '\n'.join(parent_response + [message])
|
45 |
+
print(colored(f'current query: {message}', 'green'))
|
46 |
+
for agent_return in super().stream_chat(message,
|
47 |
+
session_id=random.randint(
|
48 |
+
0, 999999),
|
49 |
+
**kwargs):
|
50 |
+
agent_return.type = 'searcher'
|
51 |
+
agent_return.content = question
|
52 |
+
yield deepcopy(agent_return)
|
53 |
+
|
54 |
+
|
55 |
+
class MindSearchProtocol(Internlm2Protocol):
|
56 |
+
|
57 |
+
def __init__(
|
58 |
+
self,
|
59 |
+
meta_prompt: str = None,
|
60 |
+
interpreter_prompt: str = None,
|
61 |
+
plugin_prompt: str = None,
|
62 |
+
few_shot: Optional[List] = None,
|
63 |
+
response_prompt: str = None,
|
64 |
+
language: Dict = dict(
|
65 |
+
begin='',
|
66 |
+
end='',
|
67 |
+
belong='assistant',
|
68 |
+
),
|
69 |
+
tool: Dict = dict(
|
70 |
+
begin='{start_token}{name}\n',
|
71 |
+
start_token='<|action_start|>',
|
72 |
+
name_map=dict(plugin='<|plugin|>', interpreter='<|interpreter|>'),
|
73 |
+
belong='assistant',
|
74 |
+
end='<|action_end|>\n',
|
75 |
+
),
|
76 |
+
execute: Dict = dict(role='execute',
|
77 |
+
begin='',
|
78 |
+
end='',
|
79 |
+
fallback_role='environment'),
|
80 |
+
) -> None:
|
81 |
+
self.response_prompt = response_prompt
|
82 |
+
super().__init__(meta_prompt=meta_prompt,
|
83 |
+
interpreter_prompt=interpreter_prompt,
|
84 |
+
plugin_prompt=plugin_prompt,
|
85 |
+
few_shot=few_shot,
|
86 |
+
language=language,
|
87 |
+
tool=tool,
|
88 |
+
execute=execute)
|
89 |
+
|
90 |
+
def format(self,
|
91 |
+
inner_step: List[Dict],
|
92 |
+
plugin_executor: ActionExecutor = None,
|
93 |
+
**kwargs) -> list:
|
94 |
+
formatted = []
|
95 |
+
if self.meta_prompt:
|
96 |
+
formatted.append(dict(role='system', content=self.meta_prompt))
|
97 |
+
if self.plugin_prompt:
|
98 |
+
plugin_prompt = self.plugin_prompt.format(tool_info=json.dumps(
|
99 |
+
plugin_executor.get_actions_info(), ensure_ascii=False))
|
100 |
+
formatted.append(
|
101 |
+
dict(role='system', content=plugin_prompt, name='plugin'))
|
102 |
+
if self.interpreter_prompt:
|
103 |
+
formatted.append(
|
104 |
+
dict(role='system',
|
105 |
+
content=self.interpreter_prompt,
|
106 |
+
name='interpreter'))
|
107 |
+
if self.few_shot:
|
108 |
+
for few_shot in self.few_shot:
|
109 |
+
formatted += self.format_sub_role(few_shot)
|
110 |
+
formatted += self.format_sub_role(inner_step)
|
111 |
+
return formatted
|
112 |
+
|
113 |
+
|
114 |
+
class WebSearchGraph:
|
115 |
+
end_signal = 'end'
|
116 |
+
searcher_cfg = dict()
|
117 |
+
|
118 |
+
def __init__(self):
|
119 |
+
self.nodes = {}
|
120 |
+
self.adjacency_list = defaultdict(list)
|
121 |
+
self.executor = ThreadPoolExecutor(max_workers=10)
|
122 |
+
self.future_to_query = dict()
|
123 |
+
self.searcher_resp_queue = queue.Queue()
|
124 |
+
|
125 |
+
def add_root_node(self, node_content, node_name='root'):
|
126 |
+
self.nodes[node_name] = dict(content=node_content, type='root')
|
127 |
+
self.adjacency_list[node_name] = []
|
128 |
+
self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
|
129 |
+
|
130 |
+
def add_node(self, node_name, node_content):
|
131 |
+
self.nodes[node_name] = dict(content=node_content, type='searcher')
|
132 |
+
self.adjacency_list[node_name] = []
|
133 |
+
|
134 |
+
def model_stream_thread():
|
135 |
+
agent = SearcherAgent(**self.searcher_cfg)
|
136 |
+
try:
|
137 |
+
parent_nodes = []
|
138 |
+
for start_node, adj in self.adjacency_list.items():
|
139 |
+
for neighbor in adj:
|
140 |
+
if node_name == neighbor[
|
141 |
+
'name'] and start_node in self.nodes and 'response' in self.nodes[
|
142 |
+
start_node]:
|
143 |
+
parent_nodes.append(self.nodes[start_node])
|
144 |
+
parent_response = [
|
145 |
+
dict(question=node['content'], answer=node['response'])
|
146 |
+
for node in parent_nodes
|
147 |
+
]
|
148 |
+
for answer in agent.stream_chat(
|
149 |
+
node_content,
|
150 |
+
self.nodes['root']['content'],
|
151 |
+
parent_response=parent_response):
|
152 |
+
self.searcher_resp_queue.put(
|
153 |
+
deepcopy((node_name,
|
154 |
+
dict(response=answer.response,
|
155 |
+
detail=answer), [])))
|
156 |
+
self.nodes[node_name]['response'] = answer.response
|
157 |
+
self.nodes[node_name]['detail'] = answer
|
158 |
+
except Exception as e:
|
159 |
+
logger.exception(f'Error in model_stream_thread: {e}')
|
160 |
+
|
161 |
+
self.future_to_query[self.executor.submit(
|
162 |
+
model_stream_thread)] = f'{node_name}-{node_content}'
|
163 |
+
|
164 |
+
def add_response_node(self, node_name='response'):
|
165 |
+
self.nodes[node_name] = dict(type='end')
|
166 |
+
self.searcher_resp_queue.put((node_name, self.nodes[node_name], []))
|
167 |
+
|
168 |
+
def add_edge(self, start_node, end_node):
|
169 |
+
self.adjacency_list[start_node].append(
|
170 |
+
dict(id=str(uuid.uuid4()), name=end_node, state=2))
|
171 |
+
self.searcher_resp_queue.put((start_node, self.nodes[start_node],
|
172 |
+
self.adjacency_list[start_node]))
|
173 |
+
|
174 |
+
def reset(self):
|
175 |
+
self.nodes = {}
|
176 |
+
self.adjacency_list = defaultdict(list)
|
177 |
+
|
178 |
+
def node(self, node_name):
|
179 |
+
return self.nodes[node_name].copy()
|
180 |
+
|
181 |
+
|
182 |
+
class MindSearchAgent(BaseAgent):
|
183 |
+
|
184 |
+
def __init__(self,
|
185 |
+
llm,
|
186 |
+
searcher_cfg,
|
187 |
+
protocol=MindSearchProtocol(),
|
188 |
+
max_turn=10):
|
189 |
+
self.local_dict = {}
|
190 |
+
self.ptr = 0
|
191 |
+
self.llm = llm
|
192 |
+
self.max_turn = max_turn
|
193 |
+
WebSearchGraph.searcher_cfg = searcher_cfg
|
194 |
+
super().__init__(llm=llm, action_executor=None, protocol=protocol)
|
195 |
+
|
196 |
+
def stream_chat(self, message, **kwargs):
|
197 |
+
if isinstance(message, str):
|
198 |
+
message = [{'role': 'user', 'content': message}]
|
199 |
+
elif isinstance(message, dict):
|
200 |
+
message = [message]
|
201 |
+
as_dict = kwargs.pop('as_dict', False)
|
202 |
+
return_early = kwargs.pop('return_early', False)
|
203 |
+
self.local_dict.clear()
|
204 |
+
self.ptr = 0
|
205 |
+
inner_history = message[:]
|
206 |
+
agent_return = AgentReturn()
|
207 |
+
agent_return.type = 'planner'
|
208 |
+
agent_return.nodes = {}
|
209 |
+
agent_return.adjacency_list = {}
|
210 |
+
agent_return.inner_steps = deepcopy(inner_history)
|
211 |
+
for _ in range(self.max_turn):
|
212 |
+
prompt = self._protocol.format(inner_step=inner_history)
|
213 |
+
code = None
|
214 |
+
for model_state, response, _ in self.llm.stream_chat(
|
215 |
+
prompt, session_id=random.randint(0, 999999), **kwargs):
|
216 |
+
if model_state.value < 0:
|
217 |
+
agent_return.state = getattr(AgentStatusCode,
|
218 |
+
model_state.name)
|
219 |
+
yield deepcopy(agent_return)
|
220 |
+
return
|
221 |
+
response = response.replace('<|plugin|>', '<|interpreter|>')
|
222 |
+
_, language, action = self._protocol.parse(response)
|
223 |
+
if not language and not action:
|
224 |
+
continue
|
225 |
+
code = action['parameters']['command'] if action else ''
|
226 |
+
agent_return.state = self._determine_agent_state(
|
227 |
+
model_state, code, agent_return)
|
228 |
+
agent_return.response = language if not code else code
|
229 |
+
|
230 |
+
# if agent_return.state == AgentStatusCode.STREAM_ING:
|
231 |
+
yield deepcopy(agent_return)
|
232 |
+
|
233 |
+
inner_history.append({'role': 'language', 'content': language})
|
234 |
+
print(colored(response, 'blue'))
|
235 |
+
|
236 |
+
if code:
|
237 |
+
yield from self._process_code(agent_return, inner_history,
|
238 |
+
code, as_dict, return_early)
|
239 |
+
else:
|
240 |
+
agent_return.state = AgentStatusCode.END
|
241 |
+
yield deepcopy(agent_return)
|
242 |
+
return
|
243 |
+
|
244 |
+
agent_return.state = AgentStatusCode.END
|
245 |
+
yield deepcopy(agent_return)
|
246 |
+
|
247 |
+
def _determine_agent_state(self, model_state, code, agent_return):
|
248 |
+
if code:
|
249 |
+
return (AgentStatusCode.PLUGIN_START if model_state
|
250 |
+
== ModelStatusCode.END else AgentStatusCode.PLUGIN_START)
|
251 |
+
return (AgentStatusCode.ANSWER_ING
|
252 |
+
if agent_return.nodes and 'response' in agent_return.nodes else
|
253 |
+
AgentStatusCode.STREAM_ING)
|
254 |
+
|
255 |
+
def _process_code(self,
|
256 |
+
agent_return,
|
257 |
+
inner_history,
|
258 |
+
code,
|
259 |
+
as_dict=False,
|
260 |
+
return_early=False):
|
261 |
+
for node_name, node, adj in self.execute_code(
|
262 |
+
code, return_early=return_early):
|
263 |
+
if as_dict and 'detail' in node:
|
264 |
+
node['detail'] = asdict(node['detail'])
|
265 |
+
if not adj:
|
266 |
+
agent_return.nodes[node_name] = node
|
267 |
+
else:
|
268 |
+
agent_return.adjacency_list[node_name] = adj
|
269 |
+
# state 1进行中,2未开始,3已结束
|
270 |
+
for start_node, neighbors in agent_return.adjacency_list.items():
|
271 |
+
for neighbor in neighbors:
|
272 |
+
if neighbor['name'] not in agent_return.nodes:
|
273 |
+
state = 2
|
274 |
+
elif 'detail' not in agent_return.nodes[neighbor['name']]:
|
275 |
+
state = 2
|
276 |
+
elif agent_return.nodes[neighbor['name']][
|
277 |
+
'detail'].state == AgentStatusCode.END:
|
278 |
+
state = 3
|
279 |
+
else:
|
280 |
+
state = 1
|
281 |
+
neighbor['state'] = state
|
282 |
+
if not adj:
|
283 |
+
yield deepcopy((agent_return, node_name))
|
284 |
+
reference, references_url = self._generate_reference(
|
285 |
+
agent_return, code, as_dict)
|
286 |
+
inner_history.append({
|
287 |
+
'role': 'tool',
|
288 |
+
'content': code,
|
289 |
+
'name': 'plugin'
|
290 |
+
})
|
291 |
+
inner_history.append({
|
292 |
+
'role': 'environment',
|
293 |
+
'content': reference,
|
294 |
+
'name': 'plugin'
|
295 |
+
})
|
296 |
+
agent_return.inner_steps = deepcopy(inner_history)
|
297 |
+
agent_return.state = AgentStatusCode.PLUGIN_RETURN
|
298 |
+
agent_return.references.update(references_url)
|
299 |
+
yield deepcopy(agent_return)
|
300 |
+
|
301 |
+
def _generate_reference(self, agent_return, code, as_dict):
|
302 |
+
node_list = [
|
303 |
+
node.strip().strip('\"') for node in re.findall(
|
304 |
+
r'graph\.node\("((?:[^"\\]|\\.)*?)"\)', code)
|
305 |
+
]
|
306 |
+
if 'add_response_node' in code:
|
307 |
+
return self._protocol.response_prompt, dict()
|
308 |
+
references = []
|
309 |
+
references_url = dict()
|
310 |
+
for node_name in node_list:
|
311 |
+
ref_results = None
|
312 |
+
ref2url = None
|
313 |
+
if as_dict:
|
314 |
+
actions = agent_return.nodes[node_name]['detail']['actions']
|
315 |
+
else:
|
316 |
+
actions = agent_return.nodes[node_name]['detail'].actions
|
317 |
+
if actions:
|
318 |
+
ref_results = actions[0]['result'][0][
|
319 |
+
'content'] if as_dict else actions[0].result[0]['content']
|
320 |
+
if ref_results:
|
321 |
+
ref_results = json.loads(ref_results)
|
322 |
+
ref2url = {
|
323 |
+
idx: item['url']
|
324 |
+
for idx, item in ref_results.items()
|
325 |
+
}
|
326 |
+
|
327 |
+
ref = f"## {node_name}\n\n{agent_return.nodes[node_name]['response']}\n"
|
328 |
+
updated_ref = re.sub(
|
329 |
+
r'\[\[(\d+)\]\]',
|
330 |
+
lambda match: f'[[{int(match.group(1)) + self.ptr}]]', ref)
|
331 |
+
numbers = [int(n) for n in re.findall(r'\[\[(\d+)\]\]', ref)]
|
332 |
+
if numbers:
|
333 |
+
try:
|
334 |
+
assert all(str(elem) in ref2url for elem in numbers)
|
335 |
+
except Exception as exc:
|
336 |
+
logger.info(f'Illegal reference id: {str(exc)}')
|
337 |
+
if ref2url:
|
338 |
+
references_url.update({
|
339 |
+
str(idx + self.ptr): ref2url[str(idx)]
|
340 |
+
for idx in set(numbers) if str(idx) in ref2url
|
341 |
+
})
|
342 |
+
self.ptr += max(numbers) + 1
|
343 |
+
references.append(updated_ref)
|
344 |
+
return '\n'.join(references), references_url
|
345 |
+
|
346 |
+
def execute_code(self, command: str, return_early=False):
|
347 |
+
|
348 |
+
def extract_code(text: str) -> str:
|
349 |
+
text = re.sub(r'from ([\w.]+) import WebSearchGraph', '', text)
|
350 |
+
triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
|
351 |
+
single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
|
352 |
+
if triple_match:
|
353 |
+
return triple_match.group(1)
|
354 |
+
elif single_match:
|
355 |
+
return single_match.group(1)
|
356 |
+
return text
|
357 |
+
|
358 |
+
def run_command(cmd):
|
359 |
+
try:
|
360 |
+
exec(cmd, globals(), self.local_dict)
|
361 |
+
plan_graph = self.local_dict.get('graph')
|
362 |
+
assert plan_graph is not None
|
363 |
+
for future in as_completed(plan_graph.future_to_query):
|
364 |
+
future.result()
|
365 |
+
plan_graph.future_to_query.clear()
|
366 |
+
plan_graph.searcher_resp_queue.put(plan_graph.end_signal)
|
367 |
+
except Exception as e:
|
368 |
+
logger.exception(f'Error executing code: {e}')
|
369 |
+
raise
|
370 |
+
|
371 |
+
command = extract_code(command)
|
372 |
+
producer_thread = threading.Thread(target=run_command,
|
373 |
+
args=(command, ))
|
374 |
+
producer_thread.start()
|
375 |
+
|
376 |
+
responses = defaultdict(list)
|
377 |
+
ordered_nodes = []
|
378 |
+
active_node = None
|
379 |
+
|
380 |
+
while True:
|
381 |
+
try:
|
382 |
+
item = self.local_dict.get('graph').searcher_resp_queue.get(
|
383 |
+
timeout=60)
|
384 |
+
if item is WebSearchGraph.end_signal:
|
385 |
+
for node_name in ordered_nodes:
|
386 |
+
# resp = None
|
387 |
+
for resp in responses[node_name]:
|
388 |
+
yield deepcopy(resp)
|
389 |
+
# if resp:
|
390 |
+
# assert resp[1][
|
391 |
+
# 'detail'].state == AgentStatusCode.END
|
392 |
+
break
|
393 |
+
node_name, node, adj = item
|
394 |
+
if node_name in ['root', 'response']:
|
395 |
+
yield deepcopy((node_name, node, adj))
|
396 |
+
else:
|
397 |
+
if node_name not in ordered_nodes:
|
398 |
+
ordered_nodes.append(node_name)
|
399 |
+
responses[node_name].append((node_name, node, adj))
|
400 |
+
if not active_node and ordered_nodes:
|
401 |
+
active_node = ordered_nodes[0]
|
402 |
+
while active_node and responses[active_node]:
|
403 |
+
if return_early:
|
404 |
+
if 'detail' in responses[active_node][-1][
|
405 |
+
1] and responses[active_node][-1][1][
|
406 |
+
'detail'].state == AgentStatusCode.END:
|
407 |
+
item = responses[active_node][-1]
|
408 |
+
else:
|
409 |
+
item = responses[active_node].pop(0)
|
410 |
+
else:
|
411 |
+
item = responses[active_node].pop(0)
|
412 |
+
if 'detail' in item[1] and item[1][
|
413 |
+
'detail'].state == AgentStatusCode.END:
|
414 |
+
ordered_nodes.pop(0)
|
415 |
+
responses[active_node].clear()
|
416 |
+
active_node = None
|
417 |
+
yield deepcopy(item)
|
418 |
+
except queue.Empty:
|
419 |
+
if not producer_thread.is_alive():
|
420 |
+
break
|
421 |
+
producer_thread.join()
|
422 |
+
return
|
mindsearch/agent/mindsearch_prompt.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# flake8: noqa
|
2 |
+
|
3 |
+
searcher_system_prompt_cn = """## 人物简介
|
4 |
+
你是一个可以调用网络搜索工具的智能助手。请根据"当前问题",调用搜索工具收集信息并回复问题。你能够调用如下工具:
|
5 |
+
{tool_info}
|
6 |
+
## 回复格式
|
7 |
+
|
8 |
+
调用工具时,请按照以下格式:
|
9 |
+
```
|
10 |
+
你的思考过程...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
|
11 |
+
```
|
12 |
+
|
13 |
+
## 要求
|
14 |
+
|
15 |
+
- 回答中每个关键点需标注引用的搜索结果来源,以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
|
16 |
+
- 基于"当前问题"的搜索结果,撰写详细完备的回复,优先回答"当前问题"。
|
17 |
+
|
18 |
+
"""
|
19 |
+
|
20 |
+
searcher_system_prompt_en = """## Character Introduction
|
21 |
+
You are an intelligent assistant that can call web search tools. Please collect information and reply to the question based on the current problem. You can use the following tools:
|
22 |
+
{tool_info}
|
23 |
+
## Reply Format
|
24 |
+
|
25 |
+
When calling the tool, please follow the format below:
|
26 |
+
```
|
27 |
+
Your thought process...<|action_start|><|plugin|>{{"name": "tool_name", "parameters": {{"param1": "value1"}}}}<|action_end|>
|
28 |
+
```
|
29 |
+
|
30 |
+
## Requirements
|
31 |
+
|
32 |
+
- Each key point in the response should be marked with the source of the search results to ensure the credibility of the information. The citation format is `[[int]]`. If there are multiple citations, use multiple [[]] to provide the index, such as `[[id_1]][[id_2]]`.
|
33 |
+
- Based on the search results of the "current problem", write a detailed and complete reply to answer the "current problem".
|
34 |
+
"""
|
35 |
+
|
36 |
+
fewshot_example_cn = """
|
37 |
+
## 样例
|
38 |
+
|
39 |
+
### search
|
40 |
+
当我希望搜索"王者荣耀现在是什么赛季"时,我会按照以下格式进行操作:
|
41 |
+
现在是2024年,因此我应该搜索王者荣耀赛季关键词<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["王者荣耀 赛季", "2024年王者荣耀赛季"]}}}}<|action_end|>
|
42 |
+
|
43 |
+
### select
|
44 |
+
为了找到王者荣耀s36赛季最强射手,我需要寻找提及王者荣耀s36射手的网页。初步浏览网页后,发现网页0提到王者荣耀s36赛季的信息,但没有具体提及射手的相关信息。网页3提到“s36最强射手出现?”,有可能包含最强射手信息。网页13提到“四大T0英雄崛起,射手荣耀降临”,可能包含最强射手的信息。因此,我选择了网页3和网页13进行进一步阅读。<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
|
45 |
+
"""
|
46 |
+
|
47 |
+
fewshot_example_en = """
|
48 |
+
## Example
|
49 |
+
|
50 |
+
### search
|
51 |
+
When I want to search for "What season is Honor of Kings now", I will operate in the following format:
|
52 |
+
Now it is 2024, so I should search for the keyword of the Honor of Kings<|action_start|><|plugin|>{{"name": "FastWebBrowser.search", "parameters": {{"query": ["Honor of Kings Season", "season for Honor of Kings in 2024"]}}}}<|action_end|>
|
53 |
+
|
54 |
+
### select
|
55 |
+
In order to find the strongest shooters in Honor of Kings in season s36, I needed to look for web pages that mentioned shooters in Honor of Kings in season s36. After an initial browse of the web pages, I found that web page 0 mentions information about Honor of Kings in s36 season, but there is no specific mention of information about the shooter. Webpage 3 mentions that “the strongest shooter in s36 has appeared?”, which may contain information about the strongest shooter. Webpage 13 mentions “Four T0 heroes rise, archer's glory”, which may contain information about the strongest archer. Therefore, I chose webpages 3 and 13 for further reading.<|action_start|><|plugin|>{{"name": "FastWebBrowser.select", "parameters": {{"index": [3, 13]}}}}<|action_end|>
|
56 |
+
"""
|
57 |
+
|
58 |
+
searcher_input_template_en = """## Final Problem
|
59 |
+
{topic}
|
60 |
+
## Current Problem
|
61 |
+
{question}
|
62 |
+
"""
|
63 |
+
|
64 |
+
searcher_input_template_cn = """## 主问题
|
65 |
+
{topic}
|
66 |
+
## 当前问题
|
67 |
+
{question}
|
68 |
+
"""
|
69 |
+
|
70 |
+
searcher_context_template_en = """## Historical Problem
|
71 |
+
{question}
|
72 |
+
Answer: {answer}
|
73 |
+
"""
|
74 |
+
|
75 |
+
searcher_context_template_cn = """## 历史问题
|
76 |
+
{question}
|
77 |
+
回答:{answer}
|
78 |
+
"""
|
79 |
+
|
80 |
+
search_template_cn = '## {query}\n\n{result}\n'
|
81 |
+
search_template_en = '## {query}\n\n{result}\n'
|
82 |
+
|
83 |
+
GRAPH_PROMPT_CN = """## 人物简介
|
84 |
+
你是一个可以利用 Jupyter 环境 Python 编程的程序员。你可以利用提供的 API 来构建 Web 搜索图,最终生成代码并执行。
|
85 |
+
|
86 |
+
## API 介绍
|
87 |
+
|
88 |
+
下面是包含属性详细说明的 `WebSearchGraph` 类的 API 文档:
|
89 |
+
|
90 |
+
### 类:`WebSearchGraph`
|
91 |
+
|
92 |
+
此类用于管理网络搜索图的节点和边,并通过网络代理进行搜索。
|
93 |
+
|
94 |
+
#### 初始化方法
|
95 |
+
|
96 |
+
初始化 `WebSearchGraph` 实例。
|
97 |
+
|
98 |
+
**属性:**
|
99 |
+
|
100 |
+
- `nodes` (Dict[str, Dict[str, str]]): 存储图中所有节点的字典。每个节点由其名称索引,并包含内容、类型以及其他相关信息。
|
101 |
+
- `adjacency_list` (Dict[str, List[str]]): 存储��中所有节点之间连接关系的邻接表。每个节点由其名称索引,并包含一个相邻节点名称的列表。
|
102 |
+
|
103 |
+
|
104 |
+
#### 方法:`add_root_node`
|
105 |
+
|
106 |
+
添加原始问题作为根节点。
|
107 |
+
**参数:**
|
108 |
+
|
109 |
+
- `node_content` (str): 用户提出的问题。
|
110 |
+
- `node_name` (str, 可选): 节点名称,默认为 'root'。
|
111 |
+
|
112 |
+
|
113 |
+
#### 方法:`add_node`
|
114 |
+
|
115 |
+
添加搜索子问题节点并返回搜索结果。
|
116 |
+
**参数:
|
117 |
+
|
118 |
+
- `node_name` (str): 节点名称。
|
119 |
+
- `node_content` (str): 子问题内容。
|
120 |
+
|
121 |
+
**返回:**
|
122 |
+
|
123 |
+
- `str`: 返回搜索结果。
|
124 |
+
|
125 |
+
|
126 |
+
#### 方法:`add_response_node`
|
127 |
+
|
128 |
+
当前获取的信息已经满足问题需求,添加回复节点。
|
129 |
+
|
130 |
+
**参数:**
|
131 |
+
|
132 |
+
- `node_name` (str, 可选): 节点名称,默认为 'response'。
|
133 |
+
|
134 |
+
|
135 |
+
#### 方法:`add_edge`
|
136 |
+
|
137 |
+
添加边。
|
138 |
+
|
139 |
+
**参数:**
|
140 |
+
|
141 |
+
- `start_node` (str): 起始节点名称。
|
142 |
+
- `end_node` (str): 结束节点名称。
|
143 |
+
|
144 |
+
|
145 |
+
#### 方法:`reset`
|
146 |
+
|
147 |
+
重置节点和边。
|
148 |
+
|
149 |
+
|
150 |
+
#### 方法:`node`
|
151 |
+
|
152 |
+
获取节点信息。
|
153 |
+
|
154 |
+
```python
|
155 |
+
def node(self, node_name: str) -> str
|
156 |
+
```
|
157 |
+
|
158 |
+
**参数:**
|
159 |
+
|
160 |
+
- `node_name` (str): 节点名称。
|
161 |
+
|
162 |
+
**返回:**
|
163 |
+
|
164 |
+
- `str`: 返回包含节点信息的字典,包含节点的内容、类型、思考过程(如果有)和前驱节点列表。
|
165 |
+
|
166 |
+
## 任务介绍
|
167 |
+
通过将一个问题拆分成能够通过搜索回答的子问题(没有关联的问题可以同步并列搜索),每个搜索的问题应该是一个单一问题,即单个具体人、事、物、具体时间点、地点或知识点的问题,不是一个复合问题(比如某个时间段), 一步步构建搜索图,最终回答问题。
|
168 |
+
|
169 |
+
## 注意事项
|
170 |
+
|
171 |
+
1. 注意,每个搜索节点的内容必须单个问题,不要包含多个问题(比如同时问多个知识点的问题或者多个事物的比较加筛选,类似 A, B, C 有什么区别,那个价格在哪个区间 -> 分别查询)
|
172 |
+
2. 不要杜撰搜索结果,要等待代码返回结果
|
173 |
+
3. 同样的问题不要重复提问,可以在已有问题的基础上继续提问
|
174 |
+
4. 添加 response 节点的时候,要单独添加,不要和其他节点一起添加,不能同时添加 response 节点和其他节点
|
175 |
+
5. 一次输出中,不要包含多个代码块,每次只能有一个代码块
|
176 |
+
6. 每个代码块应该放置在一个代码块标记中,同时生成完代码后添加一个<|action_end|>标志,如下所示:
|
177 |
+
<|action_start|><|interpreter|>```python
|
178 |
+
# 你的代码块
|
179 |
+
```<|action_end|>
|
180 |
+
7. 最后一次回复应该是添加node_name为'response'的 response 节点,必须添加 response 节点,不要添加其他节点
|
181 |
+
"""
|
182 |
+
|
183 |
+
GRAPH_PROMPT_EN = """## Character Profile
|
184 |
+
You are a programmer capable of Python programming in a Jupyter environment. You can utilize the provided API to construct a Web Search Graph, ultimately generating and executing code.
|
185 |
+
|
186 |
+
## API Description
|
187 |
+
|
188 |
+
Below is the API documentation for the WebSearchGraph class, including detailed attribute descriptions:
|
189 |
+
|
190 |
+
### Class: WebSearchGraph
|
191 |
+
|
192 |
+
This class manages nodes and edges of a web search graph and conducts searches via a web proxy.
|
193 |
+
|
194 |
+
#### Initialization Method
|
195 |
+
|
196 |
+
Initializes an instance of WebSearchGraph.
|
197 |
+
|
198 |
+
**Attributes:**
|
199 |
+
|
200 |
+
- nodes (Dict[str, Dict[str, str]]): A dictionary storing all nodes in the graph. Each node is indexed by its name and contains content, type, and other related information.
|
201 |
+
- adjacency_list (Dict[str, List[str]]): An adjacency list storing the connections between all nodes in the graph. Each node is indexed by its name and contains a list of adjacent node names.
|
202 |
+
|
203 |
+
#### Method: add_root_node
|
204 |
+
|
205 |
+
Adds the initial question as the root node.
|
206 |
+
**Parameters:**
|
207 |
+
|
208 |
+
- node_content (str): The user's question.
|
209 |
+
- node_name (str, optional): The node name, default is 'root'.
|
210 |
+
|
211 |
+
#### Method: add_node
|
212 |
+
|
213 |
+
Adds a sub-question node and returns search results.
|
214 |
+
**Parameters:**
|
215 |
+
|
216 |
+
- node_name (str): The node name.
|
217 |
+
- node_content (str): The sub-question content.
|
218 |
+
|
219 |
+
**Returns:**
|
220 |
+
|
221 |
+
- str: Returns the search results.
|
222 |
+
|
223 |
+
#### Method: add_response_node
|
224 |
+
|
225 |
+
Adds a response node when the current information satisfies the question's requirements.
|
226 |
+
|
227 |
+
**Parameters:**
|
228 |
+
|
229 |
+
- node_name (str, optional): The node name, default is 'response'.
|
230 |
+
|
231 |
+
#### Method: add_edge
|
232 |
+
|
233 |
+
Adds an edge.
|
234 |
+
|
235 |
+
**Parameters:**
|
236 |
+
|
237 |
+
- start_node (str): The starting node name.
|
238 |
+
- end_node (str): The ending node name.
|
239 |
+
|
240 |
+
#### Method: reset
|
241 |
+
|
242 |
+
Resets nodes and edges.
|
243 |
+
|
244 |
+
#### Method: node
|
245 |
+
|
246 |
+
Get node information.
|
247 |
+
|
248 |
+
python
|
249 |
+
def node(self, node_name: str) -> str
|
250 |
+
|
251 |
+
**Parameters:**
|
252 |
+
|
253 |
+
- node_name (str): The node name.
|
254 |
+
|
255 |
+
**Returns:**
|
256 |
+
|
257 |
+
- str: Returns a dictionary containing the node's information, including content, type, thought process (if any), and list of predecessor nodes.
|
258 |
+
|
259 |
+
## Task Description
|
260 |
+
By breaking down a question into sub-questions that can be answered through searches (unrelated questions can be searched concurrently), each search query should be a single question focusing on a specific person, event, object, specific time point, location, or knowledge point. It should not be a compound question (e.g., a time period). Step by step, build the search graph to finally answer the question.
|
261 |
+
|
262 |
+
## Considerations
|
263 |
+
|
264 |
+
1. Each search node's content must be a single question; do not include multiple questions (e.g., do not ask multiple knowledge points or compare and filter multiple things simultaneously, like asking for differences between A, B, and C, or price ranges -> query each separately).
|
265 |
+
2. Do not fabricate search results; wait for the code to return results.
|
266 |
+
3. Do not repeat the same question; continue asking based on existing questions.
|
267 |
+
4. When adding a response node, add it separately; do not add a response node and other nodes simultaneously.
|
268 |
+
5. In a single output, do not include multiple code blocks; only one code block per output.
|
269 |
+
6. Each code block should be placed within a code block marker, and after generating the code, add an <|action_end|> tag as shown below:
|
270 |
+
<|action_start|><|interpreter|>
|
271 |
+
```python
|
272 |
+
# Your code block (Note that the 'Get new added node information' logic must be added at the end of the code block, such as 'graph.node('...')')
|
273 |
+
```<|action_end|>
|
274 |
+
7. The final response should add a response node with node_name 'response', and no other nodes should be added.
|
275 |
+
"""
|
276 |
+
|
277 |
+
graph_fewshot_example_cn = """
|
278 |
+
## 返回格式示例
|
279 |
+
<|action_start|><|interpreter|>```python
|
280 |
+
graph = WebSearchGraph()
|
281 |
+
graph.add_root_node(node_content="哪家大模型API最便宜?", node_name="root") # 添加原始问题作为根节点
|
282 |
+
graph.add_node(
|
283 |
+
node_name="大模型API提供商", # 节点名称最好有意义
|
284 |
+
node_content="目前有哪些主要的大模型API提供商?")
|
285 |
+
graph.add_node(
|
286 |
+
node_name="sub_name_2", # 节点名称最好有意义
|
287 |
+
node_content="content of sub_name_2")
|
288 |
+
...
|
289 |
+
graph.add_edge(start_node="root", end_node="sub_name_1")
|
290 |
+
...
|
291 |
+
graph.node("大模型API提供商"), graph.node("sub_name_2"), ...
|
292 |
+
```<|action_end|>
|
293 |
+
"""
|
294 |
+
|
295 |
+
graph_fewshot_example_en = """
|
296 |
+
## Response Format
|
297 |
+
<|action_start|><|interpreter|>```python
|
298 |
+
graph = WebSearchGraph()
|
299 |
+
graph.add_root_node(node_content="Which large model API is the cheapest?", node_name="root") # Add the original question as the root node
|
300 |
+
graph.add_node(
|
301 |
+
node_name="Large Model API Providers", # The node name should be meaningful
|
302 |
+
node_content="Who are the main large model API providers currently?")
|
303 |
+
graph.add_node(
|
304 |
+
node_name="sub_name_2", # The node name should be meaningful
|
305 |
+
node_content="content of sub_name_2")
|
306 |
+
...
|
307 |
+
graph.add_edge(start_node="root", end_node="sub_name_1")
|
308 |
+
...
|
309 |
+
# Get node info
|
310 |
+
graph.node("Large Model API Providers"), graph.node("sub_name_2"), ...
|
311 |
+
```<|action_end|>
|
312 |
+
"""
|
313 |
+
|
314 |
+
FINAL_RESPONSE_CN = """基于提供的问答对,撰写一篇详细完备的最终回答。
|
315 |
+
- 回答内容需要逻辑清晰,层次分明,确保读者易于理解。
|
316 |
+
- 回答中每个关键点需标注引用的搜索结果来源(保持跟问答对中的索引一致),以确保信息的可信度。给出索引的形式为`[[int]]`,如果有多个索引,则用多个[[]]表示,如`[[id_1]][[id_2]]`。
|
317 |
+
- 回答部分需要全面且完备,不要出现"基于上述内容"等模糊表达,最终呈现的回答不包括提供给你的问答对。
|
318 |
+
- 语言风格需要专业、严谨,避免口语化表达。
|
319 |
+
- 保持统一的语法和词汇使用,确保整体文档的一致性和连贯性。"""
|
320 |
+
|
321 |
+
FINAL_RESPONSE_EN = """Based on the provided Q&A pairs, write a detailed and comprehensive final response.
|
322 |
+
- The response content should be logically clear and well-structured to ensure reader understanding.
|
323 |
+
- Each key point in the response should be marked with the source of the search results (consistent with the indices in the Q&A pairs) to ensure information credibility. The index is in the form of `[[int]]`, and if there are multiple indices, use multiple `[[]]`, such as `[[id_1]][[id_2]]`.
|
324 |
+
- The response should be comprehensive and complete, without vague expressions like "based on the above content". The final response should not include the Q&A pairs provided to you.
|
325 |
+
- The language style should be professional and rigorous, avoiding colloquial expressions.
|
326 |
+
- Maintain consistent grammar and vocabulary usage to ensure overall document consistency and coherence."""
|
mindsearch/agent/models.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
from lagent.llms import (GPTAPI, INTERNLM2_META, HFTransformerCasualLM,
|
4 |
+
LMDeployClient, LMDeployServer)
|
5 |
+
|
6 |
+
internlm_server = dict(type=LMDeployServer,
|
7 |
+
path='internlm/internlm2_5-7b-chat',
|
8 |
+
model_name='internlm2',
|
9 |
+
meta_template=INTERNLM2_META,
|
10 |
+
top_p=0.8,
|
11 |
+
top_k=1,
|
12 |
+
temperature=0,
|
13 |
+
max_new_tokens=8192,
|
14 |
+
repetition_penalty=1.02,
|
15 |
+
stop_words=['<|im_end|>'])
|
16 |
+
|
17 |
+
internlm_client = dict(type=LMDeployClient,
|
18 |
+
model_name='internlm2_5-7b-chat',
|
19 |
+
url='http://127.0.0.1:23333',
|
20 |
+
meta_template=INTERNLM2_META,
|
21 |
+
top_p=0.8,
|
22 |
+
top_k=1,
|
23 |
+
temperature=0,
|
24 |
+
max_new_tokens=8192,
|
25 |
+
repetition_penalty=1.02,
|
26 |
+
stop_words=['<|im_end|>'])
|
27 |
+
|
28 |
+
internlm_hf = dict(type=HFTransformerCasualLM,
|
29 |
+
path='internlm/internlm2_5-7b-chat',
|
30 |
+
meta_template=INTERNLM2_META,
|
31 |
+
top_p=0.8,
|
32 |
+
top_k=None,
|
33 |
+
temperature=1e-6,
|
34 |
+
max_new_tokens=8192,
|
35 |
+
repetition_penalty=1.02,
|
36 |
+
stop_words=['<|im_end|>'])
|
37 |
+
# openai_api_base needs to fill in the complete chat api address, such as: https://api.openai.com/v1/chat/completions
|
38 |
+
gpt4 = dict(type=GPTAPI,
|
39 |
+
model_type='gpt-4-turbo',
|
40 |
+
key=os.environ.get('OPENAI_API_KEY', 'YOUR OPENAI API KEY'),
|
41 |
+
openai_api_base=os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1/chat/completions'),
|
42 |
+
)
|
43 |
+
|
44 |
+
url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation'
|
45 |
+
qwen = dict(type=GPTAPI,
|
46 |
+
model_type='qwen-max-longcontext',
|
47 |
+
key=os.environ.get('QWEN_API_KEY', 'YOUR QWEN API KEY'),
|
48 |
+
openai_api_base=url,
|
49 |
+
meta_template=[
|
50 |
+
dict(role='system', api_role='system'),
|
51 |
+
dict(role='user', api_role='user'),
|
52 |
+
dict(role='assistant', api_role='assistant'),
|
53 |
+
dict(role='environment', api_role='system')
|
54 |
+
],
|
55 |
+
top_p=0.8,
|
56 |
+
top_k=1,
|
57 |
+
temperature=0,
|
58 |
+
max_new_tokens=4096,
|
59 |
+
repetition_penalty=1.02,
|
60 |
+
stop_words=['<|im_end|>'])
|
61 |
+
|
62 |
+
internlm_silicon = dict(type=GPTAPI,
|
63 |
+
model_type='internlm/internlm2_5-7b-chat',
|
64 |
+
key=os.environ.get('SILICON_API_KEY', 'YOUR SILICON API KEY'),
|
65 |
+
openai_api_base='https://api.siliconflow.cn/v1/chat/completions',
|
66 |
+
meta_template=[
|
67 |
+
dict(role='system', api_role='system'),
|
68 |
+
dict(role='user', api_role='user'),
|
69 |
+
dict(role='assistant', api_role='assistant'),
|
70 |
+
dict(role='environment', api_role='system')
|
71 |
+
],
|
72 |
+
top_p=0.8,
|
73 |
+
top_k=1,
|
74 |
+
temperature=0,
|
75 |
+
max_new_tokens=8192,
|
76 |
+
repetition_penalty=1.02,
|
77 |
+
stop_words=['<|im_end|>'])
|
mindsearch/app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
from copy import deepcopy
|
5 |
+
from dataclasses import asdict
|
6 |
+
from typing import Dict, List, Union
|
7 |
+
|
8 |
+
import janus
|
9 |
+
from fastapi import FastAPI
|
10 |
+
from fastapi.middleware.cors import CORSMiddleware
|
11 |
+
from lagent.schema import AgentStatusCode
|
12 |
+
from pydantic import BaseModel
|
13 |
+
from sse_starlette.sse import EventSourceResponse
|
14 |
+
|
15 |
+
from mindsearch.agent import init_agent
|
16 |
+
|
17 |
+
|
18 |
+
def parse_arguments():
|
19 |
+
import argparse
|
20 |
+
parser = argparse.ArgumentParser(description='MindSearch API')
|
21 |
+
parser.add_argument('--lang', default='cn', type=str, help='Language')
|
22 |
+
parser.add_argument('--model_format',
|
23 |
+
default='internlm_server',
|
24 |
+
type=str,
|
25 |
+
help='Model format')
|
26 |
+
parser.add_argument('--search_engine',
|
27 |
+
default='DuckDuckGoSearch',
|
28 |
+
type=str,
|
29 |
+
help='Search engine')
|
30 |
+
return parser.parse_args()
|
31 |
+
|
32 |
+
|
33 |
+
args = parse_arguments()
|
34 |
+
app = FastAPI(docs_url='/')
|
35 |
+
|
36 |
+
app.add_middleware(CORSMiddleware,
|
37 |
+
allow_origins=['*'],
|
38 |
+
allow_credentials=True,
|
39 |
+
allow_methods=['*'],
|
40 |
+
allow_headers=['*'])
|
41 |
+
|
42 |
+
|
43 |
+
class GenerationParams(BaseModel):
|
44 |
+
inputs: Union[str, List[Dict]]
|
45 |
+
agent_cfg: Dict = dict()
|
46 |
+
|
47 |
+
|
48 |
+
@app.post('/solve')
|
49 |
+
async def run(request: GenerationParams):
|
50 |
+
|
51 |
+
def convert_adjacency_to_tree(adjacency_input, root_name):
|
52 |
+
|
53 |
+
def build_tree(node_name):
|
54 |
+
node = {'name': node_name, 'children': []}
|
55 |
+
if node_name in adjacency_input:
|
56 |
+
for child in adjacency_input[node_name]:
|
57 |
+
child_node = build_tree(child['name'])
|
58 |
+
child_node['state'] = child['state']
|
59 |
+
child_node['id'] = child['id']
|
60 |
+
node['children'].append(child_node)
|
61 |
+
return node
|
62 |
+
|
63 |
+
return build_tree(root_name)
|
64 |
+
|
65 |
+
async def generate():
|
66 |
+
try:
|
67 |
+
queue = janus.Queue()
|
68 |
+
stop_event = asyncio.Event()
|
69 |
+
|
70 |
+
# Wrapping a sync generator as an async generator using run_in_executor
|
71 |
+
def sync_generator_wrapper():
|
72 |
+
try:
|
73 |
+
for response in agent.stream_chat(inputs):
|
74 |
+
queue.sync_q.put(response)
|
75 |
+
except Exception as e:
|
76 |
+
logging.exception(
|
77 |
+
f'Exception in sync_generator_wrapper: {e}')
|
78 |
+
finally:
|
79 |
+
# Notify async_generator_wrapper that the data generation is complete.
|
80 |
+
queue.sync_q.put(None)
|
81 |
+
|
82 |
+
async def async_generator_wrapper():
|
83 |
+
loop = asyncio.get_event_loop()
|
84 |
+
loop.run_in_executor(None, sync_generator_wrapper)
|
85 |
+
while True:
|
86 |
+
response = await queue.async_q.get()
|
87 |
+
if response is None: # Ensure that all elements are consumed
|
88 |
+
break
|
89 |
+
yield response
|
90 |
+
if not isinstance(
|
91 |
+
response,
|
92 |
+
tuple) and response.state == AgentStatusCode.END:
|
93 |
+
break
|
94 |
+
stop_event.set() # Inform sync_generator_wrapper to stop
|
95 |
+
|
96 |
+
async for response in async_generator_wrapper():
|
97 |
+
if isinstance(response, tuple):
|
98 |
+
agent_return, node_name = response
|
99 |
+
else:
|
100 |
+
agent_return = response
|
101 |
+
node_name = None
|
102 |
+
origin_adj = deepcopy(agent_return.adjacency_list)
|
103 |
+
adjacency_list = convert_adjacency_to_tree(
|
104 |
+
agent_return.adjacency_list, 'root')
|
105 |
+
assert adjacency_list[
|
106 |
+
'name'] == 'root' and 'children' in adjacency_list
|
107 |
+
agent_return.adjacency_list = adjacency_list['children']
|
108 |
+
agent_return = asdict(agent_return)
|
109 |
+
agent_return['adj'] = origin_adj
|
110 |
+
response_json = json.dumps(dict(response=agent_return,
|
111 |
+
current_node=node_name),
|
112 |
+
ensure_ascii=False)
|
113 |
+
yield {'data': response_json}
|
114 |
+
# yield f'data: {response_json}\n\n'
|
115 |
+
except Exception as exc:
|
116 |
+
msg = 'An error occurred while generating the response.'
|
117 |
+
logging.exception(msg)
|
118 |
+
response_json = json.dumps(
|
119 |
+
dict(error=dict(msg=msg, details=str(exc))),
|
120 |
+
ensure_ascii=False)
|
121 |
+
yield {'data': response_json}
|
122 |
+
# yield f'data: {response_json}\n\n'
|
123 |
+
finally:
|
124 |
+
await stop_event.wait(
|
125 |
+
) # Waiting for async_generator_wrapper to stop
|
126 |
+
queue.close()
|
127 |
+
await queue.wait_closed()
|
128 |
+
|
129 |
+
inputs = request.inputs
|
130 |
+
agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
|
131 |
+
return EventSourceResponse(generate())
|
132 |
+
|
133 |
+
|
134 |
+
if __name__ == '__main__':
|
135 |
+
import uvicorn
|
136 |
+
uvicorn.run(app, host='0.0.0.0', port=8002, log_level='info')
|
mindsearch/terminal.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
|
3 |
+
from lagent.actions import ActionExecutor, BingBrowser
|
4 |
+
from lagent.llms import INTERNLM2_META, LMDeployServer
|
5 |
+
|
6 |
+
from mindsearch.agent.mindsearch_agent import (MindSearchAgent,
|
7 |
+
MindSearchProtocol)
|
8 |
+
from mindsearch.agent.mindsearch_prompt import (
|
9 |
+
FINAL_RESPONSE_CN, FINAL_RESPONSE_EN, GRAPH_PROMPT_CN, GRAPH_PROMPT_EN,
|
10 |
+
searcher_context_template_cn, searcher_context_template_en,
|
11 |
+
searcher_input_template_cn, searcher_input_template_en,
|
12 |
+
searcher_system_prompt_cn, searcher_system_prompt_en)
|
13 |
+
|
14 |
+
lang = 'cn'
|
15 |
+
llm = LMDeployServer(path='internlm/internlm2_5-7b-chat',
|
16 |
+
model_name='internlm2',
|
17 |
+
meta_template=INTERNLM2_META,
|
18 |
+
top_p=0.8,
|
19 |
+
top_k=1,
|
20 |
+
temperature=0,
|
21 |
+
max_new_tokens=8192,
|
22 |
+
repetition_penalty=1.02,
|
23 |
+
stop_words=['<|im_end|>'])
|
24 |
+
|
25 |
+
agent = MindSearchAgent(
|
26 |
+
llm=llm,
|
27 |
+
protocol=MindSearchProtocol(
|
28 |
+
meta_prompt=datetime.now().strftime('The current date is %Y-%m-%d.'),
|
29 |
+
interpreter_prompt=GRAPH_PROMPT_CN
|
30 |
+
if lang == 'cn' else GRAPH_PROMPT_EN,
|
31 |
+
response_prompt=FINAL_RESPONSE_CN
|
32 |
+
if lang == 'cn' else FINAL_RESPONSE_EN),
|
33 |
+
searcher_cfg=dict(
|
34 |
+
llm=llm,
|
35 |
+
plugin_executor=ActionExecutor(
|
36 |
+
BingBrowser(searcher_type='DuckDuckGoSearch', topk=6)),
|
37 |
+
protocol=MindSearchProtocol(
|
38 |
+
meta_prompt=datetime.now().strftime(
|
39 |
+
'The current date is %Y-%m-%d.'),
|
40 |
+
plugin_prompt=searcher_system_prompt_cn
|
41 |
+
if lang == 'cn' else searcher_system_prompt_en,
|
42 |
+
),
|
43 |
+
template=dict(input=searcher_input_template_cn
|
44 |
+
if lang == 'cn' else searcher_input_template_en,
|
45 |
+
context=searcher_context_template_cn
|
46 |
+
if lang == 'cn' else searcher_context_template_en)),
|
47 |
+
max_turn=10)
|
48 |
+
|
49 |
+
for agent_return in agent.stream_chat('上海今天适合穿什么衣服'):
|
50 |
+
pass
|