omkarenator commited on
Commit
f893655
1 Parent(s): a685d80

Make ActionRunner accept custom LLMs (#13)

Browse files
autoagents/agents/search.py CHANGED
@@ -14,11 +14,10 @@ from langchain.schema import AgentAction, AgentFinish
14
  from langchain.callbacks import get_openai_callback
15
  from langchain.callbacks.base import AsyncCallbackHandler
16
  from langchain.callbacks.manager import AsyncCallbackManager
17
-
18
 
19
  from autoagents.tools.tools import search_tool, note_tool, rewrite_search_query
20
  from autoagents.utils.logger import InteractionsLogger
21
- from autoagents.utils.utils import OpenAICred
22
 
23
 
24
  # Set up the base template
@@ -124,9 +123,8 @@ class CustomOutputParser(AgentOutputParser):
124
  class Config:
125
  arbitrary_types_allowed = True
126
  ialogger: InteractionsLogger
127
- cred: OpenAICred
128
  new_action_input: Optional[str]
129
-
130
  action_history = defaultdict(set)
131
 
132
  def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
@@ -154,7 +152,7 @@ class CustomOutputParser(AgentOutputParser):
154
  if action_input in self.action_history[action]:
155
  new_action_input = rewrite_search_query(action_input,
156
  self.action_history[action],
157
- cred)
158
  self.ialogger.add_message({"query_rewrite": True})
159
  self.new_action_input = new_action_input
160
  self.action_history[action].add(new_action_input)
@@ -168,8 +166,7 @@ class CustomOutputParser(AgentOutputParser):
168
  class ActionRunner:
169
  def __init__(self,
170
  outputq,
171
- cred: OpenAICred,
172
- model_name: str,
173
  persist_logs: bool = False):
174
  self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs)
175
  tools = [search_tool, note_tool]
@@ -179,7 +176,7 @@ class ActionRunner:
179
  input_variables=["input", "intermediate_steps"],
180
  ialogger=self.ialogger)
181
 
182
- output_parser = CustomOutputParser(ialogger=self.ialogger, cred=cred)
183
 
184
  class MyCustomHandler(AsyncCallbackHandler):
185
  def __init__(self):
@@ -225,10 +222,6 @@ class ActionRunner:
225
 
226
  handler = MyCustomHandler()
227
 
228
- llm = ChatOpenAI(openai_api_key=cred.key,
229
- openai_organization=cred.org,
230
- temperature=0,
231
- model_name=model_name)
232
  llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
233
  tool_names = [tool.name for tool in tools]
234
  for tool in tools:
 
14
  from langchain.callbacks import get_openai_callback
15
  from langchain.callbacks.base import AsyncCallbackHandler
16
  from langchain.callbacks.manager import AsyncCallbackManager
17
+ from langchain.base_language import BaseLanguageModel
18
 
19
  from autoagents.tools.tools import search_tool, note_tool, rewrite_search_query
20
  from autoagents.utils.logger import InteractionsLogger
 
21
 
22
 
23
  # Set up the base template
 
123
  class Config:
124
  arbitrary_types_allowed = True
125
  ialogger: InteractionsLogger
126
+ llm: BaseLanguageModel
127
  new_action_input: Optional[str]
 
128
  action_history = defaultdict(set)
129
 
130
  def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:
 
152
  if action_input in self.action_history[action]:
153
  new_action_input = rewrite_search_query(action_input,
154
  self.action_history[action],
155
+ self.llm)
156
  self.ialogger.add_message({"query_rewrite": True})
157
  self.new_action_input = new_action_input
158
  self.action_history[action].add(new_action_input)
 
166
  class ActionRunner:
167
  def __init__(self,
168
  outputq,
169
+ llm: BaseLanguageModel,
 
170
  persist_logs: bool = False):
171
  self.ialogger = InteractionsLogger(name=f"{uuid.uuid4().hex[:6]}", persist=persist_logs)
172
  tools = [search_tool, note_tool]
 
176
  input_variables=["input", "intermediate_steps"],
177
  ialogger=self.ialogger)
178
 
179
+ output_parser = CustomOutputParser(ialogger=self.ialogger, llm=llm)
180
 
181
  class MyCustomHandler(AsyncCallbackHandler):
182
  def __init__(self):
 
222
 
223
  handler = MyCustomHandler()
