Spaces:
Running
Running
omkarenator
commited on
Commit
•
f893655
1
Parent(s):
a685d80
Make ActionRunner accept custom LLMs (#13)
Browse files- autoagents/agents/search.py +5 -12
- autoagents/models/custom.py +33 -0
- autoagents/spaces/app.py +16 -11
- autoagents/tools/tools.py +2 -7
- autoagents/utils/utils.py +0 -8
- test.py +11 -6
autoagents/agents/search.py
CHANGED
@@ -14,11 +14,10 @@ from langchain.schema import AgentAction, AgentFinish
|
|
14 |
from langchain.callbacks import get_openai_callback
|
15 |
from langchain.callbacks.base import AsyncCallbackHandler
|
16 |
from langchain.callbacks.manager import AsyncCallbackManager
|
17 |
-
|
18 |
|
19 |
from autoagents.tools.tools import search_tool, note_tool, rewrite_search_query
|
20 |
from autoagents.utils.logger import InteractionsLogger
|
21 |
-
from autoagents.utils.utils import OpenAICred
|
22 |
|
23 |
|
24 |
# Set up the base template
|
@@ -124,9 +123,8 @@ class CustomOutputParser(AgentOutputParser):
|
|
124 |
class Config:
|
125 |
arbitrary_types_allowed = True
|
126 |
ialogger: InteractionsLogger
|
127 |
-
|
128 |
new_action_input: Optional[str]
|
129 |
-
|
130 |
action_history = defaultdict(set)
|
131 |
|
132 |
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
|
@@ -154,7 +152,7 @@ class CustomOutputParser(AgentOutputParser):
|
|
154 |
if action_input in self.action_history[action]:
|
155 |
new_action_input = rewrite_search_query(action_input,
|
156 |
self.action_history[action],
|
157 |
-
|
158 |
self.ialogger.add_message({"query_rewrite": True})
|
159 |
self.new_action_input = new_action_input
|
160 |
self.action_history[action].add(new_action_input)
|
@@ -168,8 +166,7 @@ class CustomOutputParser(AgentOutputParser):
|
|
168 |
class ActionRunner:
|
169 |
def __init__(self,
|
170 |
outputq,
|
171 |
-
|
172 |
-
model_name: str,
|
173 |
persist_logs: bool = False):
|
174 |
self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs)
|
175 |
tools = [search_tool, note_tool]
|
@@ -179,7 +176,7 @@ class ActionRunner:
|
|
179 |
input_variables=["input", "intermediate_steps"],
|
180 |
ialogger=self.ialogger)
|
181 |
|
182 |
-
output_parser = CustomOutputParser(ialogger=self.ialogger,
|
183 |
|
184 |
class MyCustomHandler(AsyncCallbackHandler):
|
185 |
def __init__(self):
|
@@ -225,10 +222,6 @@ class ActionRunner:
|
|
225 |
|
226 |
handler = MyCustomHandler()
|
227 |
|
228 |
-
llm = ChatOpenAI(openai_api_key=cred.key,
|
229 |
-
openai_organization=cred.org,
|
230 |
-
temperature=0,
|
231 |
-
model_name=model_name)
|
232 |
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
|
233 |
tool_names = [tool.name for tool in tools]
|
234 |
for tool in tools:
|
|
|
14 |
from langchain.callbacks import get_openai_callback
|
15 |
from langchain.callbacks.base import AsyncCallbackHandler
|
16 |
from langchain.callbacks.manager import AsyncCallbackManager
|
17 |
+
from langchain.base_language import BaseLanguageModel
|
18 |
|
19 |
from autoagents.tools.tools import search_tool, note_tool, rewrite_search_query
|
20 |
from autoagents.utils.logger import InteractionsLogger
|
|
|
21 |
|
22 |
|
23 |
# Set up the base template
|
|
|
123 |
class Config:
|
124 |
arbitrary_types_allowed = True
|
125 |
ialogger: InteractionsLogger
|
126 |
+
llm: BaseLanguageModel
|
127 |
new_action_input: Optional[str]
|
|
|
128 |
action_history = defaultdict(set)
|
129 |
|
130 |
def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
|
|
|
152 |
if action_input in self.action_history[action]:
|
153 |
new_action_input = rewrite_search_query(action_input,
|
154 |
self.action_history[action],
|
155 |
+
self.llm)
|
156 |
self.ialogger.add_message({"query_rewrite": True})
|
157 |
self.new_action_input = new_action_input
|
158 |
self.action_history[action].add(new_action_input)
|
|
|
166 |
class ActionRunner:
|
167 |
def __init__(self,
|
168 |
outputq,
|
169 |
+
llm: BaseLanguageModel,
|
|
|
170 |
persist_logs: bool = False):
|
171 |
self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs)
|
172 |
tools = [search_tool, note_tool]
|
|
|
176 |
input_variables=["input", "intermediate_steps"],
|
177 |
ialogger=self.ialogger)
|
178 |
|
179 |
+
output_parser = CustomOutputParser(ialogger=self.ialogger, llm=llm)
|
180 |
|
181 |
class MyCustomHandler(AsyncCallbackHandler):
|
182 |
def __init__(self):
|
|
|
222 |
|
223 |
handler = MyCustomHandler()
|
224 |
|
|
|
|
|
|
|
|
|
225 |
llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
|
226 |
tool_names = [tool.name for tool in tools]
|
227 |
for tool in tools:
|
autoagents/models/custom.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
from langchain.llms.base import LLM
|
4 |
+
|
5 |
+
|
6 |
+
class CustomLLM(LLM):
|
7 |
+
@property
|
8 |
+
def _llm_type(self) -> str:
|
9 |
+
return "custom"
|
10 |
+
|
11 |
+
def _call(self, prompt: str, stop=None) -> str:
|
12 |
+
r = requests.post(
|
13 |
+
"http://localhost:8000/v1/chat/completions",
|
14 |
+
json={
|
15 |
+
"model": "283-vicuna-7b",
|
16 |
+
"messages": [{"role": "user", "content": prompt}],
|
17 |
+
"stop": stop
|
18 |
+
},
|
19 |
+
)
|
20 |
+
result = r.json()
|
21 |
+
return result["choices"][0]["message"]["content"]
|
22 |
+
|
23 |
+
async def _acall(self, prompt: str, stop=None) -> str:
|
24 |
+
r = requests.post(
|
25 |
+
"http://localhost:8000/v1/chat/completions",
|
26 |
+
json={
|
27 |
+
"model": "283-vicuna-7b",
|
28 |
+
"messages": [{"role": "user", "content": prompt}],
|
29 |
+
"stop": stop
|
30 |
+
},
|
31 |
+
)
|
32 |
+
result = r.json()
|
33 |
+
return result["choices"][0]["message"]["content"]
|
autoagents/spaces/app.py
CHANGED
@@ -9,7 +9,9 @@ import openai
|
|
9 |
|
10 |
from autoagents.utils.constants import MAIN_HEADER, MAIN_CAPTION, SAMPLE_QUESTIONS
|
11 |
from autoagents.agents.search import ActionRunner
|
12 |
-
|
|
|
|
|
13 |
|
14 |
async def run():
|
15 |
output_acc = ""
|
@@ -44,10 +46,9 @@ async def run():
|
|
44 |
|
45 |
# Ask the user to enter their OpenAI API key
|
46 |
if (api_key := st.sidebar.text_input("OpenAI api-key", type="password")):
|
47 |
-
|
48 |
else:
|
49 |
-
|
50 |
-
os.getenv("OPENAI_API_ORG"))
|
51 |
with st.sidebar:
|
52 |
model_dict = {
|
53 |
"gpt-3.5-turbo": "GPT-3.5-turbo",
|
@@ -67,18 +68,22 @@ async def run():
|
|
67 |
for q in SAMPLE_QUESTIONS:
|
68 |
st.markdown(f"*{q}*")
|
69 |
|
70 |
-
if not
|
71 |
st.warning(
|
72 |
"API key required to try this app. The API key is not stored in any form. [This](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) might help."
|
73 |
)
|
|
|
|
|
|
|
|
|
74 |
else:
|
75 |
outputq = asyncio.Queue()
|
76 |
-
runner = ActionRunner(
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
|
83 |
async def cleanup(e):
|
84 |
st.error(e)
|
|
|
9 |
|
10 |
from autoagents.utils.constants import MAIN_HEADER, MAIN_CAPTION, SAMPLE_QUESTIONS
|
11 |
from autoagents.agents.search import ActionRunner
|
12 |
+
|
13 |
+
from langchain.chat_models import ChatOpenAI
|
14 |
+
|
15 |
|
16 |
async def run():
|
17 |
output_acc = ""
|
|
|
46 |
|
47 |
# Ask the user to enter their OpenAI API key
|
48 |
if (api_key := st.sidebar.text_input("OpenAI api-key", type="password")):
|
49 |
+
api_org = None
|
50 |
else:
|
51 |
+
api_key, api_org = os.getenv("OPENAI_API_KEY"), os.getenv("OPENAI_API_ORG")
|
|
|
52 |
with st.sidebar:
|
53 |
model_dict = {
|
54 |
"gpt-3.5-turbo": "GPT-3.5-turbo",
|
|
|
68 |
for q in SAMPLE_QUESTIONS:
|
69 |
st.markdown(f"*{q}*")
|
70 |
|
71 |
+
if not api_key:
|
72 |
st.warning(
|
73 |
"API key required to try this app. The API key is not stored in any form. [This](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) might help."
|
74 |
)
|
75 |
+
elif api_org and st.session_state.model_name == "gpt-4":
|
76 |
+
st.warning(
|
77 |
+
"The free API key does not support GPT-4. Please switch to GPT-3.5-turbo or input your own API key."
|
78 |
+
)
|
79 |
else:
|
80 |
outputq = asyncio.Queue()
|
81 |
+
runner = ActionRunner(outputq,
|
82 |
+
ChatOpenAI(openai_api_key=api_key,
|
83 |
+
openai_organization=api_org,
|
84 |
+
temperature=0,
|
85 |
+
model_name=st.session_state.model_name),
|
86 |
+
persist_logs=True) # log to HF-dataset
|
87 |
|
88 |
async def cleanup(e):
|
89 |
st.error(e)
|
autoagents/tools/tools.py
CHANGED
@@ -3,9 +3,7 @@ import os
|
|
3 |
from duckpy import Client
|
4 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
5 |
from langchain.agents import Tool
|
6 |
-
from langchain.
|
7 |
-
|
8 |
-
from autoagents.utils.utils import OpenAICred
|
9 |
|
10 |
|
11 |
MAX_SEARCH_RESULTS = 20 # Number of search results to observe at a time
|
@@ -55,15 +53,12 @@ note_tool = Tool(name="Notepad",
|
|
55 |
description=notepad_description)
|
56 |
|
57 |
|
58 |
-
def rewrite_search_query(q: str, search_history,
|
59 |
history_string = '\n'.join(search_history)
|
60 |
template ="""We are using the Search tool.
|
61 |
# Previous queries:
|
62 |
{history_string}. \n\n Rewrite query {action_input} to be
|
63 |
different from the previous ones."""
|
64 |
-
llm = ChatOpenAI(temperature=0,
|
65 |
-
openai_api_key=cred.key,
|
66 |
-
openai_organization=cred.org)
|
67 |
prompt = PromptTemplate(template=template,
|
68 |
input_variables=["action_input", "history_string"])
|
69 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
|
|
3 |
from duckpy import Client
|
4 |
from langchain import PromptTemplate, OpenAI, LLMChain
|
5 |
from langchain.agents import Tool
|
6 |
+
from langchain.base_language import BaseLanguageModel
|
|
|
|
|
7 |
|
8 |
|
9 |
MAX_SEARCH_RESULTS = 20 # Number of search results to observe at a time
|
|
|
53 |
description=notepad_description)
|
54 |
|
55 |
|
56 |
+
def rewrite_search_query(q: str, search_history, llm: BaseLanguageModel) -> str:
|
57 |
history_string = '\n'.join(search_history)
|
58 |
template ="""We are using the Search tool.
|
59 |
# Previous queries:
|
60 |
{history_string}. \n\n Rewrite query {action_input} to be
|
61 |
different from the previous ones."""
|
|
|
|
|
|
|
62 |
prompt = PromptTemplate(template=template,
|
63 |
input_variables=["action_input", "history_string"])
|
64 |
llm_chain = LLMChain(prompt=prompt, llm=llm)
|
autoagents/utils/utils.py
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import Optional
|
3 |
-
|
4 |
-
|
5 |
-
@dataclass
|
6 |
-
class OpenAICred:
|
7 |
-
key: str
|
8 |
-
org: Optional[str]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.py
CHANGED
@@ -1,17 +1,22 @@
|
|
1 |
import os
|
2 |
import asyncio
|
3 |
-
|
4 |
-
from langchain.callbacks import get_openai_callback
|
5 |
from pprint import pprint
|
6 |
-
import pdb
|
7 |
from ast import literal_eval
|
8 |
from multiprocessing import Pool, TimeoutError
|
9 |
|
|
|
|
|
|
|
|
|
|
|
10 |
async def work(user_input):
|
11 |
outputq = asyncio.Queue()
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
task = asyncio.create_task(runner.run(user_input, outputq))
|
16 |
|
17 |
while True:
|
|
|
1 |
import os
|
2 |
import asyncio
|
3 |
+
|
|
|
4 |
from pprint import pprint
|
|
|
5 |
from ast import literal_eval
|
6 |
from multiprocessing import Pool, TimeoutError
|
7 |
|
8 |
+
from autoagents.agents.search import ActionRunner
|
9 |
+
from langchain.callbacks import get_openai_callback
|
10 |
+
from langchain.chat_models import ChatOpenAI
|
11 |
+
|
12 |
+
|
13 |
async def work(user_input):
|
14 |
outputq = asyncio.Queue()
|
15 |
+
llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"),
|
16 |
+
openai_organization=os.getenv("OPENAI_API_ORG"),
|
17 |
+
temperature=0,
|
18 |
+
model_name="gpt-3.5-turbo")
|
19 |
+
runner = ActionRunner(outputq, llm=llm)
|
20 |
task = asyncio.create_task(runner.run(user_input, outputq))
|
21 |
|
22 |
while True:
|