Spaces:
Running
on
T4
Running
on
T4
Rework demo UI.
Browse files
app.py
CHANGED
@@ -110,22 +110,22 @@ Arrange the given numbers in ascending order.
|
|
110 |
["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
|
111 |
]
|
112 |
|
113 |
-
infer_interface = gr.Interface(
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
).queue()
|
129 |
|
130 |
########################################################################################################
|
131 |
|
@@ -159,8 +159,12 @@ She also likes to tell {user} a lot about herself and her opinions, and she usua
|
|
159 |
|
160 |
_, intro_state = model.forward(pipeline.encode(chat_intro), None)
|
161 |
|
|
|
|
|
|
|
|
|
162 |
def chat(
|
163 |
-
|
164 |
history,
|
165 |
token_count=10,
|
166 |
temperature=1.0,
|
@@ -174,6 +178,7 @@ def chat(
|
|
174 |
token_ban=[], # ban the generation of some tokens
|
175 |
token_stop=[]) # stop generation whenever you see any token here
|
176 |
|
|
|
177 |
message = message.strip(' ')
|
178 |
message = message.replace('\n', '')
|
179 |
ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
|
@@ -181,9 +186,9 @@ def chat(
|
|
181 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
182 |
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
|
183 |
|
184 |
-
history = history or [
|
185 |
|
186 |
-
[
|
187 |
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
|
188 |
|
189 |
begin = len(all_tokens)
|
@@ -230,35 +235,80 @@ def chat(
|
|
230 |
gc.collect()
|
231 |
torch.cuda.empty_cache()
|
232 |
|
233 |
-
|
234 |
-
history = [
|
235 |
-
return
|
236 |
-
|
237 |
-
chat_interface = gr.Interface(
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
).queue()
|
255 |
|
256 |
########################################################################################################
|
257 |
|
258 |
-
demo = gr.TabbedInterface(
|
259 |
-
|
260 |
-
|
261 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
demo.queue(max_size=10)
|
264 |
-
demo.launch(share=
|
|
|
110 |
["Simply put, the theory of relativity states that", 150, 1.0, 0.5, 0.2, 0.2],
|
111 |
]
|
112 |
|
113 |
+
# infer_interface = gr.Interface(
|
114 |
+
# fn=infer,
|
115 |
+
# description=f'''{desc} <b>Please try examples first (bottom of page)</b> (edit them to use your question). Demo limited to ctxlen {ctx_limit}.''',
|
116 |
+
# allow_flagging="never",
|
117 |
+
# inputs=[
|
118 |
+
# gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n"), # prompt
|
119 |
+
# gr.Slider(10, 200, step=10, value=150), # token_count
|
120 |
+
# gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
|
121 |
+
# gr.Slider(0.0, 1.0, step=0.05, value=0.7), # top_p
|
122 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presencePenalty
|
123 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # countPenalty
|
124 |
+
# ],
|
125 |
+
# outputs=gr.Textbox(label="Generated Output", lines=28),
|
126 |
+
# examples=examples,
|
127 |
+
# cache_examples=False,
|
128 |
+
# ).queue()
|
129 |
|
130 |
########################################################################################################
|
131 |
|
|
|
159 |
|
160 |
_, intro_state = model.forward(pipeline.encode(chat_intro), None)
|
161 |
|
162 |
+
def user(user_message, chatbot):
|
163 |
+
chatbot = chatbot or []
|
164 |
+
return "", chatbot + [[user_message, None]]
|
165 |
+
|
166 |
def chat(
|
167 |
+
chatbot,
|
168 |
history,
|
169 |
token_count=10,
|
170 |
temperature=1.0,
|
|
|
178 |
token_ban=[], # ban the generation of some tokens
|
179 |
token_stop=[]) # stop generation whenever you see any token here
|
180 |
|
181 |
+
message = chatbot[-1][0]
|
182 |
message = message.strip(' ')
|
183 |
message = message.replace('\n', '')
|
184 |
ctx = f"{user}{interface} {message}\n\n{bot}{interface}"
|
|
|
186 |
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
187 |
print(f'vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}')
|
188 |
|
189 |
+
history = history or [intro_state, []] # [chat, state, all_tokens]
|
190 |
|
191 |
+
[state, all_tokens] = history
|
192 |
out, state = model.forward(pipeline.encode(ctx)[-ctx_limit:], state)
|
193 |
|
194 |
begin = len(all_tokens)
|
|
|
235 |
gc.collect()
|
236 |
torch.cuda.empty_cache()
|
237 |
|
238 |
+
chatbot[-1][1] = out_str.strip()
|
239 |
+
history = [state, all_tokens]
|
240 |
+
return chatbot, history
|
241 |
+
|
242 |
+
# chat_interface = gr.Interface(
|
243 |
+
# fn=chat,
|
244 |
+
# description=f'''You are {user}, bot is {bot}.''',
|
245 |
+
# allow_flagging="never",
|
246 |
+
# inputs = [
|
247 |
+
# gr.Textbox(label="Message"),
|
248 |
+
# "state",
|
249 |
+
# gr.Slider(10, 1000, step=10, value=250), # token_count
|
250 |
+
# gr.Slider(0.2, 2.0, step=0.1, value=1.0), # temperature
|
251 |
+
# gr.Slider(0.0, 1.0, step=0.05, value=0.8), # top_p
|
252 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # presence_penalty
|
253 |
+
# gr.Slider(0.0, 1.0, step=0.1, value=0.2), # count_penalty
|
254 |
+
# ],
|
255 |
+
# outputs=[
|
256 |
+
# gr.Chatbot(label="Chat Log", color_map=("blue", "pink")),
|
257 |
+
# "state"
|
258 |
+
# ]
|
259 |
+
# ).queue()
|
260 |
|
261 |
########################################################################################################
|
262 |
|
263 |
+
# demo = gr.TabbedInterface(
|
264 |
+
# [infer_interface, chat_interface], ["Generative", "Chat"],
|
265 |
+
# title=title,
|
266 |
+
# )
|
267 |
+
|
268 |
+
# demo.queue(max_size=10)
|
269 |
+
# demo.launch(share=True)
|
270 |
+
|
271 |
+
with gr.Blocks() as demo:
|
272 |
+
with gr.Tab("Generative"):
|
273 |
+
with gr.Row():
|
274 |
+
with gr.Column():
|
275 |
+
prompt = gr.Textbox(lines=10, label="Prompt", value="Here's a short cyberpunk sci-fi adventure story. The story's main character is an artificial human created by a company called OpenBot.\n\nThe Story:\n")
|
276 |
+
token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
|
277 |
+
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
278 |
+
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
|
279 |
+
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
|
280 |
+
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
|
281 |
+
with gr.Column():
|
282 |
+
with gr.Row():
|
283 |
+
submit = gr.Button("Submit")
|
284 |
+
clear = gr.Button("Clear")
|
285 |
+
output = gr.Textbox(label="Generated Output", lines=28)
|
286 |
+
data = gr.Dataset(components=[prompt, token_count, temperature, top_p, presence_penalty, count_penalty], samples=examples, label="Example Prompts", headers=["Prompt", "Max Tokens", "Temperature", "Top P", "Presence Penalty", "Count Penalty"])
|
287 |
+
submit.click(infer, [prompt, token_count, temperature, top_p, presence_penalty, count_penalty], [output])
|
288 |
+
clear.click(lambda: None, [], [output])
|
289 |
+
data.click(lambda x: x, [data], [prompt, token_count, temperature, top_p, presence_penalty, count_penalty])
|
290 |
+
with gr.Tab("Chat"):
|
291 |
+
with gr.Row():
|
292 |
+
with gr.Column():
|
293 |
+
chatbot = gr.Chatbot()
|
294 |
+
state = gr.State()
|
295 |
+
message = gr.Textbox(label="Message")
|
296 |
+
with gr.Row():
|
297 |
+
send = gr.Button("Send")
|
298 |
+
clear = gr.Button("Clear")
|
299 |
+
with gr.Column():
|
300 |
+
token_count = gr.Slider(10, 1000, label="Max Token", step=10, value=250)
|
301 |
+
temperature = gr.Slider(0.2, 2.0, label="Temperature", step=0.1, value=1.0)
|
302 |
+
top_p = gr.Slider(0.0, 1.0, label="Top P", step=0.05, value=0.8)
|
303 |
+
presence_penalty = gr.Slider(0.0, 1.0, label="Presence Penalty", step=0.1, value=0.2)
|
304 |
+
count_penalty = gr.Slider(0.0, 1.0, label="Count Penalty", step=0.1, value=0.2)
|
305 |
+
message.submit(user, [message, chatbot], [message, chatbot], queue=False).then(
|
306 |
+
chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
|
307 |
+
)
|
308 |
+
send.click(user, [message, chatbot], [message, chatbot], queue=False).then(
|
309 |
+
chat, [chatbot, state, token_count, temperature, top_p, presence_penalty, count_penalty], [chatbot, state]
|
310 |
+
)
|
311 |
+
clear.click(lambda: ([], None, ""), [], [chatbot, state, message])
|
312 |
|
313 |
demo.queue(max_size=10)
|
314 |
+
demo.launch(share=False)
|