224
 
 
 
 
 
225
  llm_chain = LLMChain(llm=llm, prompt=prompt, callbacks=[handler])
226
  tool_names = [tool.name for tool in tools]
227
  for tool in tools:
autoagents/models/custom.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ from langchain.llms.base import LLM
4
+
5
+
6
+ class CustomLLM(LLM):
7
+ @property
8
+ def _llm_type(self) -> str:
9
+ return "custom"
10
+
11
+ def _call(self, prompt: str, stop=None) -> str:
12
+ r = requests.post(
13
+ "http://localhost:8000/v1/chat/completions",
14
+ json={
15
+ "model": "283-vicuna-7b",
16
+ "messages": [{"role": "user", "content": prompt}],
17
+ "stop": stop
18
+ },
19
+ )
20
+ result = r.json()
21
+ return result["choices"][0]["message"]["content"]
22
+
23
+ async def _acall(self, prompt: str, stop=None) -> str:
24
+ r = requests.post(
25
+ "http://localhost:8000/v1/chat/completions",
26
+ json={
27
+ "model": "283-vicuna-7b",
28
+ "messages": [{"role": "user", "content": prompt}],
29
+ "stop": stop
30
+ },
31
+ )
32
+ result = r.json()
33
+ return result["choices"][0]["message"]["content"]
autoagents/spaces/app.py CHANGED
@@ -9,7 +9,9 @@ import openai
9
 
10
  from autoagents.utils.constants import MAIN_HEADER, MAIN_CAPTION, SAMPLE_QUESTIONS
11
  from autoagents.agents.search import ActionRunner
12
- from autoagents.utils.utils import OpenAICred
 
 
13
 
14
  async def run():
15
  output_acc = ""
@@ -44,10 +46,9 @@ async def run():
44
 
45
  # Ask the user to enter their OpenAI API key
46
  if (api_key := st.sidebar.text_input("OpenAI api-key", type="password")):
47
- cred = OpenAICred(api_key, None)
48
  else:
49
- cred = OpenAICred(os.getenv("OPENAI_API_KEY"),
50
- os.getenv("OPENAI_API_ORG"))
51
  with st.sidebar:
