Shanshan Wang commited on
Commit
471e971
1 Parent(s): b596e22

updated app

Browse files
Files changed (2) hide show
  1. app.py +235 -4
  2. requirements.txt +9 -0
app.py CHANGED
@@ -1,7 +1,238 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoModel, AutoTokenizer, AutoImageProcessor
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from PIL import Image
7
+ import time
8
+ import os, sys
9
+ import json
10
+ import re
11
+ from tqdm import tqdm
12
 
13
+ import pandas as pd
 
14
 
15
+ from torchvision.transforms.functional import InterpolationMode
16
+ # Define the path to your model
17
+ path = 'h2oai/h2o-mississippi-2b'
18
+
19
+ # image preprocesing
20
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
21
+ IMAGENET_STD = (0.229, 0.224, 0.225)
22
+
23
+ start_pre = time.time()
24
+
25
+ def build_transform(input_size):
26
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
27
+ transform = T.Compose([
28
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
29
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
30
+ T.ToTensor(),
31
+ T.Normalize(mean=MEAN, std=STD)
32
+ ])
33
+ return transform
34
+
35
+
36
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
37
+ best_ratio_diff = float('inf')
38
+ best_ratio = (1, 1)
39
+ area = width * height
40
+ for ratio in target_ratios:
41
+ target_aspect_ratio = ratio[0] / ratio[1]
42
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
43
+ if ratio_diff < best_ratio_diff:
44
+ best_ratio_diff = ratio_diff
45
+ best_ratio = ratio
46
+ elif ratio_diff == best_ratio_diff:
47
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
48
+ best_ratio = ratio
49
+ return best_ratio
50
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False):
51
+ orig_width, orig_height = image.size
52
+ aspect_ratio = orig_width / orig_height
53
+
54
+ # calculate the existing image aspect ratio
55
+ target_ratios = set(
56
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
57
+ i * j <= max_num and i * j >= min_num)
58
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
59
+
60
+ # find the closest aspect ratio to the target
61
+ target_aspect_ratio = find_closest_aspect_ratio(
62
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
63
+
64
+ # calculate the target width and height
65
+ target_width = image_size * target_aspect_ratio[0]
66
+ target_height = image_size * target_aspect_ratio[1]
67
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
68
+
69
+ # resize the image
70
+ resized_img = image.resize((target_width, target_height))
71
+ processed_images = []
72
+ for i in range(blocks):
73
+ box = (
74
+ (i % (target_width // image_size)) * image_size,
75
+ (i // (target_width // image_size)) * image_size,
76
+ ((i % (target_width // image_size)) + 1) * image_size,
77
+ ((i // (target_width // image_size)) + 1) * image_size
78
+ )
79
+ # split the image
80
+ split_img = resized_img.crop(box)
81
+ processed_images.append(split_img)
82
+ assert len(processed_images) == blocks
83
+ if use_thumbnail and len(processed_images) != 1:
84
+ thumbnail_img = image.resize((image_size, image_size))
85
+ processed_images.append(thumbnail_img)
86
+ return processed_images, target_aspect_ratio
87
+
88
+
89
+ def dynamic_preprocess2(image, min_num=1, max_num=6, image_size=448, use_thumbnail=False, prior_aspect_ratio=None):
90
+ orig_width, orig_height = image.size
91
+ aspect_ratio = orig_width / orig_height
92
+
93
+ # calculate the existing image aspect ratio
94
+ target_ratios = set(
95
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
96
+ i * j <= max_num and i * j >= min_num)
97
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
98
+
99
+ new_target_ratios = []
100
+ if prior_aspect_ratio is not None:
101
+ for i in target_ratios:
102
+ if prior_aspect_ratio[0]%i[0] != 0 and prior_aspect_ratio[1]%i[1] != 0:
103
+ new_target_ratios.append(i)
104
+ else:
105
+ continue
106
+
107
+ # find the closest aspect ratio to the target
108
+ target_aspect_ratio = find_closest_aspect_ratio(
109
+ aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)
110
+
111
+ # calculate the target width and height
112
+ target_width = image_size * target_aspect_ratio[0]
113
+ target_height = image_size * target_aspect_ratio[1]
114
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
115
+
116
+ # resize the image
117
+ resized_img = image.resize((target_width, target_height))
118
+ processed_images = []
119
+ for i in range(blocks):
120
+ box = (
121
+ (i % (target_width // image_size)) * image_size,
122
+ (i // (target_width // image_size)) * image_size,
123
+ ((i % (target_width // image_size)) + 1) * image_size,
124
+ ((i // (target_width // image_size)) + 1) * image_size
125
+ )
126
+ # split the image
127
+ split_img = resized_img.crop(box)
128
+ processed_images.append(split_img)
129
+ assert len(processed_images) == blocks
130
+ if use_thumbnail and len(processed_images) != 1:
131
+ thumbnail_img = image.resize((image_size, image_size))
132
+ processed_images.append(thumbnail_img)
133
+ return processed_images
134
+ def load_image1(image_file, input_size=448, min_num=1, max_num=12):
135
+ if isinstance(image_file, str):
136
+ image = Image.open(image_file).convert('RGB')
137
+ else:
138
+ image = image_file
139
+ transform = build_transform(input_size=input_size)
140
+ images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)
141
+ pixel_values = [transform(image) for image in images]
142
+ pixel_values = torch.stack(pixel_values)
143
+ return pixel_values, target_aspect_ratio
144
+
145
+ def load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):
146
+
147
+ if isinstance(image_file, str):
148
+ image = Image.open(image_file).convert('RGB')
149
+ else:
150
+ image = image_file
151
+ transform = build_transform(input_size=input_size)
152
+ images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)
153
+ pixel_values = [transform(image) for image in images]
154
+ pixel_values = torch.stack(pixel_values)
155
+ return pixel_values
156
+
157
+ def load_image_msac(file_name):
158
+ pixel_values, target_aspect_ratio = load_image1(file_name, min_num=1, max_num=6)
159
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
160
+ pixel_values2 = load_image2(file_name, min_num=3, max_num=6, target_aspect_ratio=target_aspect_ratio)
161
+ pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
162
+ pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
163
+ return pixel_values
164
+ # Load the model and tokenizer
165
+ model = AutoModel.from_pretrained(
166
+ path,
167
+ torch_dtype=torch.bfloat16,
168
+ low_cpu_mem_usage=True,
169
+ trust_remote_code=True
170
+ ).eval().cuda()
171
+
172
+ tokenizer = AutoTokenizer.from_pretrained(
173
+ path,
174
+ trust_remote_code=True,
175
+ use_fast=False
176
+ )
177
+ tokenizer.pad_token = tokenizer.unk_token
178
+ tokenizer.eos_token = "<|end|>"
179
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
180
+
181
+
182
+ def inference(image, prompt):
183
+ # Check if both image and prompt are provided
184
+ if image is None or prompt.strip() == "":
185
+ return "Please provide both an image and a prompt."
186
+
187
+ # Process the image and get pixel_values
188
+ pixel_values = load_image_msac(image)
189
+
190
+ # Set generation config
191
+ generation_config = dict(
192
+ num_beams=1,
193
+ max_new_tokens=2048,
194
+ do_sample=False,
195
+ )
196
+
197
+ # Generate the response
198
+ response = model.chat(
199
+ tokenizer,
200
+ pixel_values,
201
+ prompt,
202
+ generation_config
203
+ )
204
+
205
+ return response
206
+
207
+ # Build the Gradio interface
208
+ with gr.Blocks() as demo:
209
+ gr.Markdown("H2O-Mississippi")
210
+
211
+ with gr.Row():
212
+ image_input = gr.Image(type="pil", label="Upload an Image")
213
+ prompt_input = gr.Textbox(label="Enter your prompt here")
214
+
215
+ response_output = gr.Textbox(label="Model Response")
216
+
217
+ with gr.Row():
218
+ submit_button = gr.Button("Submit")
219
+ clear_button = gr.Button("Clear")
220
+
221
+ # When the submit button is clicked, call the inference function
222
+ submit_button.click(
223
+ fn=inference,
224
+ inputs=[image_input, prompt_input],
225
+ outputs=response_output
226
+ )
227
+
228
+ # Define the clear button action
229
+ def clear_all():
230
+ return None, "", ""
231
+
232
+ clear_button.click(
233
+ fn=clear_all,
234
+ inputs=None,
235
+ outputs=[image_input, prompt_input, response_output]
236
+ )
237
+
238
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ opencv-python
3
+ gradio==3.35.2
4
+ gradio_client==0.2.9
5
+ httpx==0.24.0
6
+ markdown2[all]
7
+ pydantic
8
+ requests
9
+ uvicorn