Spaces:
Running
on
A10G
Running
on
A10G
Shanshan Wang
commited on
Commit
•
7826ae6
1
Parent(s):
73b2bf3
clean up image_state
Browse files
app.py
CHANGED
@@ -11,8 +11,6 @@ import os
|
|
11 |
from huggingface_hub import login
|
12 |
hf_token = os.environ.get('hf_token', None)
|
13 |
|
14 |
-
# # Define the path to your model
|
15 |
-
# path = "h2oai/h2ovl-mississippi-2b"
|
16 |
|
17 |
# Define the models and their paths
|
18 |
model_paths = {
|
@@ -45,21 +43,22 @@ def load_model_and_set_image_function(model_name):
|
|
45 |
return model, tokenizer
|
46 |
|
47 |
|
48 |
-
def inference(
|
49 |
user_message,
|
50 |
temperature,
|
51 |
top_p,
|
52 |
max_new_tokens,
|
53 |
tile_num,
|
54 |
-
chatbot,
|
55 |
-
|
|
|
56 |
model_state,
|
57 |
tokenizer_state):
|
58 |
|
59 |
# Check if model_state is None
|
60 |
if model_state is None or tokenizer_state is None:
|
61 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
62 |
-
return chatbot, state,
|
63 |
|
64 |
model = model_state
|
65 |
tokenizer = tokenizer_state
|
@@ -69,13 +68,9 @@ def inference(image,
|
|
69 |
if chatbot is None:
|
70 |
chatbot = []
|
71 |
|
72 |
-
if
|
73 |
-
|
74 |
-
|
75 |
-
# If image_state is None, then no image has been provided yet
|
76 |
-
if image_state is None:
|
77 |
-
chatbot.append(("System", "Please provide an image to start the conversation."))
|
78 |
-
return chatbot, state, image_state, ""
|
79 |
|
80 |
# Initialize history (state) if it's None
|
81 |
if state is None:
|
@@ -99,7 +94,7 @@ def inference(image,
|
|
99 |
# Call model.chat with history
|
100 |
response_text, new_state = model.chat(
|
101 |
tokenizer,
|
102 |
-
|
103 |
user_message,
|
104 |
max_tiles = int(tile_num),
|
105 |
generation_config=generation_config,
|
@@ -112,7 +107,7 @@ def inference(image,
|
|
112 |
# Update chatbot with the model's response
|
113 |
chatbot[-1] = (user_message, response_text)
|
114 |
|
115 |
-
return chatbot, state,
|
116 |
|
117 |
def regenerate_response(chatbot,
|
118 |
temperature,
|
@@ -120,14 +115,14 @@ def regenerate_response(chatbot,
|
|
120 |
max_new_tokens,
|
121 |
tile_num,
|
122 |
state,
|
123 |
-
|
124 |
model_state,
|
125 |
tokenizer_state):
|
126 |
|
127 |
# Check if model_state is None
|
128 |
if model_state is None or tokenizer_state is None:
|
129 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
130 |
-
return chatbot, state
|
131 |
|
132 |
model = model_state
|
133 |
tokenizer = tokenizer_state
|
@@ -137,19 +132,19 @@ def regenerate_response(chatbot,
|
|
137 |
if chatbot is None or len(chatbot) == 0:
|
138 |
chatbot = []
|
139 |
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
|
140 |
-
return chatbot, state,
|
141 |
|
142 |
# Check if there is a previous user message
|
143 |
-
if state is None or
|
144 |
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
|
145 |
-
return chatbot, state
|
146 |
|
147 |
# Get the last user message
|
148 |
-
last_user_message,
|
149 |
|
150 |
state = state[:-1] # Remove last assistant's response from history
|
151 |
|
152 |
-
if len(state) == 0:
|
153 |
state = None
|
154 |
# Set generation config
|
155 |
do_sample = (float(temperature) != 0.0)
|
@@ -164,7 +159,7 @@ def regenerate_response(chatbot,
|
|
164 |
# Regenerate the response
|
165 |
response_text, new_state = model.chat(
|
166 |
tokenizer,
|
167 |
-
|
168 |
last_user_message,
|
169 |
max_tiles = int(tile_num),
|
170 |
generation_config=generation_config,
|
@@ -178,19 +173,17 @@ def regenerate_response(chatbot,
|
|
178 |
# Update chatbot with the regenerated response
|
179 |
chatbot.append((last_user_message, response_text))
|
180 |
|
181 |
-
return chatbot, state
|
182 |
|
183 |
|
184 |
def clear_all():
|
185 |
-
return [], None, None,
|
186 |
-
|
187 |
|
188 |
# Build the Gradio interface
|
189 |
with gr.Blocks() as demo:
|
190 |
gr.Markdown("# **H2OVL-Mississippi**")
|
191 |
|
192 |
state= gr.State()
|
193 |
-
image_state = gr.State()
|
194 |
model_state = gr.State()
|
195 |
tokenizer_state = gr.State()
|
196 |
image_load_function_state = gr.State()
|
@@ -212,12 +205,12 @@ with gr.Blocks() as demo:
|
|
212 |
# First column with image input
|
213 |
with gr.Column(scale=1):
|
214 |
image_input = gr.Image(type="filepath", label="Upload an Image")
|
|
|
215 |
|
216 |
# Second column with chatbot and user input
|
217 |
with gr.Column(scale=2):
|
218 |
chatbot = gr.Chatbot(label="Conversation")
|
219 |
user_input = gr.Textbox(label="What is your question", placeholder="Type your message here")
|
220 |
-
|
221 |
|
222 |
with gr.Accordion('Parameters', open=False):
|
223 |
with gr.Row():
|
@@ -268,11 +261,10 @@ with gr.Blocks() as demo:
|
|
268 |
tile_num,
|
269 |
chatbot,
|
270 |
state,
|
271 |
-
image_state,
|
272 |
model_state,
|
273 |
tokenizer_state
|
274 |
],
|
275 |
-
outputs=[chatbot, state,
|
276 |
)
|
277 |
# When the regenerate button is clicked, re-run the last inference
|
278 |
regenerate_button.click(
|
@@ -283,18 +275,18 @@ with gr.Blocks() as demo:
|
|
283 |
top_p_input,
|
284 |
max_new_tokens_input,
|
285 |
tile_num,
|
286 |
-
state,
|
287 |
-
|
288 |
model_state,
|
289 |
tokenizer_state,
|
290 |
],
|
291 |
-
outputs=[chatbot, state
|
292 |
)
|
293 |
|
294 |
clear_button.click(
|
295 |
fn=clear_all,
|
296 |
inputs=None,
|
297 |
-
outputs=[chatbot, state,
|
298 |
)
|
299 |
gr.Examples(
|
300 |
examples=[
|
@@ -307,4 +299,4 @@ with gr.Blocks() as demo:
|
|
307 |
label = "examples",
|
308 |
)
|
309 |
|
310 |
-
demo.launch()
|
|
|
11 |
from huggingface_hub import login
|
12 |
hf_token = os.environ.get('hf_token', None)
|
13 |
|
|
|
|
|
14 |
|
15 |
# Define the models and their paths
|
16 |
model_paths = {
|
|
|
43 |
return model, tokenizer
|
44 |
|
45 |
|
46 |
+
def inference(image_input,
|
47 |
user_message,
|
48 |
temperature,
|
49 |
top_p,
|
50 |
max_new_tokens,
|
51 |
tile_num,
|
52 |
+
chatbot,
|
53 |
+
state,
|
54 |
+
# image_state,
|
55 |
model_state,
|
56 |
tokenizer_state):
|
57 |
|
58 |
# Check if model_state is None
|
59 |
if model_state is None or tokenizer_state is None:
|
60 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
61 |
+
return chatbot, state, ""
|
62 |
|
63 |
model = model_state
|
64 |
tokenizer = tokenizer_state
|
|
|
68 |
if chatbot is None:
|
69 |
chatbot = []
|
70 |
|
71 |
+
if image_input is None:
|
72 |
+
chatbot.append(("System", "Please provide an image to start the conversation."))
|
73 |
+
return chatbot, state, ""
|
|
|
|
|
|
|
|
|
74 |
|
75 |
# Initialize history (state) if it's None
|
76 |
if state is None:
|
|
|
94 |
# Call model.chat with history
|
95 |
response_text, new_state = model.chat(
|
96 |
tokenizer,
|
97 |
+
image_input,
|
98 |
user_message,
|
99 |
max_tiles = int(tile_num),
|
100 |
generation_config=generation_config,
|
|
|
107 |
# Update chatbot with the model's response
|
108 |
chatbot[-1] = (user_message, response_text)
|
109 |
|
110 |
+
return chatbot, state, ""
|
111 |
|
112 |
def regenerate_response(chatbot,
|
113 |
temperature,
|
|
|
115 |
max_new_tokens,
|
116 |
tile_num,
|
117 |
state,
|
118 |
+
image_input,
|
119 |
model_state,
|
120 |
tokenizer_state):
|
121 |
|
122 |
# Check if model_state is None
|
123 |
if model_state is None or tokenizer_state is None:
|
124 |
chatbot.append(("System", "Please select a model to start the conversation."))
|
125 |
+
return chatbot, state
|
126 |
|
127 |
model = model_state
|
128 |
tokenizer = tokenizer_state
|
|
|
132 |
if chatbot is None or len(chatbot) == 0:
|
133 |
chatbot = []
|
134 |
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
|
135 |
+
return chatbot, state,
|
136 |
|
137 |
# Check if there is a previous user message
|
138 |
+
if state is None or len(state) == 0:
|
139 |
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
|
140 |
+
return chatbot, state
|
141 |
|
142 |
# Get the last user message
|
143 |
+
last_user_message, _ = chatbot[-1]
|
144 |
|
145 |
state = state[:-1] # Remove last assistant's response from history
|
146 |
|
147 |
+
if len(state) == 0 or not state:
|
148 |
state = None
|
149 |
# Set generation config
|
150 |
do_sample = (float(temperature) != 0.0)
|
|
|
159 |
# Regenerate the response
|
160 |
response_text, new_state = model.chat(
|
161 |
tokenizer,
|
162 |
+
image_input,
|
163 |
last_user_message,
|
164 |
max_tiles = int(tile_num),
|
165 |
generation_config=generation_config,
|
|
|
173 |
# Update chatbot with the regenerated response
|
174 |
chatbot.append((last_user_message, response_text))
|
175 |
|
176 |
+
return chatbot, state
|
177 |
|
178 |
|
179 |
def clear_all():
|
180 |
+
return [], None, None, "" # Clear chatbot, state, reset image_input
|
|
|
181 |
|
182 |
# Build the Gradio interface
|
183 |
with gr.Blocks() as demo:
|
184 |
gr.Markdown("# **H2OVL-Mississippi**")
|
185 |
|
186 |
state= gr.State()
|
|
|
187 |
model_state = gr.State()
|
188 |
tokenizer_state = gr.State()
|
189 |
image_load_function_state = gr.State()
|
|
|
205 |
# First column with image input
|
206 |
with gr.Column(scale=1):
|
207 |
image_input = gr.Image(type="filepath", label="Upload an Image")
|
208 |
+
|
209 |
|
210 |
# Second column with chatbot and user input
|
211 |
with gr.Column(scale=2):
|
212 |
chatbot = gr.Chatbot(label="Conversation")
|
213 |
user_input = gr.Textbox(label="What is your question", placeholder="Type your message here")
|
|
|
214 |
|
215 |
with gr.Accordion('Parameters', open=False):
|
216 |
with gr.Row():
|
|
|
261 |
tile_num,
|
262 |
chatbot,
|
263 |
state,
|
|
|
264 |
model_state,
|
265 |
tokenizer_state
|
266 |
],
|
267 |
+
outputs=[chatbot, state, user_input]
|
268 |
)
|
269 |
# When the regenerate button is clicked, re-run the last inference
|
270 |
regenerate_button.click(
|
|
|
275 |
top_p_input,
|
276 |
max_new_tokens_input,
|
277 |
tile_num,
|
278 |
+
state,
|
279 |
+
image_input,
|
280 |
model_state,
|
281 |
tokenizer_state,
|
282 |
],
|
283 |
+
outputs=[chatbot, state]
|
284 |
)
|
285 |
|
286 |
clear_button.click(
|
287 |
fn=clear_all,
|
288 |
inputs=None,
|
289 |
+
outputs=[chatbot, state, image_input, user_input]
|
290 |
)
|
291 |
gr.Examples(
|
292 |
examples=[
|
|
|
299 |
label = "examples",
|
300 |
)
|
301 |
|
302 |
+
demo.launch()
|