52
  model_dict = {
53
  "gpt-3.5-turbo": "GPT-3.5-turbo",
@@ -67,18 +68,22 @@ async def run():
67
  for q in SAMPLE_QUESTIONS:
68
  st.markdown(f"*{q}*")
69
 
70
- if not cred.key:
71
  st.warning(
72
  "API key required to try this app. The API key is not stored in any form. [This](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) might help."
73
  )
 
 
 
 
74
  else:
75
  outputq = asyncio.Queue()
76
- runner = ActionRunner(
77
- outputq,
78
- cred=cred,
79
- model_name=st.session_state.model_name,
80
- persist_logs=True,
81
- ) # log to HF-dataset
82
 
83
  async def cleanup(e):
84
  st.error(e)
 
9
 
10
  from autoagents.utils.constants import MAIN_HEADER, MAIN_CAPTION, SAMPLE_QUESTIONS
11
  from autoagents.agents.search import ActionRunner
12
+
13
+ from langchain.chat_models import ChatOpenAI
14
+
15
 
16
  async def run():
17
  output_acc = ""
 
46
 
47
  # Ask the user to enter their OpenAI API key
48
  if (api_key := st.sidebar.text_input("OpenAI api-key", type="password")):
49
+ api_org = None
50
  else:
51
+ api_key, api_org = os.getenv("OPENAI_API_KEY"), os.getenv("OPENAI_API_ORG")
 
52
  with st.sidebar:
53
  model_dict = {
54
  "gpt-3.5-turbo": "GPT-3.5-turbo",
 
68
  for q in SAMPLE_QUESTIONS:
69
  st.markdown(f"*{q}*")
70
 
71
+ if not api_key:
72
  st.warning(
73
  "API key required to try this app. The API key is not stored in any form. [This](https://help.openai.com/en/articles/4936850-where-do-i-find-my-secret-api-key) might help."
74
  )
75
+ elif api_org and st.session_state.model_name == "gpt-4":
76
+ st.warning(
77
+ "The free API key does not support GPT-4. Please switch to GPT-3.5-turbo or input your own API key."
78
+ )
79
  else:
80
  outputq = asyncio.Queue()
81
+ runner = ActionRunner(outputq,
82
+ ChatOpenAI(openai_api_key=api_key,
83
+ openai_organization=api_org,
84
+ temperature=0,
85
+ model_name=st.session_state.model_name),
86
+ persist_logs=True) # log to HF-dataset
87
 
88
  async def cleanup(e):
89
  st.error(e)
autoagents/tools/tools.py CHANGED
@@ -3,9 +3,7 @@ import os
3
  from duckpy import Client
4
  from langchain import PromptTemplate, OpenAI, LLMChain
5
  from langchain.agents import Tool
6
- from langchain.chat_models import ChatOpenAI
7
-
8
- from autoagents.utils.utils import OpenAICred
9
 
10
 
11
  MAX_SEARCH_RESULTS = 20 # Number of search results to observe at a time
@@ -55,15 +53,12 @@ note_tool = Tool(name="Notepad",
55
  description=notepad_description)
56
 
57
 
58
- def rewrite_search_query(q: str, search_history, cred: OpenAICred) -> str:
59
  history_string = '\n'.join(search_history)
60
  template ="""We are using the Search tool.
61
  # Previous queries:
62
  {history_string}. \n\n Rewrite query {action_input} to be
63
  different from the previous ones."""
64
- llm = ChatOpenAI(temperature=0,
65
- openai_api_key=cred.key,
66
- openai_organization=cred.org)
67
  prompt = PromptTemplate(template=template,
68
  input_variables=["action_input", "history_string"])
69
  llm_chain = LLMChain(prompt=prompt, llm=llm)
 
3
  from duckpy import Client
4
  from langchain import PromptTemplate, OpenAI, LLMChain
5
  from langchain.agents import Tool
6
+ from langchain.base_language import BaseLanguageModel
 
 
7
 
8
 
9
  MAX_SEARCH_RESULTS = 20 # Number of search results to observe at a time
 
53
  description=notepad_description)
54
 
55
 
56
+ def rewrite_search_query(q: str, search_history, llm: BaseLanguageModel) -> str:
57
  history_string = '\n'.join(search_history)
58
  template ="""We are using the Search tool.
59
  # Previous queries:
60
  {history_string}. \n\n Rewrite query {action_input} to be
61
  different from the previous ones."""
 
 
 
62
  prompt = PromptTemplate(template=template,
63
  input_variables=["action_input", "history_string"])
64
  llm_chain = LLMChain(prompt=prompt, llm=llm)
autoagents/utils/utils.py DELETED
@@ -1,8 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
3
-
4
-
5
- @dataclass
6
- class OpenAICred:
7
- key: str
8
- org: Optional[str]
 
 
 
 
 
 
 
 
 
test.py CHANGED
@@ -1,17 +1,22 @@
1
  import os
2
  import asyncio
3
- from autoagents.agents.search import ActionRunner
4
- from langchain.callbacks import get_openai_callback
5
  from pprint import pprint
6
- import pdb
7
  from ast import literal_eval
8
  from multiprocessing import Pool, TimeoutError
9
 
 
 
 
 
 
10
  async def work(user_input):
11
  outputq = asyncio.Queue()
12
-
13
- API_O = os.getenv("OPENAI_API_KEY")
14
- runner = ActionRunner(outputq, api_key=API_O, model_name="gpt-3.5-turbo")
 
 
15
  task = asyncio.create_task(runner.run(user_input, outputq))
16
 
17
  while True:
 
1
  import os
2
  import asyncio
3
+
 
4
  from pprint import pprint
 
5
  from ast import literal_eval
6
  from multiprocessing import Pool, TimeoutError
7
 
8
+ from autoagents.agents.search import ActionRunner
9
+ from langchain.callbacks import get_openai_callback
10
+ from langchain.chat_models import ChatOpenAI
11
+
12
+
13
  async def work(user_input):
14
  outputq = asyncio.Queue()
15
+ llm = ChatOpenAI(openai_api_key=os.getenv("OPENAI_API_KEY"),
16
+ openai_organization=os.getenv("OPENAI_API_ORG"),
17
+ temperature=0,
18
+ model_name="gpt-3.5-turbo")
19
+ runner = ActionRunner(outputq, llm=llm)
20
  task = asyncio.create_task(runner.run(user_input, outputq))
21
 
22
  while True: