Spaces:
Running
on
A10G
Running
on
A10G
Shanshan Wang
commited on
Commit
•
e809d4e
1
Parent(s):
1757eeb
added conversations and parameter options
Browse files
app.py
CHANGED
@@ -3,13 +3,15 @@ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
|
|
3 |
import torch
|
4 |
import torchvision.transforms as T
|
5 |
from PIL import Image
|
|
|
6 |
|
|
|
7 |
from torchvision.transforms.functional import InterpolationMode
|
8 |
-
# Define the path to your model
|
9 |
import os
|
10 |
from huggingface_hub import login
|
11 |
hf_token = os.environ.get('hf_token', None)
|
12 |
|
|
|
13 |
path = "h2oai/h2o-mississippi-2b"
|
14 |
|
15 |
# image preprocesing
|
@@ -174,48 +176,151 @@ tokenizer.eos_token = "<|end|>"
|
|
174 |
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
175 |
|
176 |
|
177 |
-
def inference(image,
|
178 |
-
#
|
179 |
-
if
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
# Set generation config
|
|
|
|
|
|
|
186 |
generation_config = dict(
|
187 |
num_beams=1,
|
188 |
-
max_new_tokens=
|
189 |
-
do_sample=
|
190 |
-
temperature=temperature,
|
191 |
-
top_p=top_p,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
-
|
|
|
203 |
|
204 |
|
205 |
# Build the Gradio interface
|
206 |
with gr.Blocks() as demo:
|
207 |
-
gr.Markdown("H2O-Mississippi")
|
|
|
|
|
|
|
|
|
208 |
|
209 |
with gr.Row():
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
with gr.Accordion('Parameters', open=False):
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
with gr.Row():
|
220 |
submit_button = gr.Button("Submit")
|
221 |
regenerate_button = gr.Button("Regenerate")
|
@@ -225,24 +330,21 @@ with gr.Blocks() as demo:
|
|
225 |
# When the submit button is clicked, call the inference function
|
226 |
submit_button.click(
|
227 |
fn=inference,
|
228 |
-
inputs=[image_input,
|
229 |
-
outputs=
|
230 |
)
|
231 |
# When the regenerate button is clicked, re-run the last inference
|
232 |
regenerate_button.click(
|
233 |
-
fn=
|
234 |
-
inputs=[
|
235 |
-
outputs=
|
236 |
)
|
237 |
|
238 |
-
# Define the clear button action
|
239 |
-
def clear_all():
|
240 |
-
return None, "", ""
|
241 |
|
242 |
clear_button.click(
|
243 |
fn=clear_all,
|
244 |
inputs=None,
|
245 |
-
outputs=[
|
246 |
)
|
247 |
|
248 |
demo.launch()
|
|
|
3 |
import torch
|
4 |
import torchvision.transforms as T
|
5 |
from PIL import Image
|
6 |
+
import logging
|
7 |
|
8 |
+
logging.basicConfig(level=logging.INFO)
|
9 |
from torchvision.transforms.functional import InterpolationMode
|
|
|
10 |
import os
|
11 |
from huggingface_hub import login
|
12 |
hf_token = os.environ.get('hf_token', None)
|
13 |
|
14 |
+
# Define the path to your model
|
15 |
path = "h2oai/h2o-mississippi-2b"
|
16 |
|
17 |
# image preprocesing
|
|
|
176 |
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
177 |
|
178 |
|
179 |
+
def inference(image, user_message, temperature, top_p, max_new_tokens, chatbot,state, image_state):
|
180 |
+
# if image is provided, store it in image_state:
|
181 |
+
if chatbot is None:
|
182 |
+
chatbot = []
|
183 |
+
|
184 |
+
if image is not None:
|
185 |
+
image_state = load_image_msac(image)
|
186 |
+
else:
|
187 |
+
# If image_state is None, then no image has been provided yet
|
188 |
+
if image_state is None:
|
189 |
+
chatbot.append(("System", "Please provide an image to start the conversation."))
|
190 |
+
return chatbot, state, image_state, ""
|
191 |
+
|
192 |
+
# Initialize history (state) if it's None
|
193 |
+
if state is None:
|
194 |
+
state = None # model.chat function handles None as empty history
|
195 |
+
|
196 |
+
# Append user message to chatbot
|
197 |
+
chatbot.append((user_message, None))
|
198 |
|
199 |
# Set generation config
|
200 |
+
do_sample = (float(temperature) != 0.0)
|
201 |
+
|
202 |
+
|
203 |
generation_config = dict(
|
204 |
num_beams=1,
|
205 |
+
max_new_tokens=int(max_new_tokens),
|
206 |
+
do_sample=do_sample,
|
207 |
+
temperature= float(temperature),
|
208 |
+
top_p= float(top_p),
|
209 |
+
)
|
210 |
+
|
211 |
+
# Call model.chat with history
|
212 |
+
response_text, new_state = model.chat(
|
213 |
+
tokenizer,
|
214 |
+
image_state,
|
215 |
+
user_message,
|
216 |
+
generation_config=generation_config,
|
217 |
+
history=state,
|
218 |
+
return_history=True
|
219 |
)
|
220 |
+
|
221 |
+
# update the satet with new_state
|
222 |
+
state = new_state
|
223 |
+
# Update chatbot with the model's response
|
224 |
+
chatbot[-1] = (user_message, response_text)
|
225 |
+
|
226 |
+
return chatbot, state, image_state, ""
|
227 |
+
|
228 |
+
def regenerate_response(chatbot, temperature, top_p, max_new_tokens, state, image_state):
|
229 |
+
|
230 |
+
# Check if there is a previous user message
|
231 |
+
if chatbot is None or len(chatbot) == 0:
|
232 |
+
chatbot = []
|
233 |
+
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
|
234 |
+
return chatbot, state, image_state
|
235 |
+
|
236 |
+
# Check if there is a previous user message
|
237 |
+
if state is None or image_state is None or len(state) == 0:
|
238 |
+
chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
|
239 |
+
return chatbot, state, image_state
|
240 |
+
|
241 |
+
# Get the last user message
|
242 |
+
last_user_message, last_response = chatbot[-1]
|
243 |
+
|
244 |
+
state = state[:-1] # Remove last assistant's response from history
|
245 |
+
|
246 |
+
if len(state) == 0:
|
247 |
+
state = None
|
248 |
+
# Set generation config
|
249 |
+
do_sample = (float(temperature) != 0.0)
|
250 |
|
251 |
+
generation_config = dict(
|
252 |
+
num_beams=1,
|
253 |
+
max_new_tokens=int(max_new_tokens),
|
254 |
+
do_sample=do_sample,
|
255 |
+
temperature= float(temperature),
|
256 |
+
top_p= float(top_p),
|
257 |
)
|
258 |
+
# Regenerate the response
|
259 |
+
response_text, new_state = model.chat(
|
260 |
+
tokenizer,
|
261 |
+
image_state,
|
262 |
+
last_user_message,
|
263 |
+
generation_config=generation_config,
|
264 |
+
history=state, # Exclude last assistant's response
|
265 |
+
return_history=True
|
266 |
+
)
|
267 |
+
|
268 |
+
# Update the state with new_state
|
269 |
+
state = new_state
|
270 |
+
|
271 |
+
# Update chatbot with the regenerated response
|
272 |
+
chatbot.append((last_user_message, response_text))
|
273 |
+
|
274 |
+
return chatbot, state, image_state
|
275 |
+
|
276 |
|
277 |
+
def clear_all():
|
278 |
+
return [], None, None, None # Clear chatbot, state, image_state, image_input
|
279 |
|
280 |
|
281 |
# Build the Gradio interface
|
282 |
with gr.Blocks() as demo:
|
283 |
+
gr.Markdown("# **H2O-Mississippi**")
|
284 |
+
|
285 |
+
state= gr.State()
|
286 |
+
image_state = gr.State()
|
287 |
+
|
288 |
|
289 |
with gr.Row():
|
290 |
+
# First column with image input
|
291 |
+
with gr.Column(scale=1):
|
292 |
+
image_input = gr.Image(type="pil", label="Upload an Image")
|
293 |
+
|
294 |
+
# Second column with chatbot and user input
|
295 |
+
with gr.Column(scale=2):
|
296 |
+
chatbot = gr.Chatbot(label="Conversation")
|
297 |
+
user_input = gr.Textbox(label="What is your question", placeholder="Type your message here")
|
298 |
+
|
299 |
|
300 |
with gr.Accordion('Parameters', open=False):
|
301 |
+
with gr.Row():
|
302 |
+
temperature_input = gr.Slider(
|
303 |
+
minimum=0.0,
|
304 |
+
maximum=1.0,
|
305 |
+
step=0.1,
|
306 |
+
value=0.0,
|
307 |
+
interactive=True,
|
308 |
+
label="Temperature")
|
309 |
+
top_p_input = gr.Slider(
|
310 |
+
minimum=0.0,
|
311 |
+
maximum=1.0,
|
312 |
+
step=0.1,
|
313 |
+
value=0.9,
|
314 |
+
interactive=True,
|
315 |
+
label="Top P")
|
316 |
+
max_new_tokens_input = gr.Slider(
|
317 |
+
minimum=0,
|
318 |
+
maximum=4096,
|
319 |
+
step=64,
|
320 |
+
value=1024,
|
321 |
+
interactive=True,
|
322 |
+
label="Max New Tokens (default: 1024)"
|
323 |
+
)
|
324 |
with gr.Row():
|
325 |
submit_button = gr.Button("Submit")
|
326 |
regenerate_button = gr.Button("Regenerate")
|
|
|
330 |
# When the submit button is clicked, call the inference function
|
331 |
submit_button.click(
|
332 |
fn=inference,
|
333 |
+
inputs=[image_input, user_input, temperature_input, top_p_input, max_new_tokens_input, chatbot, state, image_state],
|
334 |
+
outputs=[chatbot, state, image_state, user_input]
|
335 |
)
|
336 |
# When the regenerate button is clicked, re-run the last inference
|
337 |
regenerate_button.click(
|
338 |
+
fn=regenerate_response,
|
339 |
+
inputs=[chatbot, temperature_input, top_p_input,max_new_tokens_input, state, image_state],
|
340 |
+
outputs=[chatbot, state, image_state]
|
341 |
)
|
342 |
|
|
|
|
|
|
|
343 |
|
344 |
clear_button.click(
|
345 |
fn=clear_all,
|
346 |
inputs=None,
|
347 |
+
outputs=[chatbot, state, image_state, image_input]
|
348 |
)
|
349 |
|
350 |
demo.launch()
|