Shanshan Wang commited on
Commit
6c5150b
1 Parent(s): c65d305

added 0.8b model in the model list

Browse files
Files changed (1) hide show
  1. app.py +153 -39
app.py CHANGED
@@ -11,8 +11,15 @@ import os
11
  from huggingface_hub import login
12
  hf_token = os.environ.get('hf_token', None)
13
 
14
- # Define the path to your model
15
- path = "h2oai/h2ovl-mississippi-2b"
 
 
 
 
 
 
 
16
 
17
  # image preprocesing
18
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
@@ -126,7 +133,7 @@ def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbna
126
  thumbnail_img = image.resize((image_size, image_size))
127
  processed_images.append(thumbnail_img)
128
  return processed_images
129
- def load_image1(image_file, input_size=448, min_num=1, max_num=12):
130
  if isinstance(image_file, str):
131
  image = Image.open(image_file).convert('RGB')
132
  else:
@@ -134,7 +141,7 @@ def load_image1(image_file, input_size=448, min_num=1, max_num=12):
134
  transform = build_transform(input_size=input_size)
135
  images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
136
  pixel_values = [transform(image) for image in images]
137
- pixel_values = torch.stack(pixel_values)
138
  return pixel_values, target_aspect_ratio
139
 
140
  def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
@@ -146,43 +153,99 @@ def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect
146
  transform = build_transform(input_size=input_size)
147
  images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
148
  pixel_values = [transform(image) for image in images]
149
- pixel_values = torch.stack(pixel_values)
150
  return pixel_values
151
 
152
  def load_image_msac(file_name):
153
  pixel_values, target_aspect_ratio = load_image1(file_name, min_num=1, max_num=6)
154
- pixel_values = pixel_values.to(torch.bfloat16).cuda()
155
  pixel_values2 = load_image2(file_name, min_num=3, max_num=6, target_aspect_ratio=target_aspect_ratio)
156
- pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
157
  pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
158
  return pixel_values
159
- # Load the model and tokenizer
160
- model = AutoModel.from_pretrained(
161
- path,
162
- torch_dtype=torch.bfloat16,
163
- low_cpu_mem_usage=True,
164
- trust_remote_code=True,
165
- use_auth_token=hf_token
166
- ).eval().cuda()
167
-
168
- tokenizer = AutoTokenizer.from_pretrained(
169
- path,
170
- trust_remote_code=True,
171
- use_fast=False,
172
- use_auth_token=hf_token
173
- )
174
- tokenizer.pad_token = tokenizer.unk_token
175
- tokenizer.eos_token = "<|end|>"
176
- model.generation_config.pad_token_id = tokenizer.pad_token_id
177
-
178
-
179
- def inference(image, user_message, temperature, top_p, max_new_tokens, chatbot,state, image_state):
180
- # if image is provided, store it in image_state:
181
- if chatbot is None:
182
- chatbot = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  if image is not None:
185
- image_state = load_image_msac(image)
186
  else:
187
  # If image_state is None, then no image has been provided yet
188
  if image_state is None:
