erwannd commited on
Commit
337fc0b
1 Parent(s): ff6a7c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -1
app.py CHANGED
@@ -4,8 +4,152 @@ import argparse
4
  import time
5
  import subprocess
6
 
 
7
  import llava.serve.gradio_web_server as gws
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Execute the pip install command with additional options
10
  subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
11
 
@@ -89,7 +233,7 @@ Set the environment variable `model` to change the model:
89
 
90
  exit_status = 0
91
  try:
92
- demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
93
  demo.queue(
94
  status_update_rate=10,
95
  api_open=False
 
4
  import time
5
  import subprocess
6
 
7
+ import gradio as gr
8
  import llava.serve.gradio_web_server as gws
9
 
10
+ def build_demo(embed_mode, cur_dir=None, concurrency_count=10):
11
+ textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
12
+ with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=gws.block_css) as demo:
13
+ state = gr.State()
14
+
15
+ if not embed_mode:
16
+ gr.Markdown(gws.title_markdown)
17
+
18
+ with gr.Row():
19
+ with gr.Column(scale=3):
20
+ with gr.Row(elem_id="model_selector_row"):
21
+ model_selector = gr.Dropdown(
22
+ choices=gws.models,
23
+ value=gws.models[0] if len(gws.models) > 0 else "",
24
+ interactive=True,
25
+ show_label=False,
26
+ container=False)
27
+
28
+ imagebox = gr.Image(type="pil")
29
+ image_process_mode = gr.Radio(
30
+ ["Crop", "Resize", "Pad", "Default"],
31
+ value="Default",
32
+ label="Preprocess for non-square image", visible=False)
33
+
34
+ if cur_dir is None:
35
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
36
+
37
+ user_prompt = "Evaluate and explain if this chart is misleading"
38
+ gr.Examples(examples=[
39
+ [f"{cur_dir}/examples/bar_custom_1.png", user_prompt],
40
+ [f"{cur_dir}/examples/fox_news.jpeg", user_prompt],
41
+ [f"{cur_dir}/examples/bar_wiki.png", user_prompt],
42
+ ], inputs=[imagebox, textbox])
43
+
44
+ with gr.Accordion("Parameters", open=False) as parameter_row:
45
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0, step=0.1, interactive=True, label="Temperature",)
46
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
47
+ max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
48
+
49
+ with gr.Column(scale=8):
50
+ chatbot = gr.Chatbot(
51
+ elem_id="chatbot",
52
+ label="LLaVA Chatbot",
53
+ height=650,
54
+ layout="panel",
55
+ )
56
+ with gr.Row():
57
+ with gr.Column(scale=8):
58
+ textbox.render()
59
+ with gr.Column(scale=1, min_width=50):
60
+ submit_btn = gr.Button(value="Send", variant="primary")
61
+ with gr.Row(elem_id="buttons") as button_row:
62
+ upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
63
+ downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
64
+ flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
65
+ #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
66
+ regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
67
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
68
+
69
+ if not embed_mode:
70
+ gr.Markdown(gws.tos_markdown)
71
+ gr.Markdown(gws.learn_more_markdown)
72
+ url_params = gr.JSON(visible=False)
73
+
74
+ # Register listeners
75
+ btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
76
+ upvote_btn.click(
77
+ gws.upvote_last_response,
78
+ [state, model_selector],
79
+ [textbox, upvote_btn, downvote_btn, flag_btn]
80
+ )
81
+ downvote_btn.click(
82
+ gws.downvote_last_response,
83
+ [state, model_selector],
84
+ [textbox, upvote_btn, downvote_btn, flag_btn]
85
+ )
86
+ flag_btn.click(
87
+ gws.flag_last_response,
88
+ [state, model_selector],
89
+ [textbox, upvote_btn, downvote_btn, flag_btn]
90
+ )
91
+
92
+ regenerate_btn.click(
93
+ gws.regenerate,
94
+ [state, image_process_mode],
95
+ [state, chatbot, textbox, imagebox] + btn_list
96
+ ).then(
97
+ gws.http_bot,
98
+ [state, model_selector, temperature, top_p, max_output_tokens],
99
+ [state, chatbot] + btn_list,
100
+ concurrency_limit=concurrency_count
101
+ )
102
+
103
+ clear_btn.click(
104
+ gws.clear_history,
105
+ None,
106
+ [state, chatbot, textbox, imagebox] + btn_list,
107
+ queue=False
108
+ )
109
+
110
+ textbox.submit(
111
+ gws.add_text,
112
+ [state, textbox, imagebox, image_process_mode],
113
+ [state, chatbot, textbox, imagebox] + btn_list,
114
+ queue=False
115
+ ).then(
116
+ gws.http_bot,
117
+ [state, model_selector, temperature, top_p, max_output_tokens],
118
+ [state, chatbot] + btn_list,
119
+ concurrency_limit=concurrency_count
120
+ )
121
+
122
+ submit_btn.click(
123
+ gws.add_text,
124
+ [state, textbox, imagebox, image_process_mode],
125
+ [state, chatbot, textbox, imagebox] + btn_list
126
+ ).then(
127
+ gws.http_bot,
128
+ [state, model_selector, temperature, top_p, max_output_tokens],
129
+ [state, chatbot] + btn_list,
130
+ concurrency_limit=concurrency_count
131
+ )
132
+
133
+ if gws.args.model_list_mode == "once":
134
+ demo.load(
135
+ gws.load_demo,
136
+ [url_params],
137
+ [state, model_selector],
138
+ js=gws.get_window_url_params
139
+ )
140
+ elif gws.args.model_list_mode == "reload":
141
+ demo.load(
142
+ gws.load_demo_refresh_model_list,
143
+ None,
144
+ [state, model_selector],
145
+ queue=False
146
+ )
147
+ else:
148
+ raise ValueError(f"Unknown model list mode: {gws.args.model_list_mode}")
149
+
150
+ return demo
151
+
152
+
153
  # Execute the pip install command with additional options
154
  subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
155
 
 
233
 
234
  exit_status = 0
235
  try:
236
+ demo = build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
237
  demo.queue(
238
  status_update_rate=10,
239
  api_open=False