File size: 3,888 Bytes
4d7183d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import json
from actions.duck_search import duckduckgo_search
from processing.text import read_txt_files
from agent.llm_utils import llm_response, llm_stream_response
from config import Config
from agent import prompts
import os
import string

CFG = Config()


class ResearchAgent:
    def __init__(self, question, agent):
        """ Initializes the research assistant with the given question.
            Args: question (str): The question to research
            Returns: None
        """

        self.question = question
        self.agent = agent
        self.visited_urls = set()
        self.search_summary = ""
        self.directory_name = ''.join(c for c in question if c.isascii() and c not in string.punctuation)[:100]
        self.dir_path = os.path.dirname(f"./outputs/{self.directory_name}/")

    def call_agent(self, action):
        messages = [{
            "role": "system",
            "content": prompts.generate_agent_role_prompt(self.agent),
        }, {
            "role": "user",
            "content": action,
        }]
        return llm_response(
                    model=CFG.fast_llm_model,
                    messages=messages,
                )

    def call_agent_stream(self, action):
        messages = [{
            "role": "system",
            "content": prompts.generate_agent_role_prompt(self.agent),
        }, {
            "role": "user",
            "content": action,
        }]
        yield from llm_stream_response(
                model=CFG.fast_llm_model,
                messages=messages
            )

    def create_search_queries(self):
        """ Creates the search queries for the given question.
        Args: None
        Returns: list[str]: The search queries for the given question
        """
        result = self.call_agent(prompts.generate_search_queries_prompt(self.question))
        return json.loads(result)

    def search_single_query(self, query):
        """ Runs the async search for the given query.
        Args: query (str): The query to run the async search for
        Returns: list[str]: The async search for the given query
        """
        return duckduckgo_search(query, max_search_result=3)

    def run_search_summary(self, query):
        """ Runs the search summary for the given query.
        Args: query (str): The query to run the search summary for
        Returns: str: The search summary for the given query
        """
        responses = self.search_single_query(query)
        
        print(f"Searching for {query}")
        query = hash(query)
        file_path = f"./outputs/{self.directory_name}/research-{query}.txt"
        os.makedirs(os.path.dirname(file_path), exist_ok=True)
        with open(file_path, "w") as f:
            json.dump(responses, f)
            print(f"Saved {query} to {file_path}")
        return responses

    def search_online(self):
        """ Conducts the search for the given question.
        Args: None
        Returns: str: The search results for the given question
        """

        self.search_summary = read_txt_files(self.dir_path) if os.path.isdir(self.dir_path) else ""
        
        if not self.search_summary:
            search_queries = self.create_search_queries()
            for _, query in search_queries.items():
                search_result = self.run_search_summary(query)
                self.search_summary += f"=Query=:\n{query}\n=Search Result=:\n{search_result}\n================\n"

        return self.search_summary

    def write_report(self, report_type):
        """ Writes the report for the given question.
        Args: None
        Returns: str: The report for the given question
        """
        yield "Searching online..."

        report_type_func = prompts.get_report_by_type(report_type)
        
        yield from self.call_agent_stream(report_type_func(self.question, self.search_online()))