@@ -225,8 +288,24 @@ def inference(image, user_message, temperature, top_p, max_new_tokens, chatbot,s
225
 
226
  return chatbot, state, image_state, ""
227
 
228
- def regenerate_response(chatbot, temperature, top_p, max_new_tokens, state, image_state):
229
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  # Check if there is a previous user message
231
  if chatbot is None or len(chatbot) == 0:
232
  chatbot = []
@@ -284,6 +363,22 @@ with gr.Blocks() as demo:
284
 
285
  state= gr.State()
286
  image_state = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
  with gr.Row(equal_height=True):
289
  # First column with image input
@@ -329,13 +424,34 @@ with gr.Blocks() as demo:
329
  # When the submit button is clicked, call the inference function
330
  submit_button.click(
331
  fn=inference,
332
- inputs=[image_input, user_input, temperature_input, top_p_input, max_new_tokens_input, chatbot, state, image_state],
 
 
 
 
 
 
 
 
 
 
 
 
333
  outputs=[chatbot, state, image_state, user_input]
334
  )
335
  # When the regenerate button is clicked, re-run the last inference
336
  regenerate_button.click(
337
  fn=regenerate_response,
338
- inputs=[chatbot, temperature_input, top_p_input,max_new_tokens_input, state, image_state],
 
 
 
 
 
 
 
 
 
339
  outputs=[chatbot, state, image_state]
340
  )
341
 
@@ -347,13 +463,11 @@ with gr.Blocks() as demo:
347
  gr.Examples(
348
  examples=[
349
  ["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
350
- # ["assets/receipt.jpg", "Read the text on the image"],
351
  ["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
352
  ["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"],
353
  ],
354
  inputs = [image_input, user_input],
355
- # outputs = [chatbot, state, image_state, user_input],
356
- # fn=inference,
357
  label = "examples",
358
  )
359
 
 
11
  from huggingface_hub import login
12
  hf_token = os.environ.get('hf_token', None)
13
 
14
+ # # Define the path to your model
15
+ # path = "h2oai/h2ovl-mississippi-2b"
16
+
17
+ # Define the models and their paths
18
+ model_paths = {
19
+ "H2OVL-Mississippi-2B":"h2oai/h2ovl-mississippi-2b",
20
+ "H2OVL-Mississippi-0.8B":"h2oai/h2ovl-mississippi-800m",
21
+ # Add more models as needed
22
+ }
23
 
24
  # image preprocesing
25
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
 
133
  thumbnail_img = image.resize((image_size, image_size))
134
  processed_images.append(thumbnail_img)
135
  return processed_images
136
+ def load_image1(image_file, input_size=448, min_num=1, max_num=6):
137
  if isinstance(image_file, str):
138
  image = Image.open(image_file).convert('RGB')
139
  else:
 
141
  transform = build_transform(input_size=input_size)
142
  images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
143
  pixel_values = [transform(image) for image in images]
144
+ pixel_values = torch.stack(pixel_values).to(torch.bfloat16).cuda()
145
  return pixel_values, target_aspect_ratio
146
 
147
  def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
 
153
  transform = build_transform(input_size=input_size)
154
  images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
155
  pixel_values = [transform(image) for image in images]
156
+ pixel_values = torch.stack(pixel_values).to(torch.bfloat16).cuda()
157
  return pixel_values
158
 
159
  def load_image_msac(file_name):
160
  pixel_values, target_aspect_ratio = load_image1(file_name, min_num=1, max_num=6)
161
+ # pixel_values = pixel_values.to(torch.bfloat16).cuda()
162
  pixel_values2 = load_image2(file_name, min_num=3, max_num=6, target_aspect_ratio=target_aspect_ratio)
163
+ # pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
164
  pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
165
  return pixel_values
166
+
167
+
168
+ def load_model_and_set_image_function(model_name):
169
+ # Get the model path from the model_paths dictionary
170
+ model_path = model_paths[model_name]
171
+
172
+ # Load the model
173
+ model = AutoModel.from_pretrained(
174
+ model_path,
175
+ torch_dtype=torch.bfloat16,
176
+ low_cpu_mem_usage=True,
177
+ trust_remote_code=True,
178
+ use_auth_token=hf_token
179
+ ).eval().cuda()
180
+
181
+ tokenizer = AutoTokenizer.from_pretrained(
182
+ model_path,
183
+ trust_remote_code=True,
184
+ use_fast=False,
185
+ use_auth_token=hf_token
186
+ )
187
+ tokenizer.pad_token = tokenizer.unk_token
188
+ tokenizer.eos_token = "<|end|>"
189
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
190
+
191
+ # Set the appropriate image loading function
192
+ if "0.8B" in model_name:
193
+ image_load_function = lambda x: load_image1(x)[0]
194
+ elif "2B" in model_name:
195
+ image_load_function = load_image_msac
196
+ else:
197
+ image_load_function = load_image1 # Default function
198
+
199
+ return model, tokenizer, image_load_function
200
+
201
+
202
+ # # Load the model and tokenizer
203
+ # model = AutoModel.from_pretrained(
204
+ # path,
205
+ # torch_dtype=torch.bfloat16,
206
+ # low_cpu_mem_usage=True,
207
+ # trust_remote_code=True,
208
+ # use_auth_token=hf_token
209
+ # ).eval().cuda()
210
+
211
+ # tokenizer = AutoTokenizer.from_pretrained(
212
+ # path,
213
+ # trust_remote_code=True,
214
+ # use_fast=False,
215
+ # use_auth_token=hf_token
216
+ # )
217
+ # tokenizer.pad_token = tokenizer.unk_token
218
+ # tokenizer.eos_token = "<|end|>"
219
+ # model.generation_config.pad_token_id = tokenizer.pad_token_id
220
+
221
+
222
+ def inference(image,
223
+ user_message,
224
+ temperature,
225
+ top_p,
226
+ max_new_tokens,
227
+ chatbot,state,
228
+ image_state,
229
+ model_state,
230
+ tokenizer_state,
231
+ image_load_function_state):
232
+
233
+ # Check if model_state is None
234
+ if model_state is None or tokenizer_state is None:
235
+ chatbot.append(("System", "Please select a model to start the conversation."))
236
+ return chatbot, state, image_state, ""
237
+
238
+ model = model_state
239
+ tokenizer = tokenizer_state
240
+ image_load_function = image_load_function_state
241
+
242
+
243
+ # # if image is provided, store it in image_state:
244
+ # if chatbot is None:
245
+ # chatbot = []
246
 
247
  if image is not None:
248
+ image_state = image_load_function(image)
249
  else:
250
  # If image_state is None, then no image has been provided yet
251
  if image_state is None:
 
288
 
289
  return chatbot, state, image_state, ""
290
 
291
+ def regenerate_response(chatbot,
292
+ temperature,
293
+ top_p,
294
+ max_new_tokens,
295
+ state,
296
+ image_state,
297
+ model_state,
298
+ tokenizer_state):
299
+
300
+ # Check if model_state is None
301
+ if model_state is None or tokenizer_state is None:
302
+ chatbot.append(("System", "Please select a model to start the conversation."))
303
+ return chatbot, state, image_state
304
+
305
+ model = model_state
306
+ tokenizer = tokenizer_state
307
+
308
+
309
  # Check if there is a previous user message
310
  if chatbot is None or len(chatbot) == 0:
311
  chatbot = []
 
363
 
364
  state= gr.State()
365
  image_state = gr.State()
366
+ model_state = gr.State()
367
+ tokenizer_state = gr.State()
368
+ image_load_function_state = gr.State()
369
+
370
+ with gr.Row():
371
+ model_dropdown = gr.Dropdown(
372
+ choices=list(model_paths.keys()),
373
+ label="Select Model"
374
+ )
375
+
376
+ # When the model selection changes, load the new model
377
+ model_dropdown.change(
378
+ fn=load_model_and_set_image_function,
379
+ inputs=[model_dropdown],
380
+ outputs=[model_state, tokenizer_state, image_load_function_state]
381
+ )
382
 
383
  with gr.Row(equal_height=True):
384
  # First column with image input
 
424
  # When the submit button is clicked, call the inference function
425
  submit_button.click(
426
  fn=inference,
427
+ inputs=[
428
+ image_input,
429
+ user_input,
430
+ temperature_input,
431
+ top_p_input,
432
+ max_new_tokens_input,
433
+ chatbot,
434
+ state,
435
+ image_state,
436
+ model_state,
437
+ tokenizer_state,
438
+ image_load_function_state
439
+ ],
440
  outputs=[chatbot, state, image_state, user_input]
441
  )
442
  # When the regenerate button is clicked, re-run the last inference
443
  regenerate_button.click(
444
  fn=regenerate_response,
445
+ inputs=[
446
+ chatbot,
447
+ temperature_input,
448
+ top_p_input,
449
+ max_new_tokens_input,
450
+ state,
451
+ image_state,
452
+ model_state,
453
+ tokenizer_state,
454
+ ],
455
  outputs=[chatbot, state, image_state]
456
  )
457
 
 
463
  gr.Examples(
464
  examples=[
465
  ["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
466
+ ["assets/receipt.jpg", "Read the text on the image"],
467
  ["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
468
  ["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"],
469
  ],
470
  inputs = [image_input, user_input],
 
 
471
  label = "examples",
472
  )
473