Spaces:
Running
on
A10G
Running
on
A10G
Shanshan Wang
commited on
Commit
•
6c5150b
1
Parent(s):
c65d305
added 0.8b model in the model list
Browse files
app.py
CHANGED
@@ -11,8 +11,15 @@ import os
|
|
11 |
from huggingface_hub import login
|
12 |
hf_token = os.environ.get('hf_token', None)
|
13 |
|
14 |
-
# Define the path to your model
|
15 |
-
path = "h2oai/h2ovl-mississippi-2b"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
# image preprocesing
|
18 |
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
@@ -126,7 +133,7 @@ def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbna
|
|
126 |
thumbnail_img = image.resize((image_size, image_size))
|
127 |
processed_images.append(thumbnail_img)
|
128 |
return processed_images
|
129 |
-
def load_image1(image_file, input_size=448, min_num=1, max_num=
|
130 |
if isinstance(image_file, str):
|
131 |
image = Image.open(image_file).convert('RGB')
|
132 |
else:
|
@@ -134,7 +141,7 @@ def load_image1(image_file, input_size=448, min_num=1, max_num=12):
|
|
134 |
transform = build_transform(input_size=input_size)
|
135 |
images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
|
136 |
pixel_values = [transform(image) for image in images]
|
137 |
-
pixel_values = torch.stack(pixel_values)
|
138 |
return pixel_values, target_aspect_ratio
|
139 |
|
140 |
def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
|
@@ -146,43 +153,99 @@ def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect
|
|
146 |
transform = build_transform(input_size=input_size)
|
147 |
images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
|
148 |
pixel_values = [transform(image) for image in images]
|
149 |
-
pixel_values = torch.stack(pixel_values)
|
150 |
return pixel_values
|
151 |
|
152 |
def load_image_msac(file_name):
|
153 |
pixel_values, target_aspect_ratio = load_image1(file_name, min_num=1, max_num=6)
|
154 |
-
pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
155 |
pixel_values2 = load_image2(file_name, min_num=3, max_num=6, target_aspect_ratio=target_aspect_ratio)
|
156 |
-
pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
|
157 |
pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
|
158 |
return pixel_values
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
tokenizer
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
if image is not None:
|
185 |
-
image_state =
|
186 |
else:
|
187 |
# If image_state is None, then no image has been provided yet
|
188 |
if image_state is None:
|
@@ -225,8 +288,24 @@ def inference(image, user_message, temperature, top_p, max_new_tokens, chatbot,s
|
|
225 |
|
226 |
return chatbot, state, image_state, ""
|
227 |
|
228 |
-
def regenerate_response(chatbot,
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
# Check if there is a previous user message
|
231 |
if chatbot is None or len(chatbot) == 0:
|
232 |
chatbot = []
|
@@ -284,6 +363,22 @@ with gr.Blocks() as demo:
|
|
284 |
|
285 |
state= gr.State()
|
286 |
image_state = gr.State()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
with gr.Row(equal_height=True):
|
289 |
# First column with image input
|
@@ -329,13 +424,34 @@ with gr.Blocks() as demo:
|
|
329 |
# When the submit button is clicked, call the inference function
|
330 |
submit_button.click(
|
331 |
fn=inference,
|
332 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
outputs=[chatbot, state, image_state, user_input]
|
334 |
)
|
335 |
# When the regenerate button is clicked, re-run the last inference
|
336 |
regenerate_button.click(
|
337 |
fn=regenerate_response,
|
338 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
339 |
outputs=[chatbot, state, image_state]
|
340 |
)
|
341 |
|
@@ -347,13 +463,11 @@ with gr.Blocks() as demo:
|
|
347 |
gr.Examples(
|
348 |
examples=[
|
349 |
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
|
350 |
-
|
351 |
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
|
352 |
["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"],
|
353 |
],
|
354 |
inputs = [image_input, user_input],
|
355 |
-
# outputs = [chatbot, state, image_state, user_input],
|
356 |
-
# fn=inference,
|
357 |
label = "examples",
|
358 |
)
|
359 |
|
|
|
11 |
from huggingface_hub import login
|
12 |
hf_token = os.environ.get('hf_token', None)
|
13 |
|
14 |
+
# # Define the path to your model
|
15 |
+
# path = "h2oai/h2ovl-mississippi-2b"
|
16 |
+
|
17 |
+
# Define the models and their paths
|
18 |
+
model_paths = {
|
19 |
+
"H2OVL-Mississippi-2B":"h2oai/h2ovl-mississippi-2b",
|
20 |
+
"H2OVL-Mississippi-0.8B":"h2oai/h2ovl-mississippi-800m",
|
21 |
+
# Add more models as needed
|
22 |
+
}
|
23 |
|
24 |
# image preprocesing
|
25 |
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
|
|
133 |
thumbnail_img = image.resize((image_size, image_size))
|
134 |
processed_images.append(thumbnail_img)
|
135 |
return processed_images
|
136 |
+
def load_image1(image_file, input_size=448, min_num=1, max_num=6):
|
137 |
if isinstance(image_file, str):
|
138 |
image = Image.open(image_file).convert('RGB')
|
139 |
else:
|
|
|
141 |
transform = build_transform(input_size=input_size)
|
142 |
images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
|
143 |
pixel_values = [transform(image) for image in images]
|
144 |
+
pixel_values = torch.stack(pixel_values).to(torch.bfloat16).cuda()
|
145 |
return pixel_values, target_aspect_ratio
|
146 |
|
147 |
def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
|
|
|
153 |
transform = build_transform(input_size=input_size)
|
154 |
images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
|
155 |
pixel_values = [transform(image) for image in images]
|
156 |
+
pixel_values = torch.stack(pixel_values).to(torch.bfloat16).cuda()
|
157 |
return pixel_values
|
158 |
|
159 |
def load_image_msac(file_name):
|
160 |
pixel_values, target_aspect_ratio = load_image1(file_name, min_num=1, max_num=6)
|
161 |
+
# pixel_values = pixel_values.to(torch.bfloat16).cuda()
|
162 |
pixel_values2 = load_image2(file_name, min_num=3, max_num=6, target_aspect_ratio=target_aspect_ratio)
|
163 |
+
# pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
|
164 |
pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
|
165 |
return pixel_values
|
166 |
+
|
167 |
+
|
168 |
+
def load_model_and_set_image_function(model_name):
|
169 |
+
# Get the model path from the model_paths dictionary
|
170 |
+
model_path = model_paths[model_name]
|
171 |
+
|
172 |
+
# Load the model
|
173 |
+
model = AutoModel.from_pretrained(
|
174 |
+
model_path,
|
175 |
+
torch_dtype=torch.bfloat16,
|
176 |
+
low_cpu_mem_usage=True,
|
177 |
+
trust_remote_code=True,
|
178 |
+
use_auth_token=hf_token
|
179 |
+
).eval().cuda()
|
180 |
+
|
181 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
182 |
+
model_path,
|
183 |
+
trust_remote_code=True,
|
184 |
+
use_fast=False,
|
185 |
+
use_auth_token=hf_token
|
186 |
+
)
|
187 |
+
tokenizer.pad_token = tokenizer.unk_token
|
188 |
+
tokenizer.eos_token = "<|end|>"
|
189 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
190 |
+
|
191 |
+
# Set the appropriate image loading function
|
192 |
+
if "0.8B" in model_name:
|
193 |
+
image_load_function = lambda x: load_image1(x)[0]
|
194 |
+
elif "2B" in model_name:
|
195 |
+
image_load_function = load_image_msac
|
196 |
+
else:
|
197 |
+
image_load_function = load_image1 # Default function
|
198 |
+
|
199 |
+
return model, tokenizer, image_load_function
|
200 |
+
|
201 |
+
|
202 |
+
# # Load the model and tokenizer
|
203 |
+
# model = AutoModel.from_pretrained(
|
204 |
+
# path,
|
205 |
+
# torch_dtype=torch.bfloat16,
|
206 |
+
# low_cpu_mem_usage=True,
|
207 |
+
# trust_remote_code=True,
|
208 |
+
# use_auth_token=hf_token
|
209 |
+
# ).eval().cuda()
|
210 |
+
|
211 |
+
# tokenizer = AutoTokenizer.from_pretrained(
|
212 |
+
# path,
|
213 |
+
# trust_remote_code=True,
|
214 |
+
# use_fast=False,
|
215 |
+
# use_auth_token=hf_token
|
216 |
+
# )
|
217 |
+
# tokenizer.pad_token = tokenizer.unk_token
|
218 |
+
# tokenizer.eos_token = "<|end|>"
|
219 |
+
# model.generation_config.pad_token_id = tokenizer.pad_token_id
|
220 |
+
|
221 |
+
|
222 |
+
def inference(image,
|
223 |
+
user_message,
|
224 |
+
temperature,
|
225 |
+
top_p,
|
226 |
+
max_new_tokens,
|
227 |
+
chatbot,state,
|
228 |
+
image_state,
|
229 |
+
model_state,
|
230 |
+
tokenizer_state,
|
231 |
+
image_load_function_state):
|
232 |
+
|
233 |
+
# Check if model_state is None
|
234 |
+
if model_state is None or tokenizer_state is None:
|
235 |
+
chatbot.append(("System", "Please select a model to start the conversation."))
|
236 |
+
return chatbot, state, image_state, ""
|
237 |
+
|
238 |
+
model = model_state
|
239 |
+
tokenizer = tokenizer_state
|
240 |
+
image_load_function = image_load_function_state
|
241 |
+
|
242 |
+
|
243 |
+
# # if image is provided, store it in image_state:
|
244 |
+
# if chatbot is None:
|
245 |
+
# chatbot = []
|
246 |
|
247 |
if image is not None:
|
248 |
+
image_state = image_load_function(image)
|
249 |
else:
|
250 |
# If image_state is None, then no image has been provided yet
|
251 |
if image_state is None:
|
|
|
288 |
|
289 |
return chatbot, state, image_state, ""
|
290 |
|
291 |
+
def regenerate_response(chatbot,
|
292 |
+
temperature,
|
293 |
+
top_p,
|
294 |
+
max_new_tokens,
|
295 |
+
state,
|
296 |
+
image_state,
|
297 |
+
model_state,
|
298 |
+
tokenizer_state):
|
299 |
+
|
300 |
+
# Check if model_state is None
|
301 |
+
if model_state is None or tokenizer_state is None:
|
302 |
+
chatbot.append(("System", "Please select a model to start the conversation."))
|
303 |
+
return chatbot, state, image_state
|
304 |
+
|
305 |
+
model = model_state
|
306 |
+
tokenizer = tokenizer_state
|
307 |
+
|
308 |
+
|
309 |
# Check if there is a previous user message
|
310 |
if chatbot is None or len(chatbot) == 0:
|
311 |
chatbot = []
|
|
|
363 |
|
364 |
state= gr.State()
|
365 |
image_state = gr.State()
|
366 |
+
model_state = gr.State()
|
367 |
+
tokenizer_state = gr.State()
|
368 |
+
image_load_function_state = gr.State()
|
369 |
+
|
370 |
+
with gr.Row():
|
371 |
+
model_dropdown = gr.Dropdown(
|
372 |
+
choices=list(model_paths.keys()),
|
373 |
+
label="Select Model"
|
374 |
+
)
|
375 |
+
|
376 |
+
# When the model selection changes, load the new model
|
377 |
+
model_dropdown.change(
|
378 |
+
fn=load_model_and_set_image_function,
|
379 |
+
inputs=[model_dropdown],
|
380 |
+
outputs=[model_state, tokenizer_state, image_load_function_state]
|
381 |
+
)
|
382 |
|
383 |
with gr.Row(equal_height=True):
|
384 |
# First column with image input
|
|
|
424 |
# When the submit button is clicked, call the inference function
|
425 |
submit_button.click(
|
426 |
fn=inference,
|
427 |
+
inputs=[
|
428 |
+
image_input,
|
429 |
+
user_input,
|
430 |
+
temperature_input,
|
431 |
+
top_p_input,
|
432 |
+
max_new_tokens_input,
|
433 |
+
chatbot,
|
434 |
+
state,
|
435 |
+
image_state,
|
436 |
+
model_state,
|
437 |
+
tokenizer_state,
|
438 |
+
image_load_function_state
|
439 |
+
],
|
440 |
outputs=[chatbot, state, image_state, user_input]
|
441 |
)
|
442 |
# When the regenerate button is clicked, re-run the last inference
|
443 |
regenerate_button.click(
|
444 |
fn=regenerate_response,
|
445 |
+
inputs=[
|
446 |
+
chatbot,
|
447 |
+
temperature_input,
|
448 |
+
top_p_input,
|
449 |
+
max_new_tokens_input,
|
450 |
+
state,
|
451 |
+
image_state,
|
452 |
+
model_state,
|
453 |
+
tokenizer_state,
|
454 |
+
],
|
455 |
outputs=[chatbot, state, image_state]
|
456 |
)
|
457 |
|
|
|
463 |
gr.Examples(
|
464 |
examples=[
|
465 |
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
|
466 |
+
["assets/receipt.jpg", "Read the text on the image"],
|
467 |
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
|
468 |
["assets/CBA-1H23-Results-Presentation_wheel.png", "What is the efficiency of H2O.AI in document processing?"],
|
469 |
],
|
470 |
inputs = [image_input, user_input],
|
|
|
|
|
471 |
label = "examples",
|
472 |
)
|
473 |
|