abdullahmeda commited on
Commit
ad7df47
1 Parent(s): f0c73b9

added streaming functionality

Browse files
Files changed (1) hide show
  1. app.py +86 -15
app.py CHANGED
@@ -1,21 +1,68 @@
1
  import gradio as gr
2
 
 
 
 
 
 
3
  from langchain.chat_models import ChatOpenAI
4
  from langchain.chains import ConversationChain
 
5
  from langchain.memory import ConversationBufferMemory
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- def respond(openai_api_key, openai_model, message, buffer_memory, chat_history):
 
 
 
 
 
8
  conversation = ConversationChain(
9
  llm = ChatOpenAI(
10
- temperature=0.0,
11
- model=openai_model,
12
- openai_api_key=openai_api_key
 
 
 
13
  ),
14
  memory = buffer_memory
15
  )
16
- response = conversation.predict(input=message)
17
- chat_history.append([message, response])
18
- return "", buffer_memory, chat_history
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  with gr.Blocks(css="#component-0 { max-width: 900px; margin: auto; padding-top: 1.5rem; }") as demo:
@@ -24,12 +71,12 @@ with gr.Blocks(css="#component-0 { max-width: 900px; margin: auto; padding-top:
24
  openai_key = gr.Textbox(
25
  label="OpenAI Key",
26
  type="password",
27
- placeholder="sk-a83jv6fn3x8ndm78b5W..."
28
  )
29
  model = gr.Dropdown(
30
- ["gpt-4", "gpt-4-32k",
31
- "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-instruct",
32
- "text-davinci-002", "text-davinci-003"],
33
  label="OpenAI Model",
34
  value="gpt-3.5-turbo",
35
  interactive=True
@@ -53,9 +100,33 @@ with gr.Blocks(css="#component-0 { max-width: 900px; margin: auto; padding-top:
53
  scale=1,
54
  min_width=0)
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  # Event Handling
57
- query.submit(respond, [openai_key, model, query, memory, chatbot], [query, memory, chatbot])
58
- submit.click(respond, [openai_key, model, query, memory, chatbot], [query, memory, chatbot])
59
 
60
- if __name__ == "__main__":
61
- demo.launch()
 
1
  import gradio as gr
2
 
3
+ from threading import Thread
4
+ from queue import Queue, Empty
5
+ # from callbacks import StreamingGradioCallbackHandler, job_done
6
+
7
+ from langchain.schema import SystemMessage
8
  from langchain.chat_models import ChatOpenAI
9
  from langchain.chains import ConversationChain
10
+ from langchain.prompts import ChatPromptTemplate
11
  from langchain.memory import ConversationBufferMemory
12
+ from langchain.callbacks.base import BaseCallbackHandler
13
+
14
+ # huggingface.co/spaces/huggingface-projects/llama-2-13b-chat
15
+ DEFAULT_SYSTEM_PROMPT = """\
16
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. \
17
+ Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please \
18
+ ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or \
19
+ is not factually coherent, explain why instead of answering something not correct. If you don't know the answer \
20
+ to a question, please don't share false information."""
21
+
22
+ class QueueCallback(BaseCallbackHandler):
23
+ """Callback handler for streaming LLM responses to a queue."""
24
+
25
+ def __init__(self, q):
26
+ self.q = q
27
+
28
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
29
+ print(token)
30
+ self.q.put(token)
31
+
32
+ def on_llm_end(self, *args, **kwargs) -> None:
33
+ print("Done")
34
+ return self.q.empty()
35
 
36
+ def respond(openai_api_key, openai_model, creativity, max_tokens, message, buffer_memory, chat_history):
37
+ # print(buffer_memory.buffer)
38
+ chat_history.append([message, None])
39
+ q = Queue()
40
+ job_done = object()
41
+ callback = QueueCallback(q)
42
  conversation = ConversationChain(
43
  llm = ChatOpenAI(
44
+ model=openai_model,
45
+ max_tokens=max_tokens,
46
+ temperature=creativity,
47
+ openai_api_key=openai_api_key,
48
+ streaming=True,
49
+ callbacks=[callback]
50
  ),
51
  memory = buffer_memory
52
  )
53
+ def task():
54
+ resp = conversation.predict(input=message)
55
+ q.put(job_done)
56
+ thread = Thread(target=task)
57
+ thread.start()
58
+ chat_history[-1] = (chat_history[-1][0], "")
59
+ while True:
60
+ next_token = q.get(block=True) # Blocks until an input is available
61
+ if next_token is job_done:
62
+ break
63
+ chat_history[-1] = (chat_history[-1][0], chat_history[-1][1] + next_token)
64
+ yield "", buffer_memory, chat_history # Yield the chatbot's response as a string
65
+ thread.join()
66
 
67
 
68
  with gr.Blocks(css="#component-0 { max-width: 900px; margin: auto; padding-top: 1.5rem; }") as demo:
 
71
  openai_key = gr.Textbox(
72
  label="OpenAI Key",
73
  type="password",
74
+ placeholder="sk-a83jv6fn3x8ndm78b5W...",
75
  )
76
  model = gr.Dropdown(
77
+ ["gpt-4",
78
+ "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-instruct",
79
+ "text-davinci-002", "text-davinci-003"],
80
  label="OpenAI Model",
81
  value="gpt-3.5-turbo",
82
  interactive=True
 
100
  scale=1,
101
  min_width=0)
102
 
103
+ with gr.Accordion(label='Advanced options', open=False):
104
+ system_prompt = gr.Textbox(label='System prompt', value=DEFAULT_SYSTEM_PROMPT, lines=6)
105
+ max_new_tokens = gr.Slider(
106
+ label='Max new tokens',
107
+ minimum=1,
108
+ maximum=4096,
109
+ step=1,
110
+ value=2048,
111
+ )
112
+ temperature = gr.Slider(
113
+ label='Temperature',
114
+ minimum=0.0,
115
+ maximum=1.0,
116
+ step=0.1,
117
+ value=0.0,
118
+ )
119
+ memory_window = gr.Slider(
120
+ label='Converstaion Memory Window',
121
+ minimum=-1,
122
+ maximum=10,
123
+ step=1,
124
+ value=-1,
125
+ interactive=True
126
+ )
127
+
128
  # Event Handling
129
+ query.submit(respond, [openai_key, model, temperature, max_new_tokens, query, memory, chatbot], [query, memory, chatbot])
130
+ submit.click(respond, [openai_key, model, temperature, max_new_tokens, query, memory, chatbot], [query, memory, chatbot])
131
 
132
+ demo.queue().launch()