Shanshan Wang commited on
Commit
e809d4e
1 Parent(s): 1757eeb

added conversations and parameter options

Browse files
Files changed (1) hide show
  1. app.py +138 -36
app.py CHANGED
@@ -3,13 +3,15 @@ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
3
  import torch
4
  import torchvision.transforms as T
5
  from PIL import Image
 
6
 
 
7
  from torchvision.transforms.functional import InterpolationMode
8
- # Define the path to your model
9
  import os
10
  from huggingface_hub import login
11
  hf_token = os.environ.get('hf_token', None)
12
 
 
13
  path = "h2oai/h2o-mississippi-2b"
14
 
15
  # image preprocesing
@@ -174,48 +176,151 @@ tokenizer.eos_token = "<|end|>"
174
  model.generation_config.pad_token_id = tokenizer.pad_token_id
175
 
176
 
177
- def inference(image, prompt, temperature, top_p):
178
- # Check if both image and prompt are provided
179
- if image is None or prompt.strip() == "":
180
- return "Please provide both an image and a prompt."
181
-
182
- # Process the image and get pixel_values
183
- pixel_values = load_image_msac(image)
 
 
 
 
 
 
 
 
 
 
 
 
184
 
185
  # Set generation config
 
 
 
186
  generation_config = dict(
187
  num_beams=1,
188
- max_new_tokens=2048,
189
- do_sample=False,
190
- temperature=temperature,
191
- top_p=top_p,
 
 
 
 
 
 
 
 
 
 
192
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
194
- # Generate the response
195
- response = model.chat(
196
- tokenizer,
197
- pixel_values,
198
- prompt,
199
- generation_config
200
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
- return response
 
203
 
204
 
205
  # Build the Gradio interface
206
  with gr.Blocks() as demo:
207
- gr.Markdown("H2O-Mississippi")
 
 
 
 
208
 
209
  with gr.Row():
210
- image_input = gr.Image(type="pil", label="Upload an Image")
211
- prompt_input = gr.Textbox(label="Enter your prompt here")
 
 
 
 
 
 
 
212
 
213
  with gr.Accordion('Parameters', open=False):
214
- temperature_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, default=1.0, label="Temperature")
215
- top_p_input = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, default=0.9, label="Top-p")
216
-
217
- response_output = gr.Textbox(label="Model Response")
218
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  with gr.Row():
220
  submit_button = gr.Button("Submit")
221
  regenerate_button = gr.Button("Regenerate")
@@ -225,24 +330,21 @@ with gr.Blocks() as demo:
225
  # When the submit button is clicked, call the inference function
226
  submit_button.click(
227
  fn=inference,
228
- inputs=[image_input, prompt_input, temperature_input, top_p_input],
229
- outputs=response_output
230
  )
231
  # When the regenerate button is clicked, re-run the last inference
232
  regenerate_button.click(
233
- fn=inference,
234
- inputs=[image_input, prompt_input, temperature_input, top_p_input],
235
- outputs=response_output
236
  )
237
 
238
- # Define the clear button action
239
- def clear_all():
240
- return None, "", ""
241
 
242
  clear_button.click(
243
  fn=clear_all,
244
  inputs=None,
245
- outputs=[image_input, prompt_input, response_output]
246
  )
247
 
248
  demo.launch()
 
3
  import torch
4
  import torchvision.transforms as T
5
  from PIL import Image
6
+ import logging
7
 
8
+ logging.basicConfig(level=logging.INFO)
9
  from torchvision.transforms.functional import InterpolationMode
 
10
  import os
11
  from huggingface_hub import login
12
  hf_token = os.environ.get('hf_token', None)
13
 
14
+ # Define the path to your model
15
  path = "h2oai/h2o-mississippi-2b"
16
 
17
  # image preprocesing
 
176
  model.generation_config.pad_token_id = tokenizer.pad_token_id
177
 
178
 
179
+ def inference(image, user_message, temperature, top_p, max_new_tokens, chatbot,state, image_state):
180
+ # if image is provided, store it in image_state:
181
+ if chatbot is None:
182
+ chatbot = []
183
+
184
+ if image is not None:
185
+ image_state = load_image_msac(image)
186
+ else:
187
+ # If image_state is None, then no image has been provided yet
188
+ if image_state is None:
189
+ chatbot.append(("System", "Please provide an image to start the conversation."))
190
+ return chatbot, state, image_state, ""
191
+
192
+ # Initialize history (state) if it's None
193
+ if state is None:
194
+ state = None # model.chat function handles None as empty history
195
+
196
+ # Append user message to chatbot
197
+ chatbot.append((user_message, None))
198
 
199
  # Set generation config
