Spaces:
Runtime error
Runtime error
import json | |
from actions.duck_search import duckduckgo_search | |
from processing.text import read_txt_files | |
from agent.llm_utils import llm_response, llm_stream_response | |
from config import Config | |
from agent import prompts | |
import os | |
import string | |
CFG = Config() | |
class ResearchAgent: | |
def __init__(self, question, agent, system_prompt): | |
""" Initializes the research assistant with the given question. | |
Args: question (str): The question to research | |
Returns: None | |
""" | |
self.question = question | |
self.agent = agent | |
self.visited_urls = set() | |
self.search_summary = "" | |
self.directory_name = ''.join(c for c in question if c.isascii() and c not in string.punctuation)[:100] | |
self.dir_path = os.path.dirname(f"./outputs/{self.directory_name}/") | |
self.system_prompt = system_prompt | |
def call_agent(self, action): | |
messages = [{ | |
"role": "system", | |
"content": self.system_prompt, | |
}, { | |
"role": "user", | |
"content": action, | |
}] | |
return llm_response( | |
model=CFG.fast_llm_model, | |
messages=messages, | |
) | |
def call_agent_stream(self, action): | |
messages = [{ | |
"role": "system", | |
"content": self.system_prompt, | |
}, { | |
"role": "user", | |
"content": action, | |
}] | |
yield from llm_stream_response( | |
model=CFG.fast_llm_model, | |
messages=messages | |
) | |
def create_search_queries(self): | |
""" Creates the search queries for the given question. | |
Args: None | |
Returns: list[str]: The search queries for the given question | |
""" | |
result = self.call_agent(prompts.generate_search_queries_prompt(self.question)) | |
return json.loads(result) | |
def search_single_query(self, query): | |
""" Runs the async search for the given query. | |
Args: query (str): The query to run the async search for | |
Returns: list[str]: The async search for the given query | |
""" | |
return duckduckgo_search(query, max_search_result=3) | |
def run_search_summary(self, query): | |
""" Runs the search summary for the given query. | |
Args: query (str): The query to run the search summary for | |
Returns: str: The search summary for the given query | |
""" | |
responses = self.search_single_query(query) | |
print(f"Searching for {query}") | |
query = hash(query) | |
file_path = f"./outputs/{self.directory_name}/research-{query}.txt" | |
os.makedirs(os.path.dirname(file_path), exist_ok=True) | |
with open(file_path, "w") as f: | |
json.dump(responses, f) | |
print(f"Saved {query} to {file_path}") | |
return responses | |
def search_online(self): | |
""" Conducts the search for the given question. | |
Args: None | |
Returns: str: The search results for the given question | |
""" | |
self.search_summary = read_txt_files(self.dir_path) if os.path.isdir(self.dir_path) else "" | |
if not self.search_summary: | |
search_queries = self.create_search_queries() | |
for _, query in search_queries.items(): | |
search_result = self.run_search_summary(query) | |
self.search_summary += \ | |
f"=Query=:\n{query}\n=Search Result=:\n{search_result}\n================\n" | |
return self.search_summary | |
def write_report(self, report_type, extra_prompt): | |
""" Writes the report for the given question. | |
Args: None | |
Returns: str: The report for the given question | |
""" | |
report_type_func = prompts.get_report_by_type(report_type) | |
yield from self.call_agent_stream(report_type_func(self.question, | |
self.search_online(), | |
extra_prompt)) | |