Rohan Kataria commited on
Commit
3193e49
1 Parent(s): c6e70f1

not focusing more on api state

Browse files
Files changed (2) hide show
  1. app.py +6 -7
  2. src/main.py +18 -11
app.py CHANGED
@@ -1,20 +1,19 @@
1
  import gradio as gr
2
  from src.main import ChatWrapper
3
 
4
- agent = ChatWrapper('openai') # default agent_state
5
 
6
  def update_agent(api_key: str, selection: str):
7
  global agent
8
- agent = ChatWrapper(chain_type=selection)
9
  return agent # This is agent state
10
 
11
- def chat(message, api_key):
12
  global agent
13
- agent(message, api_key) # Get a response to the current message
14
  history = agent.history # Access the entire chat history
15
  return history, history # Return the history twice to update both the chatbot and the state
16
 
17
-
18
  block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
19
 
20
  with block:
@@ -53,8 +52,8 @@ with block:
53
  state = gr.State()
54
  agent_state = gr.State()
55
 
56
- submit.click(chat, inputs=[message, api_key_textbox], outputs=[chatbot, state])
57
- message.submit(chat, inputs=[message, api_key_textbox], outputs=[chatbot, state])
58
 
59
  api_key_textbox.change(update_agent, inputs=[api_key_textbox, selection], outputs=[agent_state])
60
  selection.change(update_agent, inputs=[api_key_textbox, selection], outputs=[agent_state])
 
1
  import gradio as gr
2
  from src.main import ChatWrapper
3
 
4
+ agent = ChatWrapper('openai', '') # default agnet_state
5
 
6
  def update_agent(api_key: str, selection: str):
7
  global agent
8
+ agent = ChatWrapper(chain_type=selection, api_key=api_key)
9
  return agent # This is agent state
10
 
11
+ def chat(message):
12
  global agent
13
+ agent(message) # Get a response to the current message
14
  history = agent.history # Access the entire chat history
15
  return history, history # Return the history twice to update both the chatbot and the state
16
 
 
17
  block = gr.Blocks(css=".gradio-container {background-color: lightgray}")
18
 
19
  with block:
 
52
  state = gr.State()
53
  agent_state = gr.State()
54
 
55
+ submit.click(chat, inputs=[message], outputs=[chatbot, state])
56
+ message.submit(chat, inputs=[message], outputs=[chatbot, state])
57
 
58
  api_key_textbox.change(update_agent, inputs=[api_key_textbox, selection], outputs=[agent_state])
59
  selection.change(update_agent, inputs=[api_key_textbox, selection], outputs=[agent_state])
src/main.py CHANGED
@@ -21,23 +21,29 @@ def load_chain_falcon(api_key: str):
21
  return chain
22
 
23
  class ChatWrapper:
24
- def __init__(self, chain_type: str):
 
25
  self.chain_type = chain_type
26
  self.history = []
27
  self.lock = Lock()
28
- self.chain = None
29
 
30
- def __call__(self, inp: str, api_key: str = ''):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  self.lock.acquire()
32
  try:
33
- if api_key:
34
- if self.chain_type == 'openai':
35
- self.chain = load_chain_openai(api_key)
36
- elif self.chain_type == 'falcon':
37
- self.chain = load_chain_falcon(api_key)
38
- else:
39
- raise ValueError(f'Invalid chain_type: {self.chain_type}')
40
-
41
  if self.chain is None:
42
  self.history.append((inp, "Please add your API key to proceed."))
43
  return self.history
@@ -47,6 +53,7 @@ class ChatWrapper:
47
  except Exception as e:
48
  self.history.append((inp, f"An error occurred: {e}"))
49
  finally:
 
50
  self.lock.release()
51
 
52
  return self.history
 
21
  return chain
22
 
23
  class ChatWrapper:
24
+ def __init__(self, chain_type: str, api_key: str = ''):
25
+ self.api_key = api_key
26
  self.chain_type = chain_type
27
  self.history = []
28
  self.lock = Lock()
 
29
 
30
+ if self.api_key:
31
+ if chain_type == 'openai':
32
+ self.chain = load_chain_openai(self.api_key)
33
+ elif chain_type == 'falcon':
34
+ self.chain = load_chain_falcon(self.api_key)
35
+ else:
36
+ raise ValueError(f'Invalid chain_type: {chain_type}')
37
+ else:
38
+ self.chain = None
39
+
40
+ def clear_api_key(self):
41
+ if hasattr(self, 'api_key'):
42
+ del self.api_key
43
+
44
+ def __call__(self, inp: str):
45
  self.lock.acquire()
46
  try:
 
 
 
 
 
 
 
 
47
  if self.chain is None:
48
  self.history.append((inp, "Please add your API key to proceed."))
49
  return self.history
 
53
  except Exception as e:
54
  self.history.append((inp, f"An error occurred: {e}"))
55
  finally:
56
+ self.clear_api_key() # API key is cleared after running each chain in the class
57
  self.lock.release()
58
 
59
  return self.history