200
+ do_sample = (float(temperature) != 0.0)
201
+
202
+
203
  generation_config = dict(
204
  num_beams=1,
205
+ max_new_tokens=int(max_new_tokens),
206
+ do_sample=do_sample,
207
+ temperature= float(temperature),
208
+ top_p= float(top_p),
209
+ )
210
+
211
+ # Call model.chat with history
212
+ response_text, new_state = model.chat(
213
+ tokenizer,
214
+ image_state,
215
+ user_message,
216
+ generation_config=generation_config,
217
+ history=state,
218
+ return_history=True
219
  )
220
+
221
+ # update the satet with new_state
222
+ state = new_state
223
+ # Update chatbot with the model's response
224
+ chatbot[-1] = (user_message, response_text)
225
+
226
+ return chatbot, state, image_state, ""
227
+
228
+ def regenerate_response(chatbot, temperature, top_p, max_new_tokens, state, image_state):
229
+
230
+ # Check if there is a previous user message
231
+ if chatbot is None or len(chatbot) == 0:
232
+ chatbot = []
233
+ chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
234
+ return chatbot, state, image_state
235
+
236
+ # Check if there is a previous user message
237
+ if state is None or image_state is None or len(state) == 0:
238
+ chatbot.append(("System", "Nothing to regenerate. Please start a conversation first."))
239
+ return chatbot, state, image_state
240
+
241
+ # Get the last user message
242
+ last_user_message, last_response = chatbot[-1]
243
+
244
+ state = state[:-1] # Remove last assistant's response from history
245
+
246
+ if len(state) == 0:
247
+ state = None
248
+ # Set generation config
249
+ do_sample = (float(temperature) != 0.0)
250
 
251
+ generation_config = dict(
252
+ num_beams=1,
253
+ max_new_tokens=int(max_new_tokens),
254
+ do_sample=do_sample,
255
+ temperature= float(temperature),
256
+ top_p= float(top_p),
257
  )
258
+ # Regenerate the response
259
+ response_text, new_state = model.chat(
260
+ tokenizer,
261
+ image_state,
262
+ last_user_message,
263
+ generation_config=generation_config,
264
+ history=state, # Exclude last assistant's response
265
+ return_history=True
266
+ )
267
+
268
+ # Update the state with new_state
269
+ state = new_state
270
+
271
+ # Update chatbot with the regenerated response
272
+ chatbot.append((last_user_message, response_text))
273
+
274
+ return chatbot, state, image_state
275
+
276
 
277
+ def clear_all():
278
+ return [], None, None, None # Clear chatbot, state, image_state, image_input
279
 
280
 
281
  # Build the Gradio interface
282
  with gr.Blocks() as demo:
283
+ gr.Markdown("# **H2O-Mississippi**")
284
+
285
+ state= gr.State()
286
+ image_state = gr.State()
287
+
288
 
289
  with gr.Row():
290
+ # First column with image input
291
+ with gr.Column(scale=1):
292
+ image_input = gr.Image(type="pil", label="Upload an Image")
293
+
294
+ # Second column with chatbot and user input
295
+ with gr.Column(scale=2):
296
+ chatbot = gr.Chatbot(label="Conversation")
297
+ user_input = gr.Textbox(label="What is your question", placeholder="Type your message here")
298
+
299
 
300
  with gr.Accordion('Parameters', open=False):
301
+ with gr.Row():
302
+ temperature_input = gr.Slider(
303
+ minimum=0.0,
304
+ maximum=1.0,
305
+ step=0.1,
306
+ value=0.0,
307
+ interactive=True,
308
+ label="Temperature")
309
+ top_p_input = gr.Slider(
310
+ minimum=0.0,
311
+ maximum=1.0,
312
+ step=0.1,
313
+ value=0.9,
314
+ interactive=True,
315
+ label="Top P")
316
+ max_new_tokens_input = gr.Slider(
317
+ minimum=0,
318
+ maximum=4096,
319
+ step=64,
320
+ value=1024,
321
+ interactive=True,
322
+ label="Max New Tokens (default: 1024)"
323
+ )
324
  with gr.Row():
325
  submit_button = gr.Button("Submit")
326
  regenerate_button = gr.Button("Regenerate")
 
330
  # When the submit button is clicked, call the inference function
331
  submit_button.click(
332
  fn=inference,
333
+ inputs=[image_input, user_input, temperature_input, top_p_input, max_new_tokens_input, chatbot, state, image_state],
334
+ outputs=[chatbot, state, image_state, user_input]
335
  )
336
  # When the regenerate button is clicked, re-run the last inference
337
  regenerate_button.click(
338
+ fn=regenerate_response,
339
+ inputs=[chatbot, temperature_input, top_p_input,max_new_tokens_input, state, image_state],
340
+ outputs=[chatbot, state, image_state]
341
  )
342
 
 
 
 
343
 
344
  clear_button.click(
345
  fn=clear_all,
346
  inputs=None,
347
+ outputs=[chatbot, state, image_state, image_input]
348
  )
349
 
350
  demo.launch()