Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -4,8 +4,152 @@ import argparse
|
|
4 |
import time
|
5 |
import subprocess
|
6 |
|
|
|
7 |
import llava.serve.gradio_web_server as gws
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# Execute the pip install command with additional options
|
10 |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
|
11 |
|
@@ -89,7 +233,7 @@ Set the environment variable `model` to change the model:
|
|
89 |
|
90 |
exit_status = 0
|
91 |
try:
|
92 |
-
demo =
|
93 |
demo.queue(
|
94 |
status_update_rate=10,
|
95 |
api_open=False
|
|
|
4 |
import time
|
5 |
import subprocess
|
6 |
|
7 |
+
import gradio as gr
|
8 |
import llava.serve.gradio_web_server as gws
|
9 |
|
10 |
+
def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
|
11 |
+
textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
|
12 |
+
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=gws.block_css) as demo:
|
13 |
+
state = gr.State()
|
14 |
+
|
15 |
+
if not embed_mode:
|
16 |
+
gr.Markdown(gws.title_markdown)
|
17 |
+
|
18 |
+
with gr.Row():
|
19 |
+
with gr.Column(scale=3):
|
20 |
+
with gr.Row(elem_id="model_selector_row"):
|
21 |
+
model_selector = gr.Dropdown(
|
22 |
+
choices=gws.models,
|
23 |
+
value=gws.models[0] if len(gws.models) > 0 else "",
|
24 |
+
interactive=True,
|
25 |
+
show_label=False,
|
26 |
+
container=False)
|
27 |
+
|
28 |
+
imagebox = gr.Image(type="pil")
|
29 |
+
image_process_mode = gr.Radio(
|
30 |
+
["Crop", "Resize", "Pad", "Default"],
|
31 |
+
value="Default",
|
32 |
+
label="Preprocess for non-square image", visible=False)
|
33 |
+
|
34 |
+
if cur_dir is None:
|
35 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
36 |
+
|
37 |
+
user_prompt = "Evaluate and explain if this chart is misleading"
|
38 |
+
gr.Examples(examples=[
|
39 |
+
[f"{cur_dir}/examples/bar_custom_1.png", user_prompt],
|
40 |
+
[f"{cur_dir}/examples/fox_news.jpeg", user_prompt],
|
41 |
+
[f"{cur_dir}/examples/bar_wiki.png", user_prompt],
|
42 |
+
], inputs=[imagebox, textbox])
|
43 |
+
|
44 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
45 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0, step=0.1, interactive=True, label="Temperature",)
|
46 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
|
47 |
+
max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
|
48 |
+
|
49 |
+
with gr.Column(scale=8):
|
50 |
+
chatbot = gr.Chatbot(
|
51 |
+
elem_id="chatbot",
|
52 |
+
label="LLaVA Chatbot",
|
53 |
+
height=650,
|
54 |
+
layout="panel",
|
55 |
+
)
|
56 |
+
with gr.Row():
|
57 |
+
with gr.Column(scale=8):
|
58 |
+
textbox.render()
|
59 |
+
with gr.Column(scale=1, min_width=50):
|
60 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
61 |
+
with gr.Row(elem_id="buttons") as button_row:
|
62 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
63 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
64 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
65 |
+
#stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
66 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
67 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
68 |
+
|
69 |
+
if not embed_mode:
|
70 |
+
gr.Markdown(gws.tos_markdown)
|
71 |
+
gr.Markdown(gws.learn_more_markdown)
|
72 |
+
url_params = gr.JSON(visible=False)
|
73 |
+
|
74 |
+
# Register listeners
|
75 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
76 |
+
upvote_btn.click(
|
77 |
+
gws.upvote_last_response,
|
78 |
+
[state, model_selector],
|
79 |
+
[textbox, upvote_btn, downvote_btn, flag_btn]
|
80 |
+
)
|
81 |
+
downvote_btn.click(
|
82 |
+
gws.downvote_last_response,
|
83 |
+
[state, model_selector],
|
84 |
+
[textbox, upvote_btn, downvote_btn, flag_btn]
|
85 |
+
)
|
86 |
+
flag_btn.click(
|
87 |
+
gws.flag_last_response,
|
88 |
+
[state, model_selector],
|
89 |
+
[textbox, upvote_btn, downvote_btn, flag_btn]
|
90 |
+
)
|
91 |
+
|
92 |
+
regenerate_btn.click(
|
93 |
+
gws.regenerate,
|
94 |
+
[state, image_process_mode],
|
95 |
+
[state, chatbot, textbox, imagebox] + btn_list
|
96 |
+
).then(
|
97 |
+
gws.http_bot,
|
98 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
99 |
+
[state, chatbot] + btn_list,
|
100 |
+
concurrency_limit=concurrency_count
|
101 |
+
)
|
102 |
+
|
103 |
+
clear_btn.click(
|
104 |
+
gws.clear_history,
|
105 |
+
None,
|
106 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
107 |
+
queue=False
|
108 |
+
)
|
109 |
+
|
110 |
+
textbox.submit(
|
111 |
+
gws.add_text,
|
112 |
+
[state, textbox, imagebox, image_process_mode],
|
113 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
114 |
+
queue=False
|
115 |
+
).then(
|
116 |
+
gws.http_bot,
|
117 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
118 |
+
[state, chatbot] + btn_list,
|
119 |
+
concurrency_limit=concurrency_count
|
120 |
+
)
|
121 |
+
|
122 |
+
submit_btn.click(
|
123 |
+
gws.add_text,
|
124 |
+
[state, textbox, imagebox, image_process_mode],
|
125 |
+
[state, chatbot, textbox, imagebox] + btn_list
|
126 |
+
).then(
|
127 |
+
gws.http_bot,
|
128 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
129 |
+
[state, chatbot] + btn_list,
|
130 |
+
concurrency_limit=concurrency_count
|
131 |
+
)
|
132 |
+
|
133 |
+
if gws.args.model_list_mode == "once":
|
134 |
+
demo.load(
|
135 |
+
gws.load_demo,
|
136 |
+
[url_params],
|
137 |
+
[state, model_selector],
|
138 |
+
js=gws.get_window_url_params
|
139 |
+
)
|
140 |
+
elif gws.args.model_list_mode == "reload":
|
141 |
+
demo.load(
|
142 |
+
gws.load_demo_refresh_model_list,
|
143 |
+
None,
|
144 |
+
[state, model_selector],
|
145 |
+
queue=False
|
146 |
+
)
|
147 |
+
else:
|
148 |
+
raise ValueError(f"Unknown model list mode: {gws.args.model_list_mode}")
|
149 |
+
|
150 |
+
return demo
|
151 |
+
|
152 |
+
|
153 |
# Execute the pip install command with additional options
|
154 |
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
|
155 |
|
|
|
233 |
|
234 |
exit_status = 0
|
235 |
try:
|
236 |
+
demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
|
237 |
demo.queue(
|
238 |
status_update_rate=10,
|
239 |
api_open=False
|