Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
β’
9a84ec8
1
Parent(s):
863eac9
clean up code, add langchain for chatbox
Browse filesThis view is limited to 50 files because it contains too many changes. Β
See raw diff
- .gitignore +8 -0
- DejaVuSansCondensed-Bold.ttf +0 -0
- Image/demo1.svg +0 -0
- Image/demo2.svg +0 -0
- Image/title.svg +0 -1
- app.py +441 -311
- app_huggingface.py +0 -268
- app_old.py +0 -261
- app_wo_langchain.py +588 -0
- caas.py +0 -114
- caption_anything/__init__.py +0 -0
- {captioner β caption_anything/captioner}/README.md +0 -0
- {captioner β caption_anything/captioner}/__init__.py +0 -0
- {captioner β caption_anything/captioner}/base_captioner.py +1 -1
- {captioner β caption_anything/captioner}/blip.py +5 -5
- {captioner β caption_anything/captioner}/blip2.py +4 -7
- {captioner β caption_anything/captioner}/git.py +1 -1
- {captioner β caption_anything/captioner}/modeling_blip.py +0 -0
- {captioner β caption_anything/captioner}/modeling_git.py +0 -0
- {captioner β caption_anything/captioner}/vit_pixel_masks_utils.py +0 -0
- caption_anything.py β caption_anything/model.py +78 -63
- caption_anything/segmenter/__init__.py +5 -0
- {segmenter β caption_anything/segmenter}/base_segmenter.py +66 -31
- {segmenter β caption_anything/segmenter}/readme.md +0 -0
- {text_refiner β caption_anything/text_refiner}/README.md +0 -0
- {text_refiner β caption_anything/text_refiner}/__init__.py +1 -1
- {text_refiner β caption_anything/text_refiner}/text_refiner.py +0 -0
- caption_anything/utils/chatbot.py +236 -0
- image_editing_utils.py β caption_anything/utils/image_editing_utils.py +23 -11
- caption_anything/utils/parser.py +29 -0
- caption_anything/utils/utils.py +419 -0
- env.sh +0 -6
- segmenter/__init__.py +0 -5
- segmenter/images/truck.jpg +0 -0
- segmenter/sam_vit_h_4b8939.pth +0 -3
- test_img/img0.png +0 -0
- test_img/img1.jpg +0 -0
- test_img/img1.jpg.raw_mask.png +0 -0
- test_img/img10.jpg +0 -0
- test_img/img10.jpg.raw_mask.png +0 -0
- test_img/img11.jpg +0 -0
- test_img/img12.jpg +0 -0
- test_img/img12.jpg.raw_mask.png +0 -0
- test_img/img13.jpg +0 -0
- test_img/img13.jpg.raw_mask.png +0 -0
- test_img/img14.jpg +0 -0
- test_img/img14.jpg.raw_mask.png +0 -0
- test_img/img15.jpg +0 -0
- test_img/img15.jpg.raw_mask.png +0 -0
- test_img/img16.jpg +0 -0
.gitignore
CHANGED
@@ -2,6 +2,14 @@ result/
|
|
2 |
model_cache/
|
3 |
*.pth
|
4 |
teng_grad_start.sh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Byte-compiled / optimized / DLL files
|
7 |
__pycache__/
|
|
|
2 |
model_cache/
|
3 |
*.pth
|
4 |
teng_grad_start.sh
|
5 |
+
*.jpg
|
6 |
+
*.jpeg
|
7 |
+
*.png
|
8 |
+
*.svg
|
9 |
+
*.gif
|
10 |
+
*.tiff
|
11 |
+
*.webp
|
12 |
+
|
13 |
|
14 |
# Byte-compiled / optimized / DLL files
|
15 |
__pycache__/
|
DejaVuSansCondensed-Bold.ttf
DELETED
Binary file (632 kB)
|
|
Image/demo1.svg
DELETED
Image/demo2.svg
DELETED
Image/title.svg
DELETED
app.py
CHANGED
@@ -1,85 +1,63 @@
|
|
1 |
-
|
2 |
-
import string
|
3 |
-
import gradio as gr
|
4 |
-
import requests
|
5 |
-
from caption_anything import CaptionAnything
|
6 |
-
import torch
|
7 |
import json
|
8 |
-
import
|
9 |
-
import
|
10 |
-
from caption_anything import parse_augment
|
11 |
import numpy as np
|
12 |
-
|
13 |
-
|
14 |
-
import
|
15 |
-
from
|
16 |
-
|
17 |
-
import
|
18 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
from segment_anything import sam_model_registry
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
def download_checkpoint(url, folder, filename):
|
25 |
-
os.makedirs(folder, exist_ok=True)
|
26 |
-
filepath = os.path.join(folder, filename)
|
27 |
-
|
28 |
-
if not os.path.exists(filepath):
|
29 |
-
response = requests.get(url, stream=True)
|
30 |
-
with open(filepath, "wb") as f:
|
31 |
-
for chunk in response.iter_content(chunk_size=8192):
|
32 |
-
if chunk:
|
33 |
-
f.write(chunk)
|
34 |
-
|
35 |
-
return filepath
|
36 |
-
|
37 |
-
|
38 |
-
title = """<p><h1 align="center">Caption-Anything</h1></p>
|
39 |
-
"""
|
40 |
-
description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
41 |
-
|
42 |
-
examples = [
|
43 |
-
["test_img/img35.webp"],
|
44 |
-
["test_img/img2.jpg"],
|
45 |
-
["test_img/img5.jpg"],
|
46 |
-
["test_img/img12.jpg"],
|
47 |
-
["test_img/img14.jpg"],
|
48 |
-
["test_img/img0.png"],
|
49 |
-
["test_img/img1.jpg"],
|
50 |
-
]
|
51 |
-
|
52 |
-
seg_model_map = {
|
53 |
-
'base': 'vit_b',
|
54 |
-
'large': 'vit_l',
|
55 |
-
'huge': 'vit_h'
|
56 |
-
}
|
57 |
-
ckpt_url_map = {
|
58 |
-
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
59 |
-
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
60 |
-
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
|
61 |
-
}
|
62 |
-
os.makedirs('result', exist_ok=True)
|
63 |
args = parse_augment()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
-
checkpoint_url = ckpt_url_map[seg_model_map[args.segmenter]]
|
66 |
-
folder = "segmenter"
|
67 |
-
filename = os.path.basename(checkpoint_url)
|
68 |
-
args.segmenter_checkpoint = os.path.join(folder, filename)
|
69 |
|
70 |
-
|
|
|
|
|
|
|
71 |
|
72 |
-
#
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
|
|
81 |
|
82 |
-
|
|
|
|
|
|
|
|
|
83 |
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
84 |
captioner = captioner
|
85 |
if session_id is not None:
|
@@ -89,17 +67,22 @@ def build_caption_anything_with_models(args, api_key="", captioner=None, sam_mod
|
|
89 |
|
90 |
def init_openai_api_key(api_key=""):
|
91 |
text_refiner = None
|
|
|
92 |
if api_key and len(api_key) > 30:
|
93 |
try:
|
94 |
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
95 |
-
text_refiner.llm('hi')
|
|
|
96 |
except:
|
97 |
text_refiner = None
|
|
|
98 |
openai_available = text_refiner is not None
|
99 |
-
return gr.update(visible
|
|
|
|
|
100 |
|
101 |
|
102 |
-
def
|
103 |
inputs = json.loads(chat_input)
|
104 |
if click_mode == 'Continuous':
|
105 |
points = click_state[0]
|
@@ -119,13 +102,14 @@ def get_prompt(chat_input, click_state, click_mode):
|
|
119 |
raise NotImplementedError
|
120 |
|
121 |
prompt = {
|
122 |
-
"prompt_type":["click"],
|
123 |
-
"input_point":click_state[0],
|
124 |
-
"input_label":click_state[1],
|
125 |
-
"multimask_output":"True",
|
126 |
}
|
127 |
return prompt
|
128 |
|
|
|
129 |
def update_click_state(click_state, caption, click_mode):
|
130 |
if click_mode == 'Continuous':
|
131 |
click_state[2].append(caption)
|
@@ -134,280 +118,426 @@ def update_click_state(click_state, caption, click_mode):
|
|
134 |
else:
|
135 |
raise NotImplementedError
|
136 |
|
137 |
-
|
138 |
-
|
139 |
-
if
|
|
|
|
|
140 |
response = "Text refiner is not initilzed, please input openai api key."
|
141 |
state = state + [(chat_input, response)]
|
142 |
-
return state, state
|
143 |
-
|
144 |
-
|
145 |
-
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
|
146 |
-
suffix = '\nHuman: {chat_input}\nAI: '
|
147 |
-
qa_template = '\nHuman: {q}\nAI: {a}'
|
148 |
-
# # "The image is of width {width} and height {height}."
|
149 |
-
point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \n Now, let's chat!"
|
150 |
-
prev_visual_context = ""
|
151 |
-
pos_points = []
|
152 |
-
pos_captions = []
|
153 |
-
for i in range(len(points)):
|
154 |
-
if labels[i] == 1:
|
155 |
-
pos_points.append(f"({points[i][0]}, {points[i][0]})")
|
156 |
-
pos_captions.append(captions[i])
|
157 |
-
prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(pos_captions[-1], ', '.join(pos_points))
|
158 |
-
|
159 |
-
context_length_thres = 500
|
160 |
-
prev_history = ""
|
161 |
-
for i in range(len(chat_state)):
|
162 |
-
q, a = chat_state[i]
|
163 |
-
if len(prev_history) < context_length_thres:
|
164 |
-
prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
|
165 |
-
else:
|
166 |
-
break
|
167 |
-
chat_prompt = point_chat_prompt.format(**{"img_caption":img_caption,"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
|
168 |
-
print('\nchat_prompt: ', chat_prompt)
|
169 |
-
response = text_refiner.llm(chat_prompt)
|
170 |
-
state = state + [(chat_input, response)]
|
171 |
-
chat_state = chat_state + [(chat_input, response)]
|
172 |
-
return state, state, chat_state
|
173 |
-
|
174 |
-
def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
175 |
-
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
model = build_caption_anything_with_models(
|
178 |
args,
|
179 |
api_key="",
|
180 |
captioner=shared_captioner,
|
181 |
sam_model=shared_sam_model,
|
182 |
-
text_refiner=text_refiner,
|
183 |
session_id=iface.app_id
|
184 |
)
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
if point_prompt == 'Positive':
|
192 |
-
coordinate = "[[{}, {}, 1]]".format(str(
|
193 |
else:
|
194 |
-
coordinate = "[[{}, {}, 0]]".format(str(
|
|
|
|
|
|
|
|
|
195 |
|
196 |
controls = {'length': length,
|
197 |
'sentiment': sentiment,
|
198 |
'factuality': factuality,
|
199 |
'language': language}
|
200 |
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
207 |
|
208 |
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
209 |
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
|
|
210 |
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
211 |
-
# for k, v in out['generated_captions'].items():
|
212 |
-
# state = state + [(f'{k}: {v}', None)]
|
213 |
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
|
214 |
wiki = out['generated_captions'].get('wiki', "")
|
215 |
-
|
216 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
217 |
text = out['generated_captions']['raw_caption']
|
218 |
-
# draw = ImageDraw.Draw(image_input)
|
219 |
-
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
220 |
input_mask = np.array(out['mask'].convert('P'))
|
221 |
image_input = mask_painter(np.array(image_input), input_mask)
|
222 |
origin_image_input = image_input
|
223 |
-
image_input = create_bubble_frame(image_input, text, (
|
224 |
-
|
225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
if not args.disable_gpt and model.text_refiner:
|
227 |
-
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
|
|
228 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
229 |
new_cap = refined_caption['caption']
|
230 |
wiki = refined_caption['wiki']
|
231 |
state = state + [(None, f"caption: {new_cap}")]
|
232 |
-
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (
|
233 |
-
|
|
|
|
|
234 |
|
235 |
|
236 |
-
def
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
model = build_caption_anything_with_models(
|
247 |
args,
|
248 |
api_key="",
|
249 |
captioner=shared_captioner,
|
250 |
sam_model=shared_sam_model,
|
|
|
251 |
session_id=iface.app_id
|
252 |
)
|
253 |
-
|
254 |
-
image_embedding =
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
'''
|
265 |
-
)
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
interactive=True)
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
with gr.Row(scale=1.0):
|
301 |
-
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
302 |
-
sentiment = gr.Radio(
|
303 |
-
choices=["Positive", "Natural", "Negative"],
|
304 |
-
value="Natural",
|
305 |
-
label="Sentiment",
|
306 |
-
interactive=True,
|
307 |
-
)
|
308 |
-
with gr.Row(scale=1.0):
|
309 |
-
factuality = gr.Radio(
|
310 |
-
choices=["Factual", "Imagination"],
|
311 |
-
value="Factual",
|
312 |
-
label="Factuality",
|
313 |
-
interactive=True,
|
314 |
)
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
2 |
import json
|
3 |
+
import PIL
|
4 |
+
import gradio as gr
|
|
|
5 |
import numpy as np
|
6 |
+
from gradio import processing_utils
|
7 |
+
|
8 |
+
from packaging import version
|
9 |
+
from PIL import Image, ImageDraw
|
10 |
+
|
11 |
+
from caption_anything.model import CaptionAnything
|
12 |
+
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
13 |
+
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
|
14 |
+
from caption_anything.utils.parser import parse_augment
|
15 |
+
from caption_anything.captioner import build_captioner
|
16 |
+
from caption_anything.text_refiner import build_text_refiner
|
17 |
+
from caption_anything.segmenter import build_segmenter
|
18 |
+
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
19 |
from segment_anything import sam_model_registry
|
20 |
+
|
21 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
args = parse_augment()
|
23 |
+
if args.segmenter_checkpoint is None:
|
24 |
+
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
25 |
+
else:
|
26 |
+
segmenter_checkpoint = args.segmenter_checkpoint
|
27 |
+
|
28 |
+
shared_captioner = build_captioner(args.captioner, args.device, args)
|
29 |
+
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
30 |
+
tools_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.chat_tools_dict.split(',')}
|
31 |
+
shared_chatbot_tools = build_chatbot_tools(tools_dict)
|
32 |
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
class ImageSketcher(gr.Image):
|
35 |
+
"""
|
36 |
+
Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
|
37 |
+
"""
|
38 |
|
39 |
+
is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
|
40 |
+
|
41 |
+
def __init__(self, **kwargs):
|
42 |
+
super().__init__(tool="sketch", **kwargs)
|
43 |
+
|
44 |
+
def preprocess(self, x):
|
45 |
+
if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
|
46 |
+
assert isinstance(x, dict)
|
47 |
+
if x['mask'] is None:
|
48 |
+
decode_image = processing_utils.decode_base64_to_image(x['image'])
|
49 |
+
width, height = decode_image.size
|
50 |
+
mask = np.zeros((height, width, 4), dtype=np.uint8)
|
51 |
+
mask[..., -1] = 255
|
52 |
+
mask = self.postprocess(mask)
|
53 |
|
54 |
+
x['mask'] = mask
|
55 |
|
56 |
+
return super().preprocess(x)
|
57 |
+
|
58 |
+
|
59 |
+
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
|
60 |
+
session_id=None):
|
61 |
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
62 |
captioner = captioner
|
63 |
if session_id is not None:
|
|
|
67 |
|
68 |
def init_openai_api_key(api_key=""):
|
69 |
text_refiner = None
|
70 |
+
visual_chatgpt = None
|
71 |
if api_key and len(api_key) > 30:
|
72 |
try:
|
73 |
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
74 |
+
text_refiner.llm('hi') # test
|
75 |
+
visual_chatgpt = ConversationBot(shared_chatbot_tools, api_key)
|
76 |
except:
|
77 |
text_refiner = None
|
78 |
+
visual_chatgpt = None
|
79 |
openai_available = text_refiner is not None
|
80 |
+
return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
|
81 |
+
visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
|
82 |
+
visible=True), text_refiner, visual_chatgpt
|
83 |
|
84 |
|
85 |
+
def get_click_prompt(chat_input, click_state, click_mode):
|
86 |
inputs = json.loads(chat_input)
|
87 |
if click_mode == 'Continuous':
|
88 |
points = click_state[0]
|
|
|
102 |
raise NotImplementedError
|
103 |
|
104 |
prompt = {
|
105 |
+
"prompt_type": ["click"],
|
106 |
+
"input_point": click_state[0],
|
107 |
+
"input_label": click_state[1],
|
108 |
+
"multimask_output": "True",
|
109 |
}
|
110 |
return prompt
|
111 |
|
112 |
+
|
113 |
def update_click_state(click_state, caption, click_mode):
|
114 |
if click_mode == 'Continuous':
|
115 |
click_state[2].append(caption)
|
|
|
118 |
else:
|
119 |
raise NotImplementedError
|
120 |
|
121 |
+
def chat_input_callback(*args):
|
122 |
+
visual_chatgpt, chat_input, click_state, state, aux_state = args
|
123 |
+
if visual_chatgpt is not None:
|
124 |
+
return visual_chatgpt.run_text(chat_input, state, aux_state)
|
125 |
+
else:
|
126 |
response = "Text refiner is not initilzed, please input openai api key."
|
127 |
state = state + [(chat_input, response)]
|
128 |
+
return state, state
|
129 |
+
|
130 |
+
def upload_callback(image_input, state, visual_chatgpt=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
133 |
+
image_input, mask = image_input['image'], image_input['mask']
|
134 |
+
|
135 |
+
click_state = [[], [], []]
|
136 |
+
res = 1024
|
137 |
+
width, height = image_input.size
|
138 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
139 |
+
if ratio < 1.0:
|
140 |
+
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
141 |
+
print('Scaling input image to {}'.format(image_input.size))
|
142 |
+
|
143 |
model = build_caption_anything_with_models(
|
144 |
args,
|
145 |
api_key="",
|
146 |
captioner=shared_captioner,
|
147 |
sam_model=shared_sam_model,
|
|
|
148 |
session_id=iface.app_id
|
149 |
)
|
150 |
+
model.segmenter.set_image(image_input)
|
151 |
+
image_embedding = model.image_embedding
|
152 |
+
original_size = model.original_size
|
153 |
+
input_size = model.input_size
|
154 |
+
|
155 |
+
if visual_chatgpt is not None:
|
156 |
+
new_image_path = get_new_image_name('chat_image', func_name='upload')
|
157 |
+
image_input.save(new_image_path)
|
158 |
+
visual_chatgpt.current_image = new_image_path
|
159 |
+
img_caption, _ = model.captioner.inference_seg(image_input)
|
160 |
+
Human_prompt = f'\nHuman: provide a new figure with path {new_image_path}. The description is: {img_caption}. This information helps you to understand this image, but you should use tools to finish following tasks, rather than directly imagine from my description. If you understand, say \"Received\". \n'
|
161 |
+
AI_prompt = "Received."
|
162 |
+
visual_chatgpt.agent.memory.buffer = visual_chatgpt.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
163 |
+
state = [(None, 'Received new image, resize it to width {} and height {}: '.format(image_input.size[0], image_input.size[1]))]
|
164 |
+
|
165 |
+
return state, state, image_input, click_state, image_input, image_input, image_embedding, \
|
166 |
+
original_size, input_size
|
167 |
+
|
168 |
+
|
169 |
+
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
170 |
+
length, image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt,
|
171 |
+
evt: gr.SelectData):
|
172 |
+
click_index = evt.index
|
173 |
|
174 |
if point_prompt == 'Positive':
|
175 |
+
coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
|
176 |
else:
|
177 |
+
coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
|
178 |
+
|
179 |
+
prompt = get_click_prompt(coordinate, click_state, click_mode)
|
180 |
+
input_points = prompt['input_point']
|
181 |
+
input_labels = prompt['input_label']
|
182 |
|
183 |
controls = {'length': length,
|
184 |
'sentiment': sentiment,
|
185 |
'factuality': factuality,
|
186 |
'language': language}
|
187 |
|
188 |
+
model = build_caption_anything_with_models(
|
189 |
+
args,
|
190 |
+
api_key="",
|
191 |
+
captioner=shared_captioner,
|
192 |
+
sam_model=shared_sam_model,
|
193 |
+
text_refiner=text_refiner,
|
194 |
+
session_id=iface.app_id
|
195 |
+
)
|
196 |
+
|
197 |
+
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
198 |
|
199 |
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
200 |
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
201 |
+
|
202 |
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
|
|
|
|
203 |
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
|
204 |
wiki = out['generated_captions'].get('wiki', "")
|
|
|
205 |
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
206 |
text = out['generated_captions']['raw_caption']
|
|
|
|
|
207 |
input_mask = np.array(out['mask'].convert('P'))
|
208 |
image_input = mask_painter(np.array(image_input), input_mask)
|
209 |
origin_image_input = image_input
|
210 |
+
image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
|
211 |
+
input_points=input_points, input_labels=input_labels)
|
212 |
+
x, y = input_points[-1]
|
213 |
+
|
214 |
+
if visual_chatgpt is not None:
|
215 |
+
new_crop_save_path = get_new_image_name('chat_image', func_name='crop')
|
216 |
+
Image.open(out["crop_save_path"]).save(new_crop_save_path)
|
217 |
+
point_prompt = f'You should primarly use tools on the selected regional image (description: {text}, path: {new_crop_save_path}), which is a part of the whole image (path: {visual_chatgpt.current_image}). If human mentioned some objects not in the selected region, you can use tools on the whole image.'
|
218 |
+
visual_chatgpt.point_prompt = point_prompt
|
219 |
+
|
220 |
+
yield state, state, click_state, image_input, wiki
|
221 |
if not args.disable_gpt and model.text_refiner:
|
222 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
223 |
+
enable_wiki=enable_wiki)
|
224 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
225 |
new_cap = refined_caption['caption']
|
226 |
wiki = refined_caption['wiki']
|
227 |
state = state + [(None, f"caption: {new_cap}")]
|
228 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
|
229 |
+
input_mask,
|
230 |
+
input_points=input_points, input_labels=input_labels)
|
231 |
+
yield state, state, click_state, refined_image_input, wiki
|
232 |
|
233 |
|
234 |
+
def get_sketch_prompt(mask: PIL.Image.Image):
|
235 |
+
"""
|
236 |
+
Get the prompt for the sketcher.
|
237 |
+
TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
|
238 |
+
"""
|
239 |
+
|
240 |
+
mask = np.asarray(mask)[..., 0]
|
241 |
+
|
242 |
+
# Get the bounding box of the sketch
|
243 |
+
y, x = np.where(mask != 0)
|
244 |
+
x1, y1 = np.min(x), np.min(y)
|
245 |
+
x2, y2 = np.max(x), np.max(y)
|
246 |
+
|
247 |
+
prompt = {
|
248 |
+
'prompt_type': ['box'],
|
249 |
+
'input_boxes': [
|
250 |
+
[x1, y1, x2, y2]
|
251 |
+
]
|
252 |
+
}
|
253 |
+
|
254 |
+
return prompt
|
255 |
+
|
256 |
+
|
257 |
+
def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
258 |
+
original_size, input_size, text_refiner):
|
259 |
+
image_input, mask = sketcher_image['image'], sketcher_image['mask']
|
260 |
+
|
261 |
+
prompt = get_sketch_prompt(mask)
|
262 |
+
boxes = prompt['input_boxes']
|
263 |
+
|
264 |
+
controls = {'length': length,
|
265 |
+
'sentiment': sentiment,
|
266 |
+
'factuality': factuality,
|
267 |
+
'language': language}
|
268 |
+
|
269 |
model = build_caption_anything_with_models(
|
270 |
args,
|
271 |
api_key="",
|
272 |
captioner=shared_captioner,
|
273 |
sam_model=shared_sam_model,
|
274 |
+
text_refiner=text_refiner,
|
275 |
session_id=iface.app_id
|
276 |
)
|
277 |
+
|
278 |
+
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
279 |
+
|
280 |
+
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
281 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
282 |
+
|
283 |
+
# Update components and states
|
284 |
+
state.append((f'Box: {boxes}', None))
|
285 |
+
state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
|
286 |
+
wiki = out['generated_captions'].get('wiki', "")
|
287 |
+
text = out['generated_captions']['raw_caption']
|
288 |
+
input_mask = np.array(out['mask'].convert('P'))
|
289 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
290 |
+
|
291 |
+
origin_image_input = image_input
|
292 |
+
|
293 |
+
fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
|
294 |
+
image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
|
295 |
+
|
296 |
+
yield state, state, image_input, wiki
|
297 |
+
|
298 |
+
if not args.disable_gpt and model.text_refiner:
|
299 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
300 |
+
enable_wiki=enable_wiki)
|
301 |
+
|
302 |
+
new_cap = refined_caption['caption']
|
303 |
+
wiki = refined_caption['wiki']
|
304 |
+
state = state + [(None, f"caption: {new_cap}")]
|
305 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
|
306 |
+
|
307 |
+
yield state, state, refined_image_input, wiki
|
308 |
+
|
309 |
+
def clear_chat_memory(visual_chatgpt):
|
310 |
+
if visual_chatgpt is not None:
|
311 |
+
visual_chatgpt.memory.clear()
|
312 |
+
visual_chatgpt.current_image = None
|
313 |
+
visual_chatgpt.point_prompt = ""
|
314 |
+
|
315 |
+
def get_style():
|
316 |
+
current_version = version.parse(gr.__version__)
|
317 |
+
if current_version <= version.parse('3.24.1'):
|
318 |
+
style = '''
|
319 |
+
#image_sketcher{min-height:500px}
|
320 |
+
#image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
|
321 |
+
#image_upload{min-height:500px}
|
322 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
|
323 |
+
'''
|
324 |
+
elif current_version <= version.parse('3.27'):
|
325 |
+
style = '''
|
326 |
+
#image_sketcher{min-height:500px}
|
327 |
+
#image_upload{min-height:500px}
|
328 |
+
'''
|
329 |
+
else:
|
330 |
+
style = None
|
331 |
+
|
332 |
+
return style
|
333 |
+
|
334 |
+
|
335 |
+
def create_ui():
|
336 |
+
title = """<p><h1 align="center">Caption-Anything</h1></p>
|
337 |
+
"""
|
338 |
+
description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
339 |
+
|
340 |
+
examples = [
|
341 |
+
["test_images/img35.webp"],
|
342 |
+
["test_images/img2.jpg"],
|
343 |
+
["test_images/img5.jpg"],
|
344 |
+
["test_images/img12.jpg"],
|
345 |
+
["test_images/img14.jpg"],
|
346 |
+
["test_images/qingming3.jpeg"],
|
347 |
+
["test_images/img1.jpg"],
|
348 |
+
]
|
349 |
+
|
350 |
+
with gr.Blocks(
|
351 |
+
css=get_style()
|
352 |
+
) as iface:
|
353 |
+
state = gr.State([])
|
354 |
+
click_state = gr.State([[], [], []])
|
355 |
+
# chat_state = gr.State([])
|
356 |
+
origin_image = gr.State(None)
|
357 |
+
image_embedding = gr.State(None)
|
358 |
+
text_refiner = gr.State(None)
|
359 |
+
visual_chatgpt = gr.State(None)
|
360 |
+
original_size = gr.State(None)
|
361 |
+
input_size = gr.State(None)
|
362 |
+
# img_caption = gr.State(None)
|
363 |
+
aux_state = gr.State([])
|
364 |
+
|
365 |
+
gr.Markdown(title)
|
366 |
+
gr.Markdown(description)
|
367 |
+
|
368 |
+
with gr.Row():
|
369 |
+
with gr.Column(scale=1.0):
|
370 |
+
with gr.Column(visible=False) as modules_not_need_gpt:
|
371 |
+
with gr.Tab("Click"):
|
372 |
+
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
373 |
+
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
374 |
+
with gr.Row(scale=1.0):
|
375 |
+
with gr.Row(scale=0.4):
|
376 |
+
point_prompt = gr.Radio(
|
377 |
+
choices=["Positive", "Negative"],
|
378 |
+
value="Positive",
|
379 |
+
label="Point Prompt",
|
380 |
+
interactive=True)
|
381 |
+
click_mode = gr.Radio(
|
382 |
+
choices=["Continuous", "Single"],
|
383 |
+
value="Continuous",
|
384 |
+
label="Clicking Mode",
|
385 |
+
interactive=True)
|
386 |
+
with gr.Row(scale=0.4):
|
387 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
|
388 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
389 |
+
with gr.Tab("Trajectory (beta)"):
|
390 |
+
sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
|
391 |
+
elem_id="image_sketcher")
|
392 |
+
with gr.Row():
|
393 |
+
submit_button_sketcher = gr.Button(value="Submit", interactive=True)
|
394 |
+
|
395 |
+
with gr.Column(visible=False) as modules_need_gpt:
|
396 |
+
with gr.Row(scale=1.0):
|
397 |
+
language = gr.Dropdown(
|
398 |
+
['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
|
399 |
+
value="English", label="Language", interactive=True)
|
400 |
+
sentiment = gr.Radio(
|
401 |
+
choices=["Positive", "Natural", "Negative"],
|
402 |
+
value="Natural",
|
403 |
+
label="Sentiment",
|
404 |
+
interactive=True,
|
405 |
+
)
|
406 |
+
with gr.Row(scale=1.0):
|
407 |
+
factuality = gr.Radio(
|
408 |
+
choices=["Factual", "Imagination"],
|
409 |
+
value="Factual",
|
410 |
+
label="Factuality",
|
411 |
+
interactive=True,
|
412 |
+
)
|
413 |
+
length = gr.Slider(
|
414 |
+
minimum=10,
|
415 |
+
maximum=80,
|
416 |
+
value=10,
|
417 |
+
step=1,
|
418 |
+
interactive=True,
|
419 |
+
label="Generated Caption Length",
|
420 |
+
)
|
421 |
+
enable_wiki = gr.Radio(
|
422 |
+
choices=["Yes", "No"],
|
423 |
+
value="No",
|
424 |
+
label="Enable Wiki",
|
425 |
interactive=True)
|
426 |
+
with gr.Column(visible=True) as modules_not_need_gpt3:
|
427 |
+
gr.Examples(
|
428 |
+
examples=examples,
|
429 |
+
inputs=[example_image],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
430 |
)
|
431 |
+
with gr.Column(scale=0.5):
|
432 |
+
openai_api_key = gr.Textbox(
|
433 |
+
placeholder="Input openAI API key",
|
434 |
+
show_label=False,
|
435 |
+
label="OpenAI API Key",
|
436 |
+
lines=1,
|
437 |
+
type="password")
|
438 |
+
with gr.Row(scale=0.5):
|
439 |
+
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
|
440 |
+
disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
|
441 |
+
variant='primary')
|
442 |
+
with gr.Column(visible=False) as modules_need_gpt2:
|
443 |
+
wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
|
444 |
+
with gr.Column(visible=False) as modules_not_need_gpt2:
|
445 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
|
446 |
+
with gr.Column(visible=False) as modules_need_gpt3:
|
447 |
+
chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
|
448 |
+
container=False)
|
449 |
+
with gr.Row():
|
450 |
+
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
451 |
+
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
452 |
+
|
453 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
|
454 |
+
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
|
455 |
+
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
|
456 |
+
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
|
457 |
+
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
|
458 |
+
modules_not_need_gpt,
|
459 |
+
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
|
460 |
+
disable_chatGPT_button.click(init_openai_api_key,
|
461 |
+
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
|
462 |
+
modules_not_need_gpt,
|
463 |
+
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner, visual_chatgpt])
|
464 |
+
|
465 |
+
clear_button_click.click(
|
466 |
+
lambda x: ([[], [], []], x, ""),
|
467 |
+
[origin_image],
|
468 |
+
[click_state, image_input, wiki_output],
|
469 |
+
queue=False,
|
470 |
+
show_progress=False
|
471 |
+
)
|
472 |
+
clear_button_image.click(
|
473 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
474 |
+
[],
|
475 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
476 |
+
queue=False,
|
477 |
+
show_progress=False
|
478 |
+
)
|
479 |
+
clear_button_image.click(clear_chat_memory, inputs=[visual_chatgpt])
|
480 |
+
clear_button_text.click(
|
481 |
+
lambda: ([], [], [[], [], [], []]),
|
482 |
+
[],
|
483 |
+
[chatbot, state, click_state],
|
484 |
+
queue=False,
|
485 |
+
show_progress=False
|
486 |
+
)
|
487 |
+
clear_button_text.click(clear_chat_memory, inputs=[visual_chatgpt])
|
488 |
+
|
489 |
+
image_input.clear(
|
490 |
+
lambda: (None, [], [], [[], [], []], "", "", ""),
|
491 |
+
[],
|
492 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
493 |
+
queue=False,
|
494 |
+
show_progress=False
|
495 |
+
)
|
496 |
+
|
497 |
+
image_input.clear(clear_chat_memory, inputs=[visual_chatgpt])
|
498 |
+
|
499 |
+
|
500 |
+
image_input.upload(upload_callback, [image_input, state, visual_chatgpt],
|
501 |
+
[chatbot, state, origin_image, click_state, image_input, sketcher_input,
|
502 |
+
image_embedding, original_size, input_size])
|
503 |
+
sketcher_input.upload(upload_callback, [sketcher_input, state, visual_chatgpt],
|
504 |
+
[chatbot, state, origin_image, click_state, image_input, sketcher_input,
|
505 |
+
image_embedding, original_size, input_size])
|
506 |
+
chat_input.submit(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
|
507 |
+
[chatbot, state, aux_state])
|
508 |
+
chat_input.submit(lambda: "", None, chat_input)
|
509 |
+
submit_button_text.click(chat_input_callback, [visual_chatgpt, chat_input, click_state, state, aux_state],
|
510 |
+
[chatbot, state, aux_state])
|
511 |
+
submit_button_text.click(lambda: "", None, chat_input)
|
512 |
+
example_image.change(upload_callback, [example_image, state, visual_chatgpt],
|
513 |
+
[chatbot, state, origin_image, click_state, image_input, sketcher_input,
|
514 |
+
image_embedding, original_size, input_size])
|
515 |
+
example_image.change(clear_chat_memory, inputs=[visual_chatgpt])
|
516 |
+
# select coordinate
|
517 |
+
image_input.select(
|
518 |
+
inference_click,
|
519 |
+
inputs=[
|
520 |
+
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
|
521 |
+
image_embedding, state, click_state, original_size, input_size, text_refiner, visual_chatgpt
|
522 |
+
],
|
523 |
+
outputs=[chatbot, state, click_state, image_input, wiki_output],
|
524 |
+
show_progress=False, queue=True
|
525 |
+
)
|
526 |
+
|
527 |
+
submit_button_sketcher.click(
|
528 |
+
inference_traject,
|
529 |
+
inputs=[
|
530 |
+
sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
531 |
+
original_size, input_size, text_refiner
|
532 |
+
],
|
533 |
+
outputs=[chatbot, state, sketcher_input, wiki_output],
|
534 |
+
show_progress=False, queue=True
|
535 |
+
)
|
536 |
+
|
537 |
+
return iface
|
538 |
+
|
539 |
+
|
540 |
+
if __name__ == '__main__':
|
541 |
+
iface = create_ui()
|
542 |
+
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
543 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
|
app_huggingface.py
DELETED
@@ -1,268 +0,0 @@
|
|
1 |
-
from io import BytesIO
|
2 |
-
import string
|
3 |
-
import gradio as gr
|
4 |
-
import requests
|
5 |
-
from caption_anything import CaptionAnything
|
6 |
-
import torch
|
7 |
-
import json
|
8 |
-
import sys
|
9 |
-
import argparse
|
10 |
-
from caption_anything import parse_augment
|
11 |
-
import numpy as np
|
12 |
-
import PIL.ImageDraw as ImageDraw
|
13 |
-
from image_editing_utils import create_bubble_frame
|
14 |
-
import copy
|
15 |
-
from tools import mask_painter
|
16 |
-
from PIL import Image
|
17 |
-
import os
|
18 |
-
|
19 |
-
def download_checkpoint(url, folder, filename):
|
20 |
-
os.makedirs(folder, exist_ok=True)
|
21 |
-
filepath = os.path.join(folder, filename)
|
22 |
-
|
23 |
-
if not os.path.exists(filepath):
|
24 |
-
response = requests.get(url, stream=True)
|
25 |
-
with open(filepath, "wb") as f:
|
26 |
-
for chunk in response.iter_content(chunk_size=8192):
|
27 |
-
if chunk:
|
28 |
-
f.write(chunk)
|
29 |
-
|
30 |
-
return filepath
|
31 |
-
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
32 |
-
folder = "segmenter"
|
33 |
-
filename = "sam_vit_h_4b8939.pth"
|
34 |
-
|
35 |
-
download_checkpoint(checkpoint_url, folder, filename)
|
36 |
-
|
37 |
-
|
38 |
-
title = """<h1 align="center">Caption-Anything</h1>"""
|
39 |
-
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
|
40 |
-
"""
|
41 |
-
|
42 |
-
examples = [
|
43 |
-
["test_img/img2.jpg"],
|
44 |
-
["test_img/img5.jpg"],
|
45 |
-
["test_img/img12.jpg"],
|
46 |
-
["test_img/img14.jpg"],
|
47 |
-
]
|
48 |
-
|
49 |
-
args = parse_augment()
|
50 |
-
args.captioner = 'blip2'
|
51 |
-
args.seg_crop_mode = 'wo_bg'
|
52 |
-
args.regular_box = True
|
53 |
-
# args.device = 'cuda:5'
|
54 |
-
# args.disable_gpt = False
|
55 |
-
# args.enable_reduce_tokens = True
|
56 |
-
# args.port=20322
|
57 |
-
model = CaptionAnything(args)
|
58 |
-
|
59 |
-
def init_openai_api_key(api_key):
|
60 |
-
os.environ['OPENAI_API_KEY'] = api_key
|
61 |
-
model.init_refiner()
|
62 |
-
|
63 |
-
|
64 |
-
def get_prompt(chat_input, click_state):
|
65 |
-
points = click_state[0]
|
66 |
-
labels = click_state[1]
|
67 |
-
inputs = json.loads(chat_input)
|
68 |
-
for input in inputs:
|
69 |
-
points.append(input[:2])
|
70 |
-
labels.append(input[2])
|
71 |
-
|
72 |
-
prompt = {
|
73 |
-
"prompt_type":["click"],
|
74 |
-
"input_point":points,
|
75 |
-
"input_label":labels,
|
76 |
-
"multimask_output":"True",
|
77 |
-
}
|
78 |
-
return prompt
|
79 |
-
|
80 |
-
def chat_with_points(chat_input, click_state, state):
|
81 |
-
if not hasattr(model, "text_refiner"):
|
82 |
-
response = "Text refiner is not initilzed, please input openai api key."
|
83 |
-
state = state + [(chat_input, response)]
|
84 |
-
return state, state
|
85 |
-
|
86 |
-
points, labels, captions = click_state
|
87 |
-
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting! Human: {chat_input}\nAI: "
|
88 |
-
# # "The image is of width {width} and height {height}."
|
89 |
-
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
90 |
-
prev_visual_context = ""
|
91 |
-
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
92 |
-
if len(captions):
|
93 |
-
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
94 |
-
else:
|
95 |
-
prev_visual_context = 'no point exists.'
|
96 |
-
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
97 |
-
response = model.text_refiner.llm(chat_prompt)
|
98 |
-
state = state + [(chat_input, response)]
|
99 |
-
return state, state
|
100 |
-
|
101 |
-
def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt:gr.SelectData):
|
102 |
-
|
103 |
-
if point_prompt == 'Positive':
|
104 |
-
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
105 |
-
else:
|
106 |
-
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
107 |
-
|
108 |
-
controls = {'length': length,
|
109 |
-
'sentiment': sentiment,
|
110 |
-
'factuality': factuality,
|
111 |
-
'language': language}
|
112 |
-
|
113 |
-
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
114 |
-
# chat_input = click_coordinate
|
115 |
-
prompt = get_prompt(coordinate, click_state)
|
116 |
-
print('prompt: ', prompt, 'controls: ', controls)
|
117 |
-
|
118 |
-
out = model.inference(image_input, prompt, controls)
|
119 |
-
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
120 |
-
# for k, v in out['generated_captions'].items():
|
121 |
-
# state = state + [(f'{k}: {v}', None)]
|
122 |
-
state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
123 |
-
wiki = out['generated_captions'].get('wiki', "")
|
124 |
-
click_state[2].append(out['generated_captions']['raw_caption'])
|
125 |
-
|
126 |
-
text = out['generated_captions']['raw_caption']
|
127 |
-
# draw = ImageDraw.Draw(image_input)
|
128 |
-
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
129 |
-
input_mask = np.array(Image.open(out['mask_save_path']).convert('P'))
|
130 |
-
image_input = mask_painter(np.array(image_input), input_mask)
|
131 |
-
origin_image_input = image_input
|
132 |
-
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
133 |
-
|
134 |
-
yield state, state, click_state, chat_input, image_input, wiki
|
135 |
-
if not args.disable_gpt and hasattr(model, "text_refiner"):
|
136 |
-
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
137 |
-
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
138 |
-
new_cap = refined_caption['caption']
|
139 |
-
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (evt.index[0], evt.index[1]))
|
140 |
-
yield state, state, click_state, chat_input, refined_image_input, wiki
|
141 |
-
|
142 |
-
|
143 |
-
def upload_callback(image_input, state):
|
144 |
-
state = [] + [('Image size: ' + str(image_input.size), None)]
|
145 |
-
click_state = [[], [], []]
|
146 |
-
model.segmenter.image = None
|
147 |
-
model.segmenter.image_embedding = None
|
148 |
-
model.segmenter.set_image(image_input)
|
149 |
-
return state, image_input, click_state
|
150 |
-
|
151 |
-
with gr.Blocks(
|
152 |
-
css='''
|
153 |
-
#image_upload{min-height:400px}
|
154 |
-
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
|
155 |
-
'''
|
156 |
-
) as iface:
|
157 |
-
state = gr.State([])
|
158 |
-
click_state = gr.State([[],[],[]])
|
159 |
-
origin_image = gr.State(None)
|
160 |
-
|
161 |
-
gr.Markdown(title)
|
162 |
-
gr.Markdown(description)
|
163 |
-
|
164 |
-
with gr.Row():
|
165 |
-
with gr.Column(scale=1.0):
|
166 |
-
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
167 |
-
with gr.Row(scale=1.0):
|
168 |
-
point_prompt = gr.Radio(
|
169 |
-
choices=["Positive", "Negative"],
|
170 |
-
value="Positive",
|
171 |
-
label="Point Prompt",
|
172 |
-
interactive=True)
|
173 |
-
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
174 |
-
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
175 |
-
with gr.Row(scale=1.0):
|
176 |
-
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
177 |
-
|
178 |
-
sentiment = gr.Radio(
|
179 |
-
choices=["Positive", "Natural", "Negative"],
|
180 |
-
value="Natural",
|
181 |
-
label="Sentiment",
|
182 |
-
interactive=True,
|
183 |
-
)
|
184 |
-
with gr.Row(scale=1.0):
|
185 |
-
factuality = gr.Radio(
|
186 |
-
choices=["Factual", "Imagination"],
|
187 |
-
value="Factual",
|
188 |
-
label="Factuality",
|
189 |
-
interactive=True,
|
190 |
-
)
|
191 |
-
length = gr.Slider(
|
192 |
-
minimum=10,
|
193 |
-
maximum=80,
|
194 |
-
value=10,
|
195 |
-
step=1,
|
196 |
-
interactive=True,
|
197 |
-
label="Length",
|
198 |
-
)
|
199 |
-
|
200 |
-
with gr.Column(scale=0.5):
|
201 |
-
openai_api_key = gr.Textbox(
|
202 |
-
placeholder="Input your openAI API key and press Enter",
|
203 |
-
show_label=False,
|
204 |
-
label = "OpenAI API Key",
|
205 |
-
lines=1,
|
206 |
-
type="password"
|
207 |
-
)
|
208 |
-
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
|
209 |
-
wiki_output = gr.Textbox(lines=6, label="Wiki")
|
210 |
-
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
|
211 |
-
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
212 |
-
with gr.Row():
|
213 |
-
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
214 |
-
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
215 |
-
clear_button_clike.click(
|
216 |
-
lambda x: ([[], [], []], x, ""),
|
217 |
-
[origin_image],
|
218 |
-
[click_state, image_input, wiki_output],
|
219 |
-
queue=False,
|
220 |
-
show_progress=False
|
221 |
-
)
|
222 |
-
clear_button_image.click(
|
223 |
-
lambda: (None, [], [], [[], [], []], ""),
|
224 |
-
[],
|
225 |
-
[image_input, chatbot, state, click_state, wiki_output],
|
226 |
-
queue=False,
|
227 |
-
show_progress=False
|
228 |
-
)
|
229 |
-
clear_button_text.click(
|
230 |
-
lambda: ([], [], [[], [], []]),
|
231 |
-
[],
|
232 |
-
[chatbot, state, click_state],
|
233 |
-
queue=False,
|
234 |
-
show_progress=False
|
235 |
-
)
|
236 |
-
image_input.clear(
|
237 |
-
lambda: (None, [], [], [[], [], []], ""),
|
238 |
-
[],
|
239 |
-
[image_input, chatbot, state, click_state, wiki_output],
|
240 |
-
queue=False,
|
241 |
-
show_progress=False
|
242 |
-
)
|
243 |
-
|
244 |
-
examples = gr.Examples(
|
245 |
-
examples=examples,
|
246 |
-
inputs=[image_input],
|
247 |
-
)
|
248 |
-
|
249 |
-
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
|
250 |
-
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
251 |
-
|
252 |
-
# select coordinate
|
253 |
-
image_input.select(inference_seg_cap,
|
254 |
-
inputs=[
|
255 |
-
origin_image,
|
256 |
-
point_prompt,
|
257 |
-
language,
|
258 |
-
sentiment,
|
259 |
-
factuality,
|
260 |
-
length,
|
261 |
-
state,
|
262 |
-
click_state
|
263 |
-
],
|
264 |
-
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
265 |
-
show_progress=False, queue=True)
|
266 |
-
|
267 |
-
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
268 |
-
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_old.py
DELETED
@@ -1,261 +0,0 @@
|
|
1 |
-
from io import BytesIO
|
2 |
-
import string
|
3 |
-
import gradio as gr
|
4 |
-
import requests
|
5 |
-
from caption_anything import CaptionAnything
|
6 |
-
import torch
|
7 |
-
import json
|
8 |
-
import sys
|
9 |
-
import argparse
|
10 |
-
from caption_anything import parse_augment
|
11 |
-
import os
|
12 |
-
|
13 |
-
# download sam checkpoint if not downloaded
|
14 |
-
def download_checkpoint(url, folder, filename):
|
15 |
-
os.makedirs(folder, exist_ok=True)
|
16 |
-
filepath = os.path.join(folder, filename)
|
17 |
-
|
18 |
-
if not os.path.exists(filepath):
|
19 |
-
response = requests.get(url, stream=True)
|
20 |
-
with open(filepath, "wb") as f:
|
21 |
-
for chunk in response.iter_content(chunk_size=8192):
|
22 |
-
if chunk:
|
23 |
-
f.write(chunk)
|
24 |
-
|
25 |
-
return filepath
|
26 |
-
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
27 |
-
folder = "segmenter"
|
28 |
-
filename = "sam_vit_h_4b8939.pth"
|
29 |
-
|
30 |
-
title = """<h1 align="center">Caption-Anything</h1>"""
|
31 |
-
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them.
|
32 |
-
<br> <strong>Code</strong>: GitHub repo: <a href='https://github.com/ttengwang/Caption-Anything' target='_blank'></a>
|
33 |
-
"""
|
34 |
-
|
35 |
-
examples = [
|
36 |
-
["test_img/img2.jpg", "[[1000, 700, 1]]"]
|
37 |
-
]
|
38 |
-
|
39 |
-
args = parse_augment()
|
40 |
-
|
41 |
-
def get_prompt(chat_input, click_state):
|
42 |
-
points = click_state[0]
|
43 |
-
labels = click_state[1]
|
44 |
-
inputs = json.loads(chat_input)
|
45 |
-
for input in inputs:
|
46 |
-
points.append(input[:2])
|
47 |
-
labels.append(input[2])
|
48 |
-
|
49 |
-
prompt = {
|
50 |
-
"prompt_type":["click"],
|
51 |
-
"input_point":points,
|
52 |
-
"input_label":labels,
|
53 |
-
"multimask_output":"True",
|
54 |
-
}
|
55 |
-
return prompt
|
56 |
-
|
57 |
-
def inference_seg_cap(image_input, chat_input, language, sentiment, factuality, length, state, click_state):
|
58 |
-
controls = {'length': length,
|
59 |
-
'sentiment': sentiment,
|
60 |
-
'factuality': factuality,
|
61 |
-
'language': language}
|
62 |
-
prompt = get_prompt(chat_input, click_state)
|
63 |
-
print('prompt: ', prompt, 'controls: ', controls)
|
64 |
-
out = model.inference(image_input, prompt, controls)
|
65 |
-
state = state + [(None, "Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]))]
|
66 |
-
for k, v in out['generated_captions'].items():
|
67 |
-
state = state + [(f'{k}: {v}', None)]
|
68 |
-
click_state[2].append(out['generated_captions']['raw_caption'])
|
69 |
-
image_output_mask = out['mask_save_path']
|
70 |
-
image_output_crop = out['crop_save_path']
|
71 |
-
return state, state, click_state, image_output_mask, image_output_crop
|
72 |
-
|
73 |
-
|
74 |
-
def upload_callback(image_input, state):
|
75 |
-
state = state + [('Image size: ' + str(image_input.size), None)]
|
76 |
-
return state
|
77 |
-
|
78 |
-
# get coordinate in format [[x,y,positive/negative]]
|
79 |
-
def get_select_coords(image_input, point_prompt, language, sentiment, factuality, length, state, click_state, evt: gr.SelectData):
|
80 |
-
print("point_prompt: ", point_prompt)
|
81 |
-
if point_prompt == 'Positive Point':
|
82 |
-
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
83 |
-
else:
|
84 |
-
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
85 |
-
return (coordinate,) + inference_seg_cap(image_input, coordinate, language, sentiment, factuality, length, state, click_state)
|
86 |
-
|
87 |
-
def chat_with_points(chat_input, click_state, state):
|
88 |
-
points, labels, captions = click_state
|
89 |
-
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\n. Now begin chatting! Human: {chat_input}\nAI: "
|
90 |
-
# "The image is of width {width} and height {height}."
|
91 |
-
point_chat_prompt = "a) Revised prompt: I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps}. Now, let's chat! Human: {chat_input} AI:"
|
92 |
-
prev_visual_context = ""
|
93 |
-
pos_points = [f"{points[i][0]}, {points[i][1]}" for i in range(len(points)) if labels[i] == 1]
|
94 |
-
prev_visual_context = ', '.join(pos_points) + captions[-1] + '\n'
|
95 |
-
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
96 |
-
response = model.text_refiner.llm(chat_prompt)
|
97 |
-
state = state + [(chat_input, response)]
|
98 |
-
return state, state
|
99 |
-
|
100 |
-
def init_openai_api_key(api_key):
|
101 |
-
# os.environ['OPENAI_API_KEY'] = api_key
|
102 |
-
global model
|
103 |
-
model = CaptionAnything(args, api_key)
|
104 |
-
|
105 |
-
css='''
|
106 |
-
#image_upload{min-height:200px}
|
107 |
-
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 200px}
|
108 |
-
'''
|
109 |
-
|
110 |
-
with gr.Blocks(css=css) as iface:
|
111 |
-
state = gr.State([])
|
112 |
-
click_state = gr.State([[],[],[]])
|
113 |
-
caption_state = gr.State([[]])
|
114 |
-
gr.Markdown(title)
|
115 |
-
gr.Markdown(description)
|
116 |
-
|
117 |
-
with gr.Column():
|
118 |
-
openai_api_key = gr.Textbox(
|
119 |
-
placeholder="Input your openAI API key and press Enter",
|
120 |
-
show_label=False,
|
121 |
-
lines=1,
|
122 |
-
type="password",
|
123 |
-
)
|
124 |
-
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key])
|
125 |
-
|
126 |
-
with gr.Row():
|
127 |
-
with gr.Column(scale=0.7):
|
128 |
-
image_input = gr.Image(type="pil", interactive=True, label="Image", elem_id="image_upload").style(height=260,scale=1.0)
|
129 |
-
|
130 |
-
with gr.Row(scale=0.7):
|
131 |
-
point_prompt = gr.Radio(
|
132 |
-
choices=["Positive Point", "Negative Point"],
|
133 |
-
value="Positive Point",
|
134 |
-
label="Points",
|
135 |
-
interactive=True,
|
136 |
-
)
|
137 |
-
|
138 |
-
# with gr.Row():
|
139 |
-
language = gr.Radio(
|
140 |
-
choices=["English", "Chinese", "French", "Spanish", "Arabic", "Portuguese","Cantonese"],
|
141 |
-
value="English",
|
142 |
-
label="Language",
|
143 |
-
interactive=True,
|
144 |
-
)
|
145 |
-
sentiment = gr.Radio(
|
146 |
-
choices=["Positive", "Natural", "Negative"],
|
147 |
-
value="Natural",
|
148 |
-
label="Sentiment",
|
149 |
-
interactive=True,
|
150 |
-
)
|
151 |
-
factuality = gr.Radio(
|
152 |
-
choices=["Factual", "Imagination"],
|
153 |
-
value="Factual",
|
154 |
-
label="Factuality",
|
155 |
-
interactive=True,
|
156 |
-
)
|
157 |
-
length = gr.Slider(
|
158 |
-
minimum=5,
|
159 |
-
maximum=100,
|
160 |
-
value=10,
|
161 |
-
step=1,
|
162 |
-
interactive=True,
|
163 |
-
label="Length",
|
164 |
-
)
|
165 |
-
|
166 |
-
with gr.Column(scale=1.5):
|
167 |
-
with gr.Row():
|
168 |
-
image_output_mask= gr.Image(type="pil", interactive=False, label="Mask").style(height=260,scale=1.0)
|
169 |
-
image_output_crop= gr.Image(type="pil", interactive=False, label="Cropped Image by Mask", show_progress=False).style(height=260,scale=1.0)
|
170 |
-
chatbot = gr.Chatbot(label="Chat Output",).style(height=450,scale=0.5)
|
171 |
-
|
172 |
-
with gr.Row():
|
173 |
-
with gr.Column(scale=0.7):
|
174 |
-
prompt_input = gr.Textbox(lines=1, label="Input Prompt (A list of points like : [[100, 200, 1], [200,300,0]])")
|
175 |
-
prompt_input.submit(
|
176 |
-
inference_seg_cap,
|
177 |
-
[
|
178 |
-
image_input,
|
179 |
-
prompt_input,
|
180 |
-
language,
|
181 |
-
sentiment,
|
182 |
-
factuality,
|
183 |
-
length,
|
184 |
-
state,
|
185 |
-
click_state
|
186 |
-
],
|
187 |
-
[chatbot, state, click_state, image_output_mask, image_output_crop],
|
188 |
-
show_progress=False
|
189 |
-
)
|
190 |
-
|
191 |
-
image_input.upload(
|
192 |
-
upload_callback,
|
193 |
-
[image_input, state],
|
194 |
-
[chatbot]
|
195 |
-
)
|
196 |
-
|
197 |
-
with gr.Row():
|
198 |
-
clear_button = gr.Button(value="Clear Click", interactive=True)
|
199 |
-
clear_button.click(
|
200 |
-
lambda: ("", [[], [], []], None, None),
|
201 |
-
[],
|
202 |
-
[prompt_input, click_state, image_output_mask, image_output_crop],
|
203 |
-
queue=False,
|
204 |
-
show_progress=False
|
205 |
-
)
|
206 |
-
|
207 |
-
clear_button = gr.Button(value="Clear", interactive=True)
|
208 |
-
clear_button.click(
|
209 |
-
lambda: ("", [], [], [[], [], []], None, None),
|
210 |
-
[],
|
211 |
-
[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
|
212 |
-
queue=False,
|
213 |
-
show_progress=False
|
214 |
-
)
|
215 |
-
|
216 |
-
submit_button = gr.Button(
|
217 |
-
value="Submit", interactive=True, variant="primary"
|
218 |
-
)
|
219 |
-
submit_button.click(
|
220 |
-
inference_seg_cap,
|
221 |
-
[
|
222 |
-
image_input,
|
223 |
-
prompt_input,
|
224 |
-
language,
|
225 |
-
sentiment,
|
226 |
-
factuality,
|
227 |
-
length,
|
228 |
-
state,
|
229 |
-
click_state
|
230 |
-
],
|
231 |
-
[chatbot, state, click_state, image_output_mask, image_output_crop],
|
232 |
-
show_progress=False
|
233 |
-
)
|
234 |
-
|
235 |
-
# select coordinate
|
236 |
-
image_input.select(
|
237 |
-
get_select_coords,
|
238 |
-
inputs=[image_input,point_prompt,language,sentiment,factuality,length,state,click_state],
|
239 |
-
outputs=[prompt_input, chatbot, state, click_state, image_output_mask, image_output_crop],
|
240 |
-
show_progress=False
|
241 |
-
)
|
242 |
-
|
243 |
-
image_input.change(
|
244 |
-
lambda: ("", [], [[], [], []]),
|
245 |
-
[],
|
246 |
-
[chatbot, state, click_state],
|
247 |
-
queue=False,
|
248 |
-
)
|
249 |
-
|
250 |
-
with gr.Column(scale=1.5):
|
251 |
-
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
252 |
-
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
253 |
-
|
254 |
-
|
255 |
-
examples = gr.Examples(
|
256 |
-
examples=examples,
|
257 |
-
inputs=[image_input, prompt_input],
|
258 |
-
)
|
259 |
-
|
260 |
-
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
261 |
-
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app_wo_langchain.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
import PIL
|
6 |
+
import gradio as gr
|
7 |
+
import numpy as np
|
8 |
+
from gradio import processing_utils
|
9 |
+
|
10 |
+
from packaging import version
|
11 |
+
from PIL import Image, ImageDraw
|
12 |
+
|
13 |
+
from caption_anything.model import CaptionAnything
|
14 |
+
from caption_anything.utils.image_editing_utils import create_bubble_frame
|
15 |
+
from caption_anything.utils.utils import mask_painter, seg_model_map, prepare_segmenter
|
16 |
+
from caption_anything.utils.parser import parse_augment
|
17 |
+
from caption_anything.captioner import build_captioner
|
18 |
+
from caption_anything.text_refiner import build_text_refiner
|
19 |
+
from caption_anything.segmenter import build_segmenter
|
20 |
+
from caption_anything.utils.chatbot import ConversationBot, build_chatbot_tools, get_new_image_name
|
21 |
+
from segment_anything import sam_model_registry
|
22 |
+
|
23 |
+
|
24 |
+
args = parse_augment()
|
25 |
+
|
26 |
+
args = parse_augment()
|
27 |
+
if args.segmenter_checkpoint is None:
|
28 |
+
_, segmenter_checkpoint = prepare_segmenter(args.segmenter)
|
29 |
+
else:
|
30 |
+
segmenter_checkpoint = args.segmenter_checkpoint
|
31 |
+
|
32 |
+
shared_captioner = build_captioner(args.captioner, args.device, args)
|
33 |
+
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=segmenter_checkpoint).to(args.device)
|
34 |
+
|
35 |
+
|
36 |
+
class ImageSketcher(gr.Image):
|
37 |
+
"""
|
38 |
+
Fix the bug of gradio.Image that cannot upload with tool == 'sketch'.
|
39 |
+
"""
|
40 |
+
|
41 |
+
is_template = True # Magic to make this work with gradio.Block, don't remove unless you know what you're doing.
|
42 |
+
|
43 |
+
def __init__(self, **kwargs):
|
44 |
+
super().__init__(tool="sketch", **kwargs)
|
45 |
+
|
46 |
+
def preprocess(self, x):
|
47 |
+
if self.tool == 'sketch' and self.source in ["upload", "webcam"]:
|
48 |
+
assert isinstance(x, dict)
|
49 |
+
if x['mask'] is None:
|
50 |
+
decode_image = processing_utils.decode_base64_to_image(x['image'])
|
51 |
+
width, height = decode_image.size
|
52 |
+
mask = np.zeros((height, width, 4), dtype=np.uint8)
|
53 |
+
mask[..., -1] = 255
|
54 |
+
mask = self.postprocess(mask)
|
55 |
+
|
56 |
+
x['mask'] = mask
|
57 |
+
|
58 |
+
return super().preprocess(x)
|
59 |
+
|
60 |
+
|
61 |
+
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None,
|
62 |
+
session_id=None):
|
63 |
+
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
64 |
+
captioner = captioner
|
65 |
+
if session_id is not None:
|
66 |
+
print('Init caption anything for session {}'.format(session_id))
|
67 |
+
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
|
68 |
+
|
69 |
+
|
70 |
+
def init_openai_api_key(api_key=""):
|
71 |
+
text_refiner = None
|
72 |
+
if api_key and len(api_key) > 30:
|
73 |
+
try:
|
74 |
+
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
75 |
+
text_refiner.llm('hi') # test
|
76 |
+
except:
|
77 |
+
text_refiner = None
|
78 |
+
openai_available = text_refiner is not None
|
79 |
+
return gr.update(visible=openai_available), gr.update(visible=openai_available), gr.update(
|
80 |
+
visible=openai_available), gr.update(visible=True), gr.update(visible=True), gr.update(
|
81 |
+
visible=True), text_refiner
|
82 |
+
|
83 |
+
|
84 |
+
def get_click_prompt(chat_input, click_state, click_mode):
|
85 |
+
inputs = json.loads(chat_input)
|
86 |
+
if click_mode == 'Continuous':
|
87 |
+
points = click_state[0]
|
88 |
+
labels = click_state[1]
|
89 |
+
for input in inputs:
|
90 |
+
points.append(input[:2])
|
91 |
+
labels.append(input[2])
|
92 |
+
elif click_mode == 'Single':
|
93 |
+
points = []
|
94 |
+
labels = []
|
95 |
+
for input in inputs:
|
96 |
+
points.append(input[:2])
|
97 |
+
labels.append(input[2])
|
98 |
+
click_state[0] = points
|
99 |
+
click_state[1] = labels
|
100 |
+
else:
|
101 |
+
raise NotImplementedError
|
102 |
+
|
103 |
+
prompt = {
|
104 |
+
"prompt_type": ["click"],
|
105 |
+
"input_point": click_state[0],
|
106 |
+
"input_label": click_state[1],
|
107 |
+
"multimask_output": "True",
|
108 |
+
}
|
109 |
+
return prompt
|
110 |
+
|
111 |
+
|
112 |
+
def update_click_state(click_state, caption, click_mode):
|
113 |
+
if click_mode == 'Continuous':
|
114 |
+
click_state[2].append(caption)
|
115 |
+
elif click_mode == 'Single':
|
116 |
+
click_state[2] = [caption]
|
117 |
+
else:
|
118 |
+
raise NotImplementedError
|
119 |
+
|
120 |
+
|
121 |
+
def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
|
122 |
+
if text_refiner is None:
|
123 |
+
response = "Text refiner is not initilzed, please input openai api key."
|
124 |
+
state = state + [(chat_input, response)]
|
125 |
+
return state, state, chat_state
|
126 |
+
|
127 |
+
points, labels, captions = click_state
|
128 |
+
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
|
129 |
+
suffix = '\nHuman: {chat_input}\nAI: '
|
130 |
+
qa_template = '\nHuman: {q}\nAI: {a}'
|
131 |
+
# # "The image is of width {width} and height {height}."
|
132 |
+
point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \nYou are required to use language instead of number to describe these positions. Now, let's chat!"
|
133 |
+
prev_visual_context = ""
|
134 |
+
pos_points = []
|
135 |
+
pos_captions = []
|
136 |
+
|
137 |
+
for i in range(len(points)):
|
138 |
+
if labels[i] == 1:
|
139 |
+
pos_points.append(f"(X:{points[i][0]}, Y:{points[i][1]})")
|
140 |
+
pos_captions.append(captions[i])
|
141 |
+
prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(
|
142 |
+
pos_captions[-1], ', '.join(pos_points))
|
143 |
+
|
144 |
+
context_length_thres = 500
|
145 |
+
prev_history = ""
|
146 |
+
for i in range(len(chat_state)):
|
147 |
+
q, a = chat_state[i]
|
148 |
+
if len(prev_history) < context_length_thres:
|
149 |
+
prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
|
150 |
+
else:
|
151 |
+
break
|
152 |
+
chat_prompt = point_chat_prompt.format(
|
153 |
+
**{"img_caption": img_caption, "points_with_caps": prev_visual_context}) + prev_history + suffix.format(
|
154 |
+
**{"chat_input": chat_input})
|
155 |
+
print('\nchat_prompt: ', chat_prompt)
|
156 |
+
response = text_refiner.llm(chat_prompt)
|
157 |
+
state = state + [(chat_input, response)]
|
158 |
+
chat_state = chat_state + [(chat_input, response)]
|
159 |
+
return state, state, chat_state
|
160 |
+
|
161 |
+
|
162 |
+
def upload_callback(image_input, state):
|
163 |
+
if isinstance(image_input, dict): # if upload from sketcher_input, input contains image and mask
|
164 |
+
image_input, mask = image_input['image'], image_input['mask']
|
165 |
+
|
166 |
+
chat_state = []
|
167 |
+
click_state = [[], [], []]
|
168 |
+
res = 1024
|
169 |
+
width, height = image_input.size
|
170 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
171 |
+
if ratio < 1.0:
|
172 |
+
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
173 |
+
print('Scaling input image to {}'.format(image_input.size))
|
174 |
+
state = [] + [(None, 'Image size: ' + str(image_input.size))]
|
175 |
+
model = build_caption_anything_with_models(
|
176 |
+
args,
|
177 |
+
api_key="",
|
178 |
+
captioner=shared_captioner,
|
179 |
+
sam_model=shared_sam_model,
|
180 |
+
session_id=iface.app_id
|
181 |
+
)
|
182 |
+
model.segmenter.set_image(image_input)
|
183 |
+
image_embedding = model.image_embedding
|
184 |
+
original_size = model.original_size
|
185 |
+
input_size = model.input_size
|
186 |
+
img_caption, _ = model.captioner.inference_seg(image_input)
|
187 |
+
|
188 |
+
return state, state, chat_state, image_input, click_state, image_input, image_input, image_embedding, \
|
189 |
+
original_size, input_size, img_caption
|
190 |
+
|
191 |
+
|
192 |
+
def inference_click(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
193 |
+
length, image_embedding, state, click_state, original_size, input_size, text_refiner,
|
194 |
+
evt: gr.SelectData):
|
195 |
+
click_index = evt.index
|
196 |
+
|
197 |
+
if point_prompt == 'Positive':
|
198 |
+
coordinate = "[[{}, {}, 1]]".format(str(click_index[0]), str(click_index[1]))
|
199 |
+
else:
|
200 |
+
coordinate = "[[{}, {}, 0]]".format(str(click_index[0]), str(click_index[1]))
|
201 |
+
|
202 |
+
prompt = get_click_prompt(coordinate, click_state, click_mode)
|
203 |
+
input_points = prompt['input_point']
|
204 |
+
input_labels = prompt['input_label']
|
205 |
+
|
206 |
+
controls = {'length': length,
|
207 |
+
'sentiment': sentiment,
|
208 |
+
'factuality': factuality,
|
209 |
+
'language': language}
|
210 |
+
|
211 |
+
model = build_caption_anything_with_models(
|
212 |
+
args,
|
213 |
+
api_key="",
|
214 |
+
captioner=shared_captioner,
|
215 |
+
sam_model=shared_sam_model,
|
216 |
+
text_refiner=text_refiner,
|
217 |
+
session_id=iface.app_id
|
218 |
+
)
|
219 |
+
|
220 |
+
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
221 |
+
|
222 |
+
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
223 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
224 |
+
|
225 |
+
state = state + [("Image point: {}, Input label: {}".format(prompt["input_point"], prompt["input_label"]), None)]
|
226 |
+
state = state + [(None, "raw_caption: {}".format(out['generated_captions']['raw_caption']))]
|
227 |
+
wiki = out['generated_captions'].get('wiki', "")
|
228 |
+
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
229 |
+
text = out['generated_captions']['raw_caption']
|
230 |
+
input_mask = np.array(out['mask'].convert('P'))
|
231 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
232 |
+
origin_image_input = image_input
|
233 |
+
image_input = create_bubble_frame(image_input, text, (click_index[0], click_index[1]), input_mask,
|
234 |
+
input_points=input_points, input_labels=input_labels)
|
235 |
+
yield state, state, click_state, image_input, wiki
|
236 |
+
if not args.disable_gpt and model.text_refiner:
|
237 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
238 |
+
enable_wiki=enable_wiki)
|
239 |
+
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
240 |
+
new_cap = refined_caption['caption']
|
241 |
+
wiki = refined_caption['wiki']
|
242 |
+
state = state + [(None, f"caption: {new_cap}")]
|
243 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, (click_index[0], click_index[1]),
|
244 |
+
input_mask,
|
245 |
+
input_points=input_points, input_labels=input_labels)
|
246 |
+
yield state, state, click_state, refined_image_input, wiki
|
247 |
+
|
248 |
+
|
249 |
+
def get_sketch_prompt(mask: PIL.Image.Image, multi_mask=True):
|
250 |
+
"""
|
251 |
+
Get the prompt for the sketcher.
|
252 |
+
TODO: This is a temporary solution. We should cluster the sketch and get the bounding box of each cluster.
|
253 |
+
"""
|
254 |
+
|
255 |
+
mask = np.array(np.asarray(mask)[..., 0])
|
256 |
+
mask[mask > 0] = 1 # Refine the mask, let all nonzero values be 1
|
257 |
+
|
258 |
+
if not multi_mask:
|
259 |
+
y, x = np.where(mask == 1)
|
260 |
+
x1, y1 = np.min(x), np.min(y)
|
261 |
+
x2, y2 = np.max(x), np.max(y)
|
262 |
+
|
263 |
+
prompt = {
|
264 |
+
'prompt_type': ['box'],
|
265 |
+
'input_boxes': [
|
266 |
+
[x1, y1, x2, y2]
|
267 |
+
]
|
268 |
+
}
|
269 |
+
|
270 |
+
return prompt
|
271 |
+
|
272 |
+
traversed = np.zeros_like(mask)
|
273 |
+
groups = np.zeros_like(mask)
|
274 |
+
max_group_id = 1
|
275 |
+
|
276 |
+
# Iterate over all pixels
|
277 |
+
for x in range(mask.shape[0]):
|
278 |
+
for y in range(mask.shape[1]):
|
279 |
+
if traversed[x, y] == 1:
|
280 |
+
continue
|
281 |
+
|
282 |
+
if mask[x, y] == 0:
|
283 |
+
traversed[x, y] = 1
|
284 |
+
else:
|
285 |
+
# If pixel is part of mask
|
286 |
+
groups[x, y] = max_group_id
|
287 |
+
stack = [(x, y)]
|
288 |
+
while stack:
|
289 |
+
i, j = stack.pop()
|
290 |
+
if traversed[i, j] == 1:
|
291 |
+
continue
|
292 |
+
traversed[i, j] = 1
|
293 |
+
if mask[i, j] == 1:
|
294 |
+
groups[i, j] = max_group_id
|
295 |
+
for di, dj in [(1, 0), (-1, 0), (0, 1), (0, -1), (1, 1), (1, -1), (-1, 1), (-1, -1)]:
|
296 |
+
ni, nj = i + di, j + dj
|
297 |
+
traversed[i, j] = 1
|
298 |
+
if 0 <= nj < mask.shape[1] and mask.shape[0] > ni >= 0 == traversed[ni, nj]:
|
299 |
+
stack.append((i + di, j + dj))
|
300 |
+
max_group_id += 1
|
301 |
+
|
302 |
+
# get the bounding box of each group
|
303 |
+
boxes = []
|
304 |
+
for group in range(1, max_group_id):
|
305 |
+
y, x = np.where(groups == group)
|
306 |
+
x1, y1 = np.min(x), np.min(y)
|
307 |
+
x2, y2 = np.max(x), np.max(y)
|
308 |
+
boxes.append([x1, y1, x2, y2])
|
309 |
+
|
310 |
+
prompt = {
|
311 |
+
'prompt_type': ['box'],
|
312 |
+
'input_boxes': boxes
|
313 |
+
}
|
314 |
+
|
315 |
+
return prompt
|
316 |
+
|
317 |
+
|
318 |
+
def inference_traject(sketcher_image, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
319 |
+
original_size, input_size, text_refiner):
|
320 |
+
image_input, mask = sketcher_image['image'], sketcher_image['mask']
|
321 |
+
|
322 |
+
prompt = get_sketch_prompt(mask, multi_mask=False)
|
323 |
+
boxes = prompt['input_boxes']
|
324 |
+
|
325 |
+
controls = {'length': length,
|
326 |
+
'sentiment': sentiment,
|
327 |
+
'factuality': factuality,
|
328 |
+
'language': language}
|
329 |
+
|
330 |
+
model = build_caption_anything_with_models(
|
331 |
+
args,
|
332 |
+
api_key="",
|
333 |
+
captioner=shared_captioner,
|
334 |
+
sam_model=shared_sam_model,
|
335 |
+
text_refiner=text_refiner,
|
336 |
+
session_id=iface.app_id
|
337 |
+
)
|
338 |
+
|
339 |
+
model.setup(image_embedding, original_size, input_size, is_image_set=True)
|
340 |
+
|
341 |
+
enable_wiki = True if enable_wiki in ['True', 'TRUE', 'true', True, 'Yes', 'YES', 'yes'] else False
|
342 |
+
out = model.inference(image_input, prompt, controls, disable_gpt=True, enable_wiki=enable_wiki)
|
343 |
+
|
344 |
+
# Update components and states
|
345 |
+
state.append((f'Box: {boxes}', None))
|
346 |
+
state.append((None, f'raw_caption: {out["generated_captions"]["raw_caption"]}'))
|
347 |
+
wiki = out['generated_captions'].get('wiki', "")
|
348 |
+
text = out['generated_captions']['raw_caption']
|
349 |
+
input_mask = np.array(out['mask'].convert('P'))
|
350 |
+
image_input = mask_painter(np.array(image_input), input_mask)
|
351 |
+
|
352 |
+
origin_image_input = image_input
|
353 |
+
|
354 |
+
fake_click_index = (int((boxes[0][0] + boxes[0][2]) / 2), int((boxes[0][1] + boxes[0][3]) / 2))
|
355 |
+
image_input = create_bubble_frame(image_input, text, fake_click_index, input_mask)
|
356 |
+
|
357 |
+
yield state, state, image_input, wiki
|
358 |
+
|
359 |
+
if not args.disable_gpt and model.text_refiner:
|
360 |
+
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'],
|
361 |
+
enable_wiki=enable_wiki)
|
362 |
+
|
363 |
+
new_cap = refined_caption['caption']
|
364 |
+
wiki = refined_caption['wiki']
|
365 |
+
state = state + [(None, f"caption: {new_cap}")]
|
366 |
+
refined_image_input = create_bubble_frame(origin_image_input, new_cap, fake_click_index, input_mask)
|
367 |
+
|
368 |
+
yield state, state, refined_image_input, wiki
|
369 |
+
|
370 |
+
|
371 |
+
def get_style():
|
372 |
+
current_version = version.parse(gr.__version__)
|
373 |
+
if current_version <= version.parse('3.24.1'):
|
374 |
+
style = '''
|
375 |
+
#image_sketcher{min-height:500px}
|
376 |
+
#image_sketcher [data-testid="image"], #image_sketcher [data-testid="image"] > div{min-height: 500px}
|
377 |
+
#image_upload{min-height:500px}
|
378 |
+
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 500px}
|
379 |
+
'''
|
380 |
+
elif current_version <= version.parse('3.27'):
|
381 |
+
style = '''
|
382 |
+
#image_sketcher{min-height:500px}
|
383 |
+
#image_upload{min-height:500px}
|
384 |
+
'''
|
385 |
+
else:
|
386 |
+
style = None
|
387 |
+
|
388 |
+
return style
|
389 |
+
|
390 |
+
|
391 |
+
def create_ui():
|
392 |
+
title = """<p><h1 align="center">Caption-Anything</h1></p>
|
393 |
+
"""
|
394 |
+
description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: <a href="https://github.com/ttengwang/Caption-Anything">https://github.com/ttengwang/Caption-Anything</a> <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
395 |
+
|
396 |
+
examples = [
|
397 |
+
["test_images/img35.webp"],
|
398 |
+
["test_images/img2.jpg"],
|
399 |
+
["test_images/img5.jpg"],
|
400 |
+
["test_images/img12.jpg"],
|
401 |
+
["test_images/img14.jpg"],
|
402 |
+
["test_images/qingming3.jpeg"],
|
403 |
+
["test_images/img1.jpg"],
|
404 |
+
]
|
405 |
+
|
406 |
+
with gr.Blocks(
|
407 |
+
css=get_style()
|
408 |
+
) as iface:
|
409 |
+
state = gr.State([])
|
410 |
+
click_state = gr.State([[], [], []])
|
411 |
+
chat_state = gr.State([])
|
412 |
+
origin_image = gr.State(None)
|
413 |
+
image_embedding = gr.State(None)
|
414 |
+
text_refiner = gr.State(None)
|
415 |
+
original_size = gr.State(None)
|
416 |
+
input_size = gr.State(None)
|
417 |
+
img_caption = gr.State(None)
|
418 |
+
|
419 |
+
gr.Markdown(title)
|
420 |
+
gr.Markdown(description)
|
421 |
+
|
422 |
+
with gr.Row():
|
423 |
+
with gr.Column(scale=1.0):
|
424 |
+
with gr.Column(visible=False) as modules_not_need_gpt:
|
425 |
+
with gr.Tab("Click"):
|
426 |
+
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
427 |
+
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
428 |
+
with gr.Row(scale=1.0):
|
429 |
+
with gr.Row(scale=0.4):
|
430 |
+
point_prompt = gr.Radio(
|
431 |
+
choices=["Positive", "Negative"],
|
432 |
+
value="Positive",
|
433 |
+
label="Point Prompt",
|
434 |
+
interactive=True)
|
435 |
+
click_mode = gr.Radio(
|
436 |
+
choices=["Continuous", "Single"],
|
437 |
+
value="Continuous",
|
438 |
+
label="Clicking Mode",
|
439 |
+
interactive=True)
|
440 |
+
with gr.Row(scale=0.4):
|
441 |
+
clear_button_click = gr.Button(value="Clear Clicks", interactive=True)
|
442 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
443 |
+
with gr.Tab("Trajectory (Beta)"):
|
444 |
+
sketcher_input = ImageSketcher(type="pil", interactive=True, brush_radius=20,
|
445 |
+
elem_id="image_sketcher")
|
446 |
+
with gr.Row():
|
447 |
+
submit_button_sketcher = gr.Button(value="Submit", interactive=True)
|
448 |
+
|
449 |
+
with gr.Column(visible=False) as modules_need_gpt:
|
450 |
+
with gr.Row(scale=1.0):
|
451 |
+
language = gr.Dropdown(
|
452 |
+
['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"],
|
453 |
+
value="English", label="Language", interactive=True)
|
454 |
+
sentiment = gr.Radio(
|
455 |
+
choices=["Positive", "Natural", "Negative"],
|
456 |
+
value="Natural",
|
457 |
+
label="Sentiment",
|
458 |
+
interactive=True,
|
459 |
+
)
|
460 |
+
with gr.Row(scale=1.0):
|
461 |
+
factuality = gr.Radio(
|
462 |
+
choices=["Factual", "Imagination"],
|
463 |
+
value="Factual",
|
464 |
+
label="Factuality",
|
465 |
+
interactive=True,
|
466 |
+
)
|
467 |
+
length = gr.Slider(
|
468 |
+
minimum=10,
|
469 |
+
maximum=80,
|
470 |
+
value=10,
|
471 |
+
step=1,
|
472 |
+
interactive=True,
|
473 |
+
label="Generated Caption Length",
|
474 |
+
)
|
475 |
+
enable_wiki = gr.Radio(
|
476 |
+
choices=["Yes", "No"],
|
477 |
+
value="No",
|
478 |
+
label="Enable Wiki",
|
479 |
+
interactive=True)
|
480 |
+
with gr.Column(visible=True) as modules_not_need_gpt3:
|
481 |
+
gr.Examples(
|
482 |
+
examples=examples,
|
483 |
+
inputs=[example_image],
|
484 |
+
)
|
485 |
+
with gr.Column(scale=0.5):
|
486 |
+
openai_api_key = gr.Textbox(
|
487 |
+
placeholder="Input openAI API key",
|
488 |
+
show_label=False,
|
489 |
+
label="OpenAI API Key",
|
490 |
+
lines=1,
|
491 |
+
type="password")
|
492 |
+
with gr.Row(scale=0.5):
|
493 |
+
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
|
494 |
+
disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True,
|
495 |
+
variant='primary')
|
496 |
+
with gr.Column(visible=False) as modules_need_gpt2:
|
497 |
+
wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
|
498 |
+
with gr.Column(visible=False) as modules_not_need_gpt2:
|
499 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object", ).style(height=550, scale=0.5)
|
500 |
+
with gr.Column(visible=False) as modules_need_gpt3:
|
501 |
+
chat_input = gr.Textbox(show_label=False, placeholder="Enter text and press Enter").style(
|
502 |
+
container=False)
|
503 |
+
with gr.Row():
|
504 |
+
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
505 |
+
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
506 |
+
|
507 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key],
|
508 |
+
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt,
|
509 |
+
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
510 |
+
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key],
|
511 |
+
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
|
512 |
+
modules_not_need_gpt,
|
513 |
+
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
514 |
+
disable_chatGPT_button.click(init_openai_api_key,
|
515 |
+
outputs=[modules_need_gpt, modules_need_gpt2, modules_need_gpt3,
|
516 |
+
modules_not_need_gpt,
|
517 |
+
modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
518 |
+
|
519 |
+
clear_button_click.click(
|
520 |
+
lambda x: ([[], [], []], x, ""),
|
521 |
+
[origin_image],
|
522 |
+
[click_state, image_input, wiki_output],
|
523 |
+
queue=False,
|
524 |
+
show_progress=False
|
525 |
+
)
|
526 |
+
clear_button_image.click(
|
527 |
+
lambda: (None, [], [], [], [[], [], []], "", "", ""),
|
528 |
+
[],
|
529 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
|
530 |
+
queue=False,
|
531 |
+
show_progress=False
|
532 |
+
)
|
533 |
+
clear_button_text.click(
|
534 |
+
lambda: ([], [], [[], [], [], []], []),
|
535 |
+
[],
|
536 |
+
[chatbot, state, click_state, chat_state],
|
537 |
+
queue=False,
|
538 |
+
show_progress=False
|
539 |
+
)
|
540 |
+
image_input.clear(
|
541 |
+
lambda: (None, [], [], [], [[], [], []], "", "", ""),
|
542 |
+
[],
|
543 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
|
544 |
+
queue=False,
|
545 |
+
show_progress=False
|
546 |
+
)
|
547 |
+
|
548 |
+
image_input.upload(upload_callback, [image_input, state],
|
549 |
+
[chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
|
550 |
+
image_embedding, original_size, input_size, img_caption])
|
551 |
+
sketcher_input.upload(upload_callback, [sketcher_input, state],
|
552 |
+
[chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
|
553 |
+
image_embedding, original_size, input_size, img_caption])
|
554 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption],
|
555 |
+
[chatbot, state, chat_state])
|
556 |
+
chat_input.submit(lambda: "", None, chat_input)
|
557 |
+
example_image.change(upload_callback, [example_image, state],
|
558 |
+
[chatbot, state, chat_state, origin_image, click_state, image_input, sketcher_input,
|
559 |
+
image_embedding, original_size, input_size, img_caption])
|
560 |
+
|
561 |
+
# select coordinate
|
562 |
+
image_input.select(
|
563 |
+
inference_click,
|
564 |
+
inputs=[
|
565 |
+
origin_image, point_prompt, click_mode, enable_wiki, language, sentiment, factuality, length,
|
566 |
+
image_embedding, state, click_state, original_size, input_size, text_refiner
|
567 |
+
],
|
568 |
+
outputs=[chatbot, state, click_state, image_input, wiki_output],
|
569 |
+
show_progress=False, queue=True
|
570 |
+
)
|
571 |
+
|
572 |
+
submit_button_sketcher.click(
|
573 |
+
inference_traject,
|
574 |
+
inputs=[
|
575 |
+
sketcher_input, enable_wiki, language, sentiment, factuality, length, image_embedding, state,
|
576 |
+
original_size, input_size, text_refiner
|
577 |
+
],
|
578 |
+
outputs=[chatbot, state, sketcher_input, wiki_output],
|
579 |
+
show_progress=False, queue=True
|
580 |
+
)
|
581 |
+
|
582 |
+
return iface
|
583 |
+
|
584 |
+
|
585 |
+
if __name__ == '__main__':
|
586 |
+
iface = create_ui()
|
587 |
+
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
588 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
|
caas.py
DELETED
@@ -1,114 +0,0 @@
|
|
1 |
-
from captioner import build_captioner, BaseCaptioner
|
2 |
-
from segmenter import build_segmenter
|
3 |
-
from text_refiner import build_text_refiner
|
4 |
-
import os
|
5 |
-
import argparse
|
6 |
-
import pdb
|
7 |
-
import time
|
8 |
-
from PIL import Image
|
9 |
-
|
10 |
-
class CaptionAnything():
|
11 |
-
def __init__(self, args):
|
12 |
-
self.args = args
|
13 |
-
self.captioner = build_captioner(args.captioner, args.device, args)
|
14 |
-
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
15 |
-
if not args.disable_gpt:
|
16 |
-
self.init_refiner()
|
17 |
-
|
18 |
-
|
19 |
-
def init_refiner(self):
|
20 |
-
if os.environ.get('OPENAI_API_KEY', None):
|
21 |
-
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args)
|
22 |
-
|
23 |
-
def inference(self, image, prompt, controls, disable_gpt=False):
|
24 |
-
# segment with prompt
|
25 |
-
print("CA prompt: ", prompt, "CA controls",controls)
|
26 |
-
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
27 |
-
mask_save_path = f'result/mask_{time.time()}.png'
|
28 |
-
if not os.path.exists(os.path.dirname(mask_save_path)):
|
29 |
-
os.makedirs(os.path.dirname(mask_save_path))
|
30 |
-
new_p = Image.fromarray(seg_mask.astype('int') * 255.)
|
31 |
-
if new_p.mode != 'RGB':
|
32 |
-
new_p = new_p.convert('RGB')
|
33 |
-
new_p.save(mask_save_path)
|
34 |
-
print('seg_mask path: ', mask_save_path)
|
35 |
-
print("seg_mask.shape: ", seg_mask.shape)
|
36 |
-
# captioning with mask
|
37 |
-
if self.args.enable_reduce_tokens:
|
38 |
-
caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
|
39 |
-
else:
|
40 |
-
caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, regular_box = self.args.regular_box)
|
41 |
-
# refining with TextRefiner
|
42 |
-
context_captions = []
|
43 |
-
if self.args.context_captions:
|
44 |
-
context_captions.append(self.captioner.inference(image))
|
45 |
-
if not disable_gpt and hasattr(self, "text_refiner"):
|
46 |
-
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
47 |
-
else:
|
48 |
-
refined_caption = {'raw_caption': caption}
|
49 |
-
out = {'generated_captions': refined_caption,
|
50 |
-
'crop_save_path': crop_save_path,
|
51 |
-
'mask_save_path': mask_save_path,
|
52 |
-
'context_captions': context_captions}
|
53 |
-
return out
|
54 |
-
|
55 |
-
def parse_augment():
|
56 |
-
parser = argparse.ArgumentParser()
|
57 |
-
parser.add_argument('--captioner', type=str, default="blip")
|
58 |
-
parser.add_argument('--segmenter', type=str, default="base")
|
59 |
-
parser.add_argument('--text_refiner', type=str, default="base")
|
60 |
-
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
61 |
-
parser.add_argument('--seg_crop_mode', type=str, default="w_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
62 |
-
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
63 |
-
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption")
|
64 |
-
parser.add_argument('--regular_box', action="store_true", default = False, help="crop image with a regular box")
|
65 |
-
parser.add_argument('--device', type=str, default="cuda:0")
|
66 |
-
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
67 |
-
parser.add_argument('--debug', action="store_true")
|
68 |
-
parser.add_argument('--gradio_share', action="store_true")
|
69 |
-
parser.add_argument('--disable_gpt', action="store_true")
|
70 |
-
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
71 |
-
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
72 |
-
args = parser.parse_args()
|
73 |
-
|
74 |
-
if args.debug:
|
75 |
-
print(args)
|
76 |
-
return args
|
77 |
-
|
78 |
-
if __name__ == "__main__":
|
79 |
-
args = parse_augment()
|
80 |
-
# image_path = 'test_img/img3.jpg'
|
81 |
-
image_path = 'test_img/img13.jpg'
|
82 |
-
prompts = [
|
83 |
-
{
|
84 |
-
"prompt_type":["click"],
|
85 |
-
"input_point":[[500, 300], [1000, 500]],
|
86 |
-
"input_label":[1, 0],
|
87 |
-
"multimask_output":"True",
|
88 |
-
},
|
89 |
-
{
|
90 |
-
"prompt_type":["click"],
|
91 |
-
"input_point":[[900, 800]],
|
92 |
-
"input_label":[1],
|
93 |
-
"multimask_output":"True",
|
94 |
-
}
|
95 |
-
]
|
96 |
-
controls = {
|
97 |
-
"length": "30",
|
98 |
-
"sentiment": "positive",
|
99 |
-
# "imagination": "True",
|
100 |
-
"imagination": "False",
|
101 |
-
"language": "English",
|
102 |
-
}
|
103 |
-
|
104 |
-
model = CaptionAnything(args)
|
105 |
-
for prompt in prompts:
|
106 |
-
print('*'*30)
|
107 |
-
print('Image path: ', image_path)
|
108 |
-
image = Image.open(image_path)
|
109 |
-
print(image)
|
110 |
-
print('Visual controls (SAM prompt):\n', prompt)
|
111 |
-
print('Language controls:\n', controls)
|
112 |
-
out = model.inference(image_path, prompt, controls)
|
113 |
-
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
caption_anything/__init__.py
ADDED
File without changes
|
{captioner β caption_anything/captioner}/README.md
RENAMED
File without changes
|
{captioner β caption_anything/captioner}/__init__.py
RENAMED
File without changes
|
{captioner β caption_anything/captioner}/base_captioner.py
RENAMED
@@ -191,7 +191,7 @@ class BaseCaptioner:
|
|
191 |
|
192 |
if __name__ == '__main__':
|
193 |
model = BaseCaptioner(device='cuda:0')
|
194 |
-
image_path = '
|
195 |
seg_mask = np.zeros((15,15))
|
196 |
seg_mask[5:10, 5:10] = 1
|
197 |
seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
|
|
|
191 |
|
192 |
if __name__ == '__main__':
|
193 |
model = BaseCaptioner(device='cuda:0')
|
194 |
+
image_path = 'test_images/img2.jpg'
|
195 |
seg_mask = np.zeros((15,15))
|
196 |
seg_mask[5:10, 5:10] = 1
|
197 |
seg_mask = 'image/SAM/img10.jpg.raw_mask.png'
|
{captioner β caption_anything/captioner}/blip.py
RENAMED
@@ -54,13 +54,13 @@ class BLIPCaptioner(BaseCaptioner):
|
|
54 |
|
55 |
if __name__ == '__main__':
|
56 |
model = BLIPCaptioner(device='cuda:0')
|
57 |
-
# image_path = '
|
58 |
-
image_path = '
|
59 |
seg_mask = np.zeros((15,15))
|
60 |
seg_mask[5:10, 5:10] = 1
|
61 |
-
seg_mask = '
|
62 |
-
image_path = '
|
63 |
-
seg_mask = '
|
64 |
print(f'process image {image_path}')
|
65 |
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
66 |
|
|
|
54 |
|
55 |
if __name__ == '__main__':
|
56 |
model = BLIPCaptioner(device='cuda:0')
|
57 |
+
# image_path = 'test_images/img2.jpg'
|
58 |
+
image_path = 'image/SAM/img10.jpg'
|
59 |
seg_mask = np.zeros((15,15))
|
60 |
seg_mask[5:10, 5:10] = 1
|
61 |
+
seg_mask = 'test_images/img10.jpg.raw_mask.png'
|
62 |
+
image_path = 'test_images/img2.jpg'
|
63 |
+
seg_mask = 'test_images/img2.jpg.raw_mask.png'
|
64 |
print(f'process image {image_path}')
|
65 |
print(model.inference_with_reduced_tokens(image_path, seg_mask))
|
66 |
|
{captioner β caption_anything/captioner}/blip2.py
RENAMED
@@ -1,13 +1,10 @@
|
|
1 |
import torch
|
2 |
-
from PIL import Image
|
3 |
-
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
4 |
-
import json
|
5 |
-
import pdb
|
6 |
-
import cv2
|
7 |
import numpy as np
|
8 |
from typing import Union
|
|
|
9 |
|
10 |
-
from
|
11 |
from .base_captioner import BaseCaptioner
|
12 |
|
13 |
class BLIP2Captioner(BaseCaptioner):
|
@@ -55,7 +52,7 @@ if __name__ == '__main__':
|
|
55 |
|
56 |
dialogue = False
|
57 |
model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
|
58 |
-
image_path = '
|
59 |
seg_mask = np.zeros((224,224))
|
60 |
seg_mask[50:200, 50:200] = 1
|
61 |
print(f'process image {image_path}')
|
|
|
1 |
import torch
|
2 |
+
from PIL import Image
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
4 |
from typing import Union
|
5 |
+
from transformers import AutoProcessor, Blip2ForConditionalGeneration
|
6 |
|
7 |
+
from caption_anything.utils.utils import is_platform_win
|
8 |
from .base_captioner import BaseCaptioner
|
9 |
|
10 |
class BLIP2Captioner(BaseCaptioner):
|
|
|
52 |
|
53 |
dialogue = False
|
54 |
model = BLIP2Captioner(device='cuda:4', dialogue = dialogue, cache_dir = '/nvme-ssd/fjj/Caption-Anything/model_cache')
|
55 |
+
image_path = 'test_images/img2.jpg'
|
56 |
seg_mask = np.zeros((224,224))
|
57 |
seg_mask[50:200, 50:200] = 1
|
58 |
print(f'process image {image_path}')
|
{captioner β caption_anything/captioner}/git.py
RENAMED
@@ -50,7 +50,7 @@ class GITCaptioner(BaseCaptioner):
|
|
50 |
|
51 |
if __name__ == '__main__':
|
52 |
model = GITCaptioner(device='cuda:2', enable_filter=False)
|
53 |
-
image_path = '
|
54 |
seg_mask = np.zeros((224,224))
|
55 |
seg_mask[50:200, 50:200] = 1
|
56 |
print(f'process image {image_path}')
|
|
|
50 |
|
51 |
if __name__ == '__main__':
|
52 |
model = GITCaptioner(device='cuda:2', enable_filter=False)
|
53 |
+
image_path = 'test_images/img2.jpg'
|
54 |
seg_mask = np.zeros((224,224))
|
55 |
seg_mask[50:200, 50:200] = 1
|
56 |
print(f'process image {image_path}')
|
{captioner β caption_anything/captioner}/modeling_blip.py
RENAMED
File without changes
|
{captioner β caption_anything/captioner}/modeling_git.py
RENAMED
File without changes
|
{captioner β caption_anything/captioner}/vit_pixel_masks_utils.py
RENAMED
File without changes
|
caption_anything.py β caption_anything/model.py
RENAMED
@@ -1,6 +1,3 @@
|
|
1 |
-
from captioner import build_captioner, BaseCaptioner
|
2 |
-
from segmenter import build_segmenter
|
3 |
-
from text_refiner import build_text_refiner
|
4 |
import os
|
5 |
import argparse
|
6 |
import pdb
|
@@ -8,13 +5,17 @@ import time
|
|
8 |
from PIL import Image
|
9 |
import cv2
|
10 |
import numpy as np
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
13 |
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
|
14 |
self.args = args
|
15 |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
16 |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
17 |
-
|
18 |
self.text_refiner = None
|
19 |
if not args.disable_gpt:
|
20 |
if text_refiner is not None:
|
@@ -22,24 +23,54 @@ class CaptionAnything():
|
|
22 |
else:
|
23 |
self.init_refiner(api_key)
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
def init_refiner(self, api_key):
|
26 |
try:
|
27 |
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
|
28 |
-
self.text_refiner.llm('hi')
|
29 |
except:
|
30 |
self.text_refiner = None
|
31 |
print('OpenAI GPT is not available')
|
32 |
-
|
33 |
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
|
34 |
# segment with prompt
|
35 |
-
print("CA prompt: ", prompt, "CA controls",controls)
|
36 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
37 |
if self.args.enable_morphologyex:
|
38 |
seg_mask = 255 * seg_mask.astype(np.uint8)
|
39 |
-
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis
|
40 |
-
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel
|
41 |
-
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel
|
42 |
-
seg_mask = seg_mask[
|
43 |
mask_save_path = f'result/mask_{time.time()}.png'
|
44 |
if not os.path.exists(os.path.dirname(mask_save_path)):
|
45 |
os.makedirs(os.path.dirname(mask_save_path))
|
@@ -51,82 +82,66 @@ class CaptionAnything():
|
|
51 |
print("seg_mask.shape: ", seg_mask.shape)
|
52 |
# captioning with mask
|
53 |
if self.args.enable_reduce_tokens:
|
54 |
-
caption, crop_save_path = self.captioner.
|
|
|
|
|
|
|
|
|
55 |
else:
|
56 |
-
caption, crop_save_path = self.captioner.
|
|
|
|
|
|
|
57 |
# refining with TextRefiner
|
58 |
context_captions = []
|
59 |
if self.args.context_captions:
|
60 |
context_captions.append(self.captioner.inference(image))
|
61 |
if not disable_gpt and self.text_refiner is not None:
|
62 |
-
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
|
|
|
63 |
else:
|
64 |
-
refined_caption = {'raw_caption': caption}
|
65 |
out = {'generated_captions': refined_caption,
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
return out
|
71 |
-
|
72 |
-
def parse_augment():
|
73 |
-
parser = argparse.ArgumentParser()
|
74 |
-
parser.add_argument('--captioner', type=str, default="blip2")
|
75 |
-
parser.add_argument('--segmenter', type=str, default="huge")
|
76 |
-
parser.add_argument('--text_refiner', type=str, default="base")
|
77 |
-
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
78 |
-
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
79 |
-
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
80 |
-
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
|
81 |
-
parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
|
82 |
-
parser.add_argument('--device', type=str, default="cuda:0")
|
83 |
-
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
84 |
-
parser.add_argument('--debug', action="store_true")
|
85 |
-
parser.add_argument('--gradio_share', action="store_true")
|
86 |
-
parser.add_argument('--disable_gpt', action="store_true")
|
87 |
-
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
88 |
-
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
89 |
-
parser.add_argument('--enable_morphologyex', action="store_true", default=False)
|
90 |
-
args = parser.parse_args()
|
91 |
|
92 |
-
if args.debug:
|
93 |
-
print(args)
|
94 |
-
return args
|
95 |
|
96 |
if __name__ == "__main__":
|
|
|
97 |
args = parse_augment()
|
98 |
-
# image_path = '
|
99 |
-
image_path = '
|
100 |
prompts = [
|
101 |
{
|
102 |
-
"prompt_type":["click"],
|
103 |
-
"input_point":[[500, 300], [
|
104 |
-
"input_label":[1, 0],
|
105 |
-
"multimask_output":"True",
|
106 |
},
|
107 |
{
|
108 |
-
"prompt_type":["click"],
|
109 |
-
"input_point":[[
|
110 |
-
"input_label":[1],
|
111 |
-
"multimask_output":"True",
|
112 |
}
|
113 |
]
|
114 |
controls = {
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
|
123 |
for prompt in prompts:
|
124 |
-
print('*'*30)
|
125 |
print('Image path: ', image_path)
|
126 |
image = Image.open(image_path)
|
127 |
print(image)
|
128 |
print('Visual controls (SAM prompt):\n', prompt)
|
129 |
print('Language controls:\n', controls)
|
130 |
out = model.inference(image_path, prompt, controls)
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import argparse
|
3 |
import pdb
|
|
|
5 |
from PIL import Image
|
6 |
import cv2
|
7 |
import numpy as np
|
8 |
+
from caption_anything.captioner import build_captioner, BaseCaptioner
|
9 |
+
from caption_anything.segmenter import build_segmenter
|
10 |
+
from caption_anything.text_refiner import build_text_refiner
|
11 |
|
12 |
+
|
13 |
+
class CaptionAnything:
|
14 |
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
|
15 |
self.args = args
|
16 |
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
17 |
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
18 |
+
|
19 |
self.text_refiner = None
|
20 |
if not args.disable_gpt:
|
21 |
if text_refiner is not None:
|
|
|
23 |
else:
|
24 |
self.init_refiner(api_key)
|
25 |
|
26 |
+
@property
|
27 |
+
def image_embedding(self):
|
28 |
+
return self.segmenter.image_embedding
|
29 |
+
|
30 |
+
@image_embedding.setter
|
31 |
+
def image_embedding(self, image_embedding):
|
32 |
+
self.segmenter.image_embedding = image_embedding
|
33 |
+
|
34 |
+
@property
|
35 |
+
def original_size(self):
|
36 |
+
return self.segmenter.predictor.original_size
|
37 |
+
|
38 |
+
@original_size.setter
|
39 |
+
def original_size(self, original_size):
|
40 |
+
self.segmenter.predictor.original_size = original_size
|
41 |
+
|
42 |
+
@property
|
43 |
+
def input_size(self):
|
44 |
+
return self.segmenter.predictor.input_size
|
45 |
+
|
46 |
+
@input_size.setter
|
47 |
+
def input_size(self, input_size):
|
48 |
+
self.segmenter.predictor.input_size = input_size
|
49 |
+
|
50 |
+
def setup(self, image_embedding, original_size, input_size, is_image_set):
|
51 |
+
self.image_embedding = image_embedding
|
52 |
+
self.original_size = original_size
|
53 |
+
self.input_size = input_size
|
54 |
+
self.segmenter.predictor.is_image_set = is_image_set
|
55 |
+
|
56 |
def init_refiner(self, api_key):
|
57 |
try:
|
58 |
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
|
59 |
+
self.text_refiner.llm('hi') # test
|
60 |
except:
|
61 |
self.text_refiner = None
|
62 |
print('OpenAI GPT is not available')
|
63 |
+
|
64 |
def inference(self, image, prompt, controls, disable_gpt=False, enable_wiki=False):
|
65 |
# segment with prompt
|
66 |
+
print("CA prompt: ", prompt, "CA controls", controls)
|
67 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
68 |
if self.args.enable_morphologyex:
|
69 |
seg_mask = 255 * seg_mask.astype(np.uint8)
|
70 |
+
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis=-1)
|
71 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel=np.ones((6, 6), np.uint8))
|
72 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel=np.ones((6, 6), np.uint8))
|
73 |
+
seg_mask = seg_mask[:, :, 0] > 0
|
74 |
mask_save_path = f'result/mask_{time.time()}.png'
|
75 |
if not os.path.exists(os.path.dirname(mask_save_path)):
|
76 |
os.makedirs(os.path.dirname(mask_save_path))
|
|
|
82 |
print("seg_mask.shape: ", seg_mask.shape)
|
83 |
# captioning with mask
|
84 |
if self.args.enable_reduce_tokens:
|
85 |
+
caption, crop_save_path = self.captioner. \
|
86 |
+
inference_with_reduced_tokens(image, seg_mask,
|
87 |
+
crop_mode=self.args.seg_crop_mode,
|
88 |
+
filter=self.args.clip_filter,
|
89 |
+
disable_regular_box=self.args.disable_regular_box)
|
90 |
else:
|
91 |
+
caption, crop_save_path = self.captioner. \
|
92 |
+
inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode,
|
93 |
+
filter=self.args.clip_filter,
|
94 |
+
disable_regular_box=self.args.disable_regular_box)
|
95 |
# refining with TextRefiner
|
96 |
context_captions = []
|
97 |
if self.args.context_captions:
|
98 |
context_captions.append(self.captioner.inference(image))
|
99 |
if not disable_gpt and self.text_refiner is not None:
|
100 |
+
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions,
|
101 |
+
enable_wiki=enable_wiki)
|
102 |
else:
|
103 |
+
refined_caption = {'raw_caption': caption}
|
104 |
out = {'generated_captions': refined_caption,
|
105 |
+
'crop_save_path': crop_save_path,
|
106 |
+
'mask_save_path': mask_save_path,
|
107 |
+
'mask': seg_mask_img,
|
108 |
+
'context_captions': context_captions}
|
109 |
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
|
|
|
|
|
|
111 |
|
112 |
if __name__ == "__main__":
|
113 |
+
from caption_anything.utils.parser import parse_augment
|
114 |
args = parse_augment()
|
115 |
+
# image_path = 'test_images/img3.jpg'
|
116 |
+
image_path = 'test_images/img1.jpg'
|
117 |
prompts = [
|
118 |
{
|
119 |
+
"prompt_type": ["click"],
|
120 |
+
"input_point": [[500, 300], [200, 500]],
|
121 |
+
"input_label": [1, 0],
|
122 |
+
"multimask_output": "True",
|
123 |
},
|
124 |
{
|
125 |
+
"prompt_type": ["click"],
|
126 |
+
"input_point": [[300, 800]],
|
127 |
+
"input_label": [1],
|
128 |
+
"multimask_output": "True",
|
129 |
}
|
130 |
]
|
131 |
controls = {
|
132 |
+
"length": "30",
|
133 |
+
"sentiment": "positive",
|
134 |
+
# "imagination": "True",
|
135 |
+
"imagination": "False",
|
136 |
+
"language": "English",
|
137 |
+
}
|
138 |
+
|
139 |
model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
|
140 |
for prompt in prompts:
|
141 |
+
print('*' * 30)
|
142 |
print('Image path: ', image_path)
|
143 |
image = Image.open(image_path)
|
144 |
print(image)
|
145 |
print('Visual controls (SAM prompt):\n', prompt)
|
146 |
print('Language controls:\n', controls)
|
147 |
out = model.inference(image_path, prompt, controls)
|
|
|
|
caption_anything/segmenter/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base_segmenter import BaseSegmenter
|
2 |
+
from caption_anything.utils.utils import seg_model_map
|
3 |
+
|
4 |
+
def build_segmenter(model_name, device, args=None, model=None):
|
5 |
+
return BaseSegmenter(device, args.segmenter_checkpoint, model_name, reuse_feature=not args.disable_reuse_features, model=model)
|
{segmenter β caption_anything/segmenter}/base_segmenter.py
RENAMED
@@ -5,19 +5,22 @@ from PIL import Image, ImageDraw, ImageOps
|
|
5 |
import numpy as np
|
6 |
from typing import Union
|
7 |
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
|
|
8 |
import matplotlib.pyplot as plt
|
9 |
import PIL
|
10 |
|
|
|
11 |
class BaseSegmenter:
|
12 |
-
def __init__(self, device, checkpoint,
|
13 |
print(f"Initializing BaseSegmenter to {device}")
|
14 |
self.device = device
|
15 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
16 |
self.processor = None
|
17 |
-
self.model_type = model_type
|
18 |
if model is None:
|
|
|
|
|
|
|
19 |
self.checkpoint = checkpoint
|
20 |
-
self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
|
21 |
self.model.to(device=self.device)
|
22 |
else:
|
23 |
self.model = model
|
@@ -27,26 +30,57 @@ class BaseSegmenter:
|
|
27 |
self.image_embedding = None
|
28 |
self.image = None
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
def set_image(self, image: Union[np.ndarray, Image.Image, str]):
|
33 |
-
if type(image) == str: # input path
|
34 |
image = Image.open(image)
|
35 |
image = np.array(image)
|
36 |
elif type(image) == Image.Image:
|
37 |
image = np.array(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
self.image = image
|
39 |
if self.reuse_feature:
|
40 |
self.predictor.set_image(image)
|
41 |
self.image_embedding = self.predictor.get_image_embedding()
|
42 |
print(self.image_embedding.shape)
|
43 |
|
44 |
-
|
45 |
@torch.no_grad()
|
46 |
-
def inference(self, image, control):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
if 'everything' in control['prompt_type']:
|
48 |
masks = self.mask_generator.generate(image)
|
49 |
-
new_masks = np.concatenate([mask["segmentation"][np.newaxis
|
50 |
return new_masks
|
51 |
else:
|
52 |
if not self.reuse_feature or self.image_embedding is None:
|
@@ -55,17 +89,17 @@ class BaseSegmenter:
|
|
55 |
else:
|
56 |
assert self.image_embedding is not None
|
57 |
self.predictor.features = self.image_embedding
|
58 |
-
|
59 |
if 'mutimask_output' in control:
|
60 |
masks, scores, logits = self.predictor.predict(
|
61 |
-
point_coords
|
62 |
-
point_labels
|
63 |
-
multimask_output
|
64 |
)
|
65 |
elif 'input_boxes' in control:
|
66 |
transformed_boxes = self.predictor.transform.apply_boxes_torch(
|
67 |
torch.tensor(control["input_boxes"], device=self.predictor.device),
|
68 |
-
image.shape[
|
69 |
)
|
70 |
masks, _, _ = self.predictor.predict_torch(
|
71 |
point_coords=None,
|
@@ -74,31 +108,32 @@ class BaseSegmenter:
|
|
74 |
multimask_output=False,
|
75 |
)
|
76 |
masks = masks.squeeze(1).cpu().numpy()
|
77 |
-
|
78 |
else:
|
79 |
input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
|
80 |
input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
|
81 |
input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
|
82 |
-
|
83 |
masks, scores, logits = self.predictor.predict(
|
84 |
-
point_coords
|
85 |
-
point_labels
|
86 |
-
box
|
87 |
-
multimask_output
|
88 |
)
|
89 |
-
|
90 |
if 0 in control['input_label']:
|
91 |
mask_input = logits[np.argmax(scores), :, :]
|
92 |
masks, scores, logits = self.predictor.predict(
|
93 |
point_coords=input_point,
|
94 |
point_labels=input_label,
|
95 |
-
box
|
96 |
mask_input=mask_input[None, :, :],
|
97 |
multimask_output=False,
|
98 |
)
|
99 |
-
|
100 |
return masks
|
101 |
|
|
|
102 |
if __name__ == "__main__":
|
103 |
image_path = 'segmenter/images/truck.jpg'
|
104 |
prompts = [
|
@@ -109,9 +144,9 @@ if __name__ == "__main__":
|
|
109 |
# "multimask_output":"True",
|
110 |
# },
|
111 |
{
|
112 |
-
"prompt_type":["click"],
|
113 |
-
"input_point":[[1000, 600], [1325, 625]],
|
114 |
-
"input_label":[1, 0],
|
115 |
},
|
116 |
# {
|
117 |
# "prompt_type":["click", "box"],
|
@@ -132,7 +167,7 @@ if __name__ == "__main__":
|
|
132 |
# "prompt_type":["everything"]
|
133 |
# },
|
134 |
]
|
135 |
-
|
136 |
init_time = time.time()
|
137 |
segmenter = BaseSegmenter(
|
138 |
device='cuda',
|
@@ -142,8 +177,8 @@ if __name__ == "__main__":
|
|
142 |
reuse_feature=True
|
143 |
)
|
144 |
print(f'init time: {time.time() - init_time}')
|
145 |
-
|
146 |
-
image_path = '
|
147 |
infer_time = time.time()
|
148 |
for i, prompt in enumerate(prompts):
|
149 |
print(f'{prompt["prompt_type"]} mode')
|
@@ -152,5 +187,5 @@ if __name__ == "__main__":
|
|
152 |
masks = segmenter.inference(np.array(image), prompt)
|
153 |
Image.fromarray(masks[0]).save('seg.png')
|
154 |
print(masks.shape)
|
155 |
-
|
156 |
print(f'infer time: {time.time() - infer_time}')
|
|
|
5 |
import numpy as np
|
6 |
from typing import Union
|
7 |
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
8 |
+
from caption_anything.utils.utils import prepare_segmenter, seg_model_map
|
9 |
import matplotlib.pyplot as plt
|
10 |
import PIL
|
11 |
|
12 |
+
|
13 |
class BaseSegmenter:
|
14 |
+
def __init__(self, device, checkpoint, model_name='huge', reuse_feature=True, model=None):
|
15 |
print(f"Initializing BaseSegmenter to {device}")
|
16 |
self.device = device
|
17 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
18 |
self.processor = None
|
|
|
19 |
if model is None:
|
20 |
+
if checkpoint is None:
|
21 |
+
_, checkpoint = prepare_segmenter(model_name)
|
22 |
+
self.model = sam_model_registry[seg_model_map[model_name]](checkpoint=checkpoint)
|
23 |
self.checkpoint = checkpoint
|
|
|
24 |
self.model.to(device=self.device)
|
25 |
else:
|
26 |
self.model = model
|
|
|
30 |
self.image_embedding = None
|
31 |
self.image = None
|
32 |
|
33 |
+
def read_image(self, image: Union[np.ndarray, Image.Image, str]):
|
34 |
+
if type(image) == str: # input path
|
|
|
|
|
35 |
image = Image.open(image)
|
36 |
image = np.array(image)
|
37 |
elif type(image) == Image.Image:
|
38 |
image = np.array(image)
|
39 |
+
elif type(image) == np.ndarray:
|
40 |
+
image = image
|
41 |
+
else:
|
42 |
+
raise TypeError
|
43 |
+
return image
|
44 |
+
|
45 |
+
@torch.no_grad()
|
46 |
+
def set_image(self, image: Union[np.ndarray, Image.Image, str]):
|
47 |
+
image = self.read_image(image)
|
48 |
self.image = image
|
49 |
if self.reuse_feature:
|
50 |
self.predictor.set_image(image)
|
51 |
self.image_embedding = self.predictor.get_image_embedding()
|
52 |
print(self.image_embedding.shape)
|
53 |
|
|
|
54 |
@torch.no_grad()
|
55 |
+
def inference(self, image: Union[np.ndarray, Image.Image, str], control: dict):
|
56 |
+
"""
|
57 |
+
SAM inference of image according to control.
|
58 |
+
Args:
|
59 |
+
image: str or PIL.Image or np.ndarray
|
60 |
+
control:
|
61 |
+
prompt_type:
|
62 |
+
1. {control['prompt_type'] = ['everything']} to segment everything in the image.
|
63 |
+
2. {control['prompt_type'] = ['click', 'box']} to segment according to click and box.
|
64 |
+
3. {control['prompt_type'] = ['click'] to segment according to click.
|
65 |
+
4. {control['prompt_type'] = ['box'] to segment according to box.
|
66 |
+
input_point: list of [x, y] coordinates of click.
|
67 |
+
input_label: List of labels for points accordingly, 0 for negative, 1 for positive.
|
68 |
+
input_box: List of [x1, y1, x2, y2] coordinates of box.
|
69 |
+
multimask_output:
|
70 |
+
If true, the model will return three masks.
|
71 |
+
For ambiguous input prompts (such as a single click), this will often
|
72 |
+
produce better masks than a single prediction. If only a single
|
73 |
+
mask is needed, the model's predicted quality score can be used
|
74 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
75 |
+
input prompts, multimask_output=False can give better results.
|
76 |
+
Returns:
|
77 |
+
masks: np.ndarray of shape [num_masks, height, width]
|
78 |
+
|
79 |
+
"""
|
80 |
+
image = self.read_image(image) # Turn image into np.ndarray
|
81 |
if 'everything' in control['prompt_type']:
|
82 |
masks = self.mask_generator.generate(image)
|
83 |
+
new_masks = np.concatenate([mask["segmentation"][np.newaxis, :] for mask in masks])
|
84 |
return new_masks
|
85 |
else:
|
86 |
if not self.reuse_feature or self.image_embedding is None:
|
|
|
89 |
else:
|
90 |
assert self.image_embedding is not None
|
91 |
self.predictor.features = self.image_embedding
|
92 |
+
|
93 |
if 'mutimask_output' in control:
|
94 |
masks, scores, logits = self.predictor.predict(
|
95 |
+
point_coords=np.array(control['input_point']),
|
96 |
+
point_labels=np.array(control['input_label']),
|
97 |
+
multimask_output=True,
|
98 |
)
|
99 |
elif 'input_boxes' in control:
|
100 |
transformed_boxes = self.predictor.transform.apply_boxes_torch(
|
101 |
torch.tensor(control["input_boxes"], device=self.predictor.device),
|
102 |
+
image.shape[1::-1] # Reverse shape because numpy is (W, H) and function need (H, W)
|
103 |
)
|
104 |
masks, _, _ = self.predictor.predict_torch(
|
105 |
point_coords=None,
|
|
|
108 |
multimask_output=False,
|
109 |
)
|
110 |
masks = masks.squeeze(1).cpu().numpy()
|
111 |
+
|
112 |
else:
|
113 |
input_point = np.array(control['input_point']) if 'click' in control['prompt_type'] else None
|
114 |
input_label = np.array(control['input_label']) if 'click' in control['prompt_type'] else None
|
115 |
input_box = np.array(control['input_box']) if 'box' in control['prompt_type'] else None
|
116 |
+
|
117 |
masks, scores, logits = self.predictor.predict(
|
118 |
+
point_coords=input_point,
|
119 |
+
point_labels=input_label,
|
120 |
+
box=input_box,
|
121 |
+
multimask_output=False,
|
122 |
)
|
123 |
+
|
124 |
if 0 in control['input_label']:
|
125 |
mask_input = logits[np.argmax(scores), :, :]
|
126 |
masks, scores, logits = self.predictor.predict(
|
127 |
point_coords=input_point,
|
128 |
point_labels=input_label,
|
129 |
+
box=input_box,
|
130 |
mask_input=mask_input[None, :, :],
|
131 |
multimask_output=False,
|
132 |
)
|
133 |
+
|
134 |
return masks
|
135 |
|
136 |
+
|
137 |
if __name__ == "__main__":
|
138 |
image_path = 'segmenter/images/truck.jpg'
|
139 |
prompts = [
|
|
|
144 |
# "multimask_output":"True",
|
145 |
# },
|
146 |
{
|
147 |
+
"prompt_type": ["click"],
|
148 |
+
"input_point": [[1000, 600], [1325, 625]],
|
149 |
+
"input_label": [1, 0],
|
150 |
},
|
151 |
# {
|
152 |
# "prompt_type":["click", "box"],
|
|
|
167 |
# "prompt_type":["everything"]
|
168 |
# },
|
169 |
]
|
170 |
+
|
171 |
init_time = time.time()
|
172 |
segmenter = BaseSegmenter(
|
173 |
device='cuda',
|
|
|
177 |
reuse_feature=True
|
178 |
)
|
179 |
print(f'init time: {time.time() - init_time}')
|
180 |
+
|
181 |
+
image_path = 'test_images/img2.jpg'
|
182 |
infer_time = time.time()
|
183 |
for i, prompt in enumerate(prompts):
|
184 |
print(f'{prompt["prompt_type"]} mode')
|
|
|
187 |
masks = segmenter.inference(np.array(image), prompt)
|
188 |
Image.fromarray(masks[0]).save('seg.png')
|
189 |
print(masks.shape)
|
190 |
+
|
191 |
print(f'infer time: {time.time() - infer_time}')
|
{segmenter β caption_anything/segmenter}/readme.md
RENAMED
File without changes
|
{text_refiner β caption_anything/text_refiner}/README.md
RENAMED
File without changes
|
{text_refiner β caption_anything/text_refiner}/__init__.py
RENAMED
@@ -1,4 +1,4 @@
|
|
1 |
-
from
|
2 |
|
3 |
|
4 |
def build_text_refiner(type, device, args=None, api_key=""):
|
|
|
1 |
+
from .text_refiner import TextRefiner
|
2 |
|
3 |
|
4 |
def build_text_refiner(type, device, args=None, api_key=""):
|
{text_refiner β caption_anything/text_refiner}/text_refiner.py
RENAMED
File without changes
|
caption_anything/utils/chatbot.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Microsoft
|
2 |
+
# Modified from Visual ChatGPT Project https://github.com/microsoft/TaskMatrix/blob/main/visual_chatgpt.py
|
3 |
+
|
4 |
+
import os
|
5 |
+
import gradio as gr
|
6 |
+
import re
|
7 |
+
import uuid
|
8 |
+
from PIL import Image, ImageDraw, ImageOps
|
9 |
+
import numpy as np
|
10 |
+
import argparse
|
11 |
+
import inspect
|
12 |
+
|
13 |
+
from langchain.agents.initialize import initialize_agent
|
14 |
+
from langchain.agents.tools import Tool
|
15 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
16 |
+
from langchain.llms.openai import OpenAI
|
17 |
+
import torch
|
18 |
+
from PIL import Image, ImageDraw, ImageOps
|
19 |
+
from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
|
20 |
+
|
21 |
+
VISUAL_CHATGPT_PREFIX = """
|
22 |
+
Caption Anything Chatbox (short as CATchat) is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. CATchat is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
23 |
+
|
24 |
+
As a language model, CATchat can not directly read images, but it has a list of tools to finish different visual tasks. CATchat can invoke different tools to indirectly understand pictures.
|
25 |
+
|
26 |
+
Visual ChatGPT has access to the following tools:"""
|
27 |
+
|
28 |
+
|
29 |
+
# VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
|
30 |
+
|
31 |
+
# Visual ChatGPT is able to process and understand large amounts of text and images. As a language model, Visual ChatGPT can not directly read images, but it has a list of tools to finish different visual tasks. Each image will have a file name formed as "chat_image/xxx.png", and Visual ChatGPT can invoke different tools to indirectly understand pictures. When talking about images, Visual ChatGPT is very strict to the file name and will never fabricate nonexistent files. Visual ChatGPT is able to use tools in a sequence, and is loyal to the tool observation outputs rather than faking the image content and image file name.
|
32 |
+
|
33 |
+
# Visual ChatGPT is aware of the coordinate of an object in the image, which is represented as a point (X, Y) on the object. Note that (0, 0) represents the bottom-left corner of the image.
|
34 |
+
|
35 |
+
# Human may provide new figures to Visual ChatGPT with a description. The description helps Visual ChatGPT to understand this image, but Visual ChatGPT should use tools to finish following tasks, rather than directly imagine from the description.
|
36 |
+
|
37 |
+
# Overall, Visual ChatGPT is a powerful visual dialogue assistant tool that can help with a wide range of tasks and provide valuable insights and information on a wide range of topics.
|
38 |
+
|
39 |
+
|
40 |
+
# TOOLS:
|
41 |
+
# ------
|
42 |
+
|
43 |
+
# Visual ChatGPT has access to the following tools:"""
|
44 |
+
|
45 |
+
VISUAL_CHATGPT_FORMAT_INSTRUCTIONS = """To use a tool, please use the following format:
|
46 |
+
|
47 |
+
"Thought: Do I need to use a tool? Yes
|
48 |
+
Action: the action to take, should be one of [{tool_names}], remember the action must to be one tool
|
49 |
+
Action Input: the input to the action
|
50 |
+
Observation: the result of the action"
|
51 |
+
|
52 |
+
When you have a response to say to the Human, or if you do not need to use a tool, you MUST use the format:
|
53 |
+
|
54 |
+
"Thought: Do I need to use a tool? No
|
55 |
+
{ai_prefix}: [your response here]"
|
56 |
+
|
57 |
+
"""
|
58 |
+
|
59 |
+
VISUAL_CHATGPT_SUFFIX = """
|
60 |
+
Begin Chatting!
|
61 |
+
|
62 |
+
Previous conversation history:
|
63 |
+
{chat_history}
|
64 |
+
|
65 |
+
New input: {input}
|
66 |
+
Since CATchat is a text language model, CATchat must use tools iteratively to observe images rather than imagination.
|
67 |
+
The thoughts and observations are only visible for CATchat, CATchat should remember to repeat important information in the final response for Human.
|
68 |
+
|
69 |
+
Thought: Do I need to use a tool? {agent_scratchpad} (You are strictly to use the aforementioned "Thought/Action/Action Input/Observation" format as the answer.)"""
|
70 |
+
|
71 |
+
os.makedirs('chat_image', exist_ok=True)
|
72 |
+
|
73 |
+
|
74 |
+
def prompts(name, description):
|
75 |
+
def decorator(func):
|
76 |
+
func.name = name
|
77 |
+
func.description = description
|
78 |
+
return func
|
79 |
+
return decorator
|
80 |
+
|
81 |
+
def cut_dialogue_history(history_memory, keep_last_n_words=500):
|
82 |
+
if history_memory is None or len(history_memory) == 0:
|
83 |
+
return history_memory
|
84 |
+
tokens = history_memory.split()
|
85 |
+
n_tokens = len(tokens)
|
86 |
+
print(f"history_memory:{history_memory}, n_tokens: {n_tokens}")
|
87 |
+
if n_tokens < keep_last_n_words:
|
88 |
+
return history_memory
|
89 |
+
paragraphs = history_memory.split('\n')
|
90 |
+
last_n_tokens = n_tokens
|
91 |
+
while last_n_tokens >= keep_last_n_words:
|
92 |
+
last_n_tokens -= len(paragraphs[0].split(' '))
|
93 |
+
paragraphs = paragraphs[1:]
|
94 |
+
return '\n' + '\n'.join(paragraphs)
|
95 |
+
|
96 |
+
def get_new_image_name(folder='chat_image', func_name="update"):
|
97 |
+
this_new_uuid = str(uuid.uuid4())[:8]
|
98 |
+
new_file_name = f'{func_name}_{this_new_uuid}.png'
|
99 |
+
return os.path.join(folder, new_file_name)
|
100 |
+
|
101 |
+
class VisualQuestionAnswering:
|
102 |
+
def __init__(self, device):
|
103 |
+
print(f"Initializing VisualQuestionAnswering to {device}")
|
104 |
+
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
105 |
+
self.device = device
|
106 |
+
self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
|
107 |
+
self.model = BlipForQuestionAnswering.from_pretrained(
|
108 |
+
"Salesforce/blip-vqa-base", torch_dtype=self.torch_dtype).to(self.device)
|
109 |
+
# self.processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
110 |
+
# self.model = BlipForQuestionAnswering.from_pretrained(
|
111 |
+
# "Salesforce/blip-vqa-capfilt-large", torch_dtype=self.torch_dtype).to(self.device)
|
112 |
+
|
113 |
+
@prompts(name="Answer Question About The Image",
|
114 |
+
description="useful when you need an answer for a question based on an image. "
|
115 |
+
"like: what is the background color of the last image, how many cats in this figure, what is in this figure. "
|
116 |
+
"The input to this tool should be a comma separated string of two, representing the image_path and the question")
|
117 |
+
def inference(self, inputs):
|
118 |
+
image_path, question = inputs.split(",")[0], ','.join(inputs.split(',')[1:])
|
119 |
+
raw_image = Image.open(image_path).convert('RGB')
|
120 |
+
inputs = self.processor(raw_image, question, return_tensors="pt").to(self.device, self.torch_dtype)
|
121 |
+
out = self.model.generate(**inputs)
|
122 |
+
answer = self.processor.decode(out[0], skip_special_tokens=True)
|
123 |
+
print(f"\nProcessed VisualQuestionAnswering, Input Image: {image_path}, Input Question: {question}, "
|
124 |
+
f"Output Answer: {answer}")
|
125 |
+
return answer
|
126 |
+
|
127 |
+
def build_chatbot_tools(load_dict):
|
128 |
+
print(f"Initializing ChatBot, load_dict={load_dict}")
|
129 |
+
models = {}
|
130 |
+
# Load Basic Foundation Models
|
131 |
+
for class_name, device in load_dict.items():
|
132 |
+
models[class_name] = globals()[class_name](device=device)
|
133 |
+
|
134 |
+
# Load Template Foundation Models
|
135 |
+
for class_name, module in globals().items():
|
136 |
+
if getattr(module, 'template_model', False):
|
137 |
+
template_required_names = {k for k in inspect.signature(module.__init__).parameters.keys() if k!='self'}
|
138 |
+
loaded_names = set([type(e).__name__ for e in models.values()])
|
139 |
+
if template_required_names.issubset(loaded_names):
|
140 |
+
models[class_name] = globals()[class_name](
|
141 |
+
**{name: models[name] for name in template_required_names})
|
142 |
+
|
143 |
+
tools = []
|
144 |
+
for instance in models.values():
|
145 |
+
for e in dir(instance):
|
146 |
+
if e.startswith('inference'):
|
147 |
+
func = getattr(instance, e)
|
148 |
+
tools.append(Tool(name=func.name, description=func.description, func=func))
|
149 |
+
return tools
|
150 |
+
|
151 |
+
class ConversationBot:
|
152 |
+
def __init__(self, tools, api_key=""):
|
153 |
+
# load_dict = {'VisualQuestionAnswering':'cuda:0', 'ImageCaptioning':'cuda:1',...}
|
154 |
+
llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
|
155 |
+
self.llm = llm
|
156 |
+
self.memory = ConversationBufferMemory(memory_key="chat_history", output_key='output')
|
157 |
+
self.tools = tools
|
158 |
+
self.current_image = None
|
159 |
+
self.point_prompt = ""
|
160 |
+
self.agent = initialize_agent(
|
161 |
+
self.tools,
|
162 |
+
self.llm,
|
163 |
+
agent="conversational-react-description",
|
164 |
+
verbose=True,
|
165 |
+
memory=self.memory,
|
166 |
+
return_intermediate_steps=True,
|
167 |
+
agent_kwargs={'prefix': VISUAL_CHATGPT_PREFIX, 'format_instructions': VISUAL_CHATGPT_FORMAT_INSTRUCTIONS,
|
168 |
+
'suffix': VISUAL_CHATGPT_SUFFIX}, )
|
169 |
+
|
170 |
+
def constructe_intermediate_steps(self, agent_res):
|
171 |
+
ans = []
|
172 |
+
for action, output in agent_res:
|
173 |
+
if hasattr(action, "tool_input"):
|
174 |
+
use_tool = "Yes"
|
175 |
+
act = (f"Thought: Do I need to use a tool? {use_tool}\nAction: {action.tool}\nAction Input: {action.tool_input}", f"Observation: {output}")
|
176 |
+
else:
|
177 |
+
use_tool = "No"
|
178 |
+
act = (f"Thought: Do I need to use a tool? {use_tool}", f"AI: {output}")
|
179 |
+
act= list(map(lambda x: x.replace('\n', '<br>'), act))
|
180 |
+
ans.append(act)
|
181 |
+
return ans
|
182 |
+
|
183 |
+
def run_text(self, text, state, aux_state):
|
184 |
+
self.agent.memory.buffer = cut_dialogue_history(self.agent.memory.buffer, keep_last_n_words=500)
|
185 |
+
if self.point_prompt != "":
|
186 |
+
Human_prompt = f'\nHuman: {self.point_prompt}\n'
|
187 |
+
AI_prompt = 'Ok'
|
188 |
+
self.agent.memory.buffer = self.agent.memory.buffer + Human_prompt + 'AI: ' + AI_prompt
|
189 |
+
self.point_prompt = ""
|
190 |
+
res = self.agent({"input": text})
|
191 |
+
res['output'] = res['output'].replace("\\", "/")
|
192 |
+
response = re.sub('(chat_image/\S*png)', lambda m: f'![](/file={m.group(0)})*{m.group(0)}*', res['output'])
|
193 |
+
state = state + [(text, response)]
|
194 |
+
|
195 |
+
aux_state = aux_state + [(f"User Input: {text}", None)]
|
196 |
+
aux_state = aux_state + self.constructe_intermediate_steps(res['intermediate_steps'])
|
197 |
+
print(f"\nProcessed run_text, Input text: {text}\nCurrent state: {state}\n"
|
198 |
+
f"Current Memory: {self.agent.memory.buffer}\n"
|
199 |
+
f"Aux state: {aux_state}\n"
|
200 |
+
)
|
201 |
+
return state, state, aux_state, aux_state
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == '__main__':
|
205 |
+
parser = argparse.ArgumentParser()
|
206 |
+
parser.add_argument('--load', type=str, default="VisualQuestionAnswering_cuda:0")
|
207 |
+
parser.add_argument('--port', type=int, default=1015)
|
208 |
+
|
209 |
+
args = parser.parse_args()
|
210 |
+
load_dict = {e.split('_')[0].strip(): e.split('_')[1].strip() for e in args.load.split(',')}
|
211 |
+
tools = build_chatbot_tools(load_dict)
|
212 |
+
bot = ConversationBot(tools)
|
213 |
+
with gr.Blocks(css="#chatbot .overflow-y-auto{height:500px}") as demo:
|
214 |
+
with gr.Row():
|
215 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="Visual ChatGPT").style(height=1000,scale=0.5)
|
216 |
+
auxwindow = gr.Chatbot(elem_id="chatbot", label="Aux Window").style(height=1000,scale=0.5)
|
217 |
+
state = gr.State([])
|
218 |
+
aux_state = gr.State([])
|
219 |
+
with gr.Row():
|
220 |
+
with gr.Column(scale=0.7):
|
221 |
+
txt = gr.Textbox(show_label=False, placeholder="Enter text and press enter, or upload an image").style(
|
222 |
+
container=False)
|
223 |
+
with gr.Column(scale=0.15, min_width=0):
|
224 |
+
clear = gr.Button("Clear")
|
225 |
+
with gr.Column(scale=0.15, min_width=0):
|
226 |
+
btn = gr.UploadButton("Upload", file_types=["image"])
|
227 |
+
|
228 |
+
txt.submit(bot.run_text, [txt, state, aux_state], [chatbot, state, aux_state, auxwindow])
|
229 |
+
txt.submit(lambda: "", None, txt)
|
230 |
+
btn.upload(bot.run_image, [btn, state, txt, aux_state], [chatbot, state, txt, aux_state, auxwindow])
|
231 |
+
clear.click(bot.memory.clear)
|
232 |
+
clear.click(lambda: [], None, chatbot)
|
233 |
+
clear.click(lambda: [], None, auxwindow)
|
234 |
+
clear.click(lambda: [], None, state)
|
235 |
+
clear.click(lambda: [], None, aux_state)
|
236 |
+
demo.launch(server_name="0.0.0.0", server_port=args.port, share=True)
|
image_editing_utils.py β caption_anything/utils/image_editing_utils.py
RENAMED
@@ -1,7 +1,8 @@
|
|
1 |
from PIL import Image, ImageDraw, ImageFont
|
2 |
import copy
|
3 |
import numpy as np
|
4 |
-
import cv2
|
|
|
5 |
|
6 |
def wrap_text(text, font, max_width):
|
7 |
lines = []
|
@@ -18,11 +19,18 @@ def wrap_text(text, font, max_width):
|
|
18 |
lines.append(current_line)
|
19 |
return lines
|
20 |
|
21 |
-
|
|
|
|
|
22 |
# Load the image
|
|
|
|
|
|
|
|
|
|
|
23 |
if type(image) == np.ndarray:
|
24 |
image = Image.fromarray(image)
|
25 |
-
|
26 |
image = copy.deepcopy(image)
|
27 |
width, height = image.size
|
28 |
|
@@ -47,19 +55,19 @@ def create_bubble_frame(image, text, point, segmask, input_points, input_labels,
|
|
47 |
bubble_height = text_height + 2 * padding
|
48 |
|
49 |
# Create a new image for the bubble frame
|
50 |
-
bubble = Image.new('RGBA', (bubble_width, bubble_height), (255,248, 220, 0))
|
51 |
|
52 |
# Draw the bubble frame on the new image
|
53 |
draw = ImageDraw.Draw(bubble)
|
54 |
# draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
|
55 |
-
draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
|
56 |
-
|
57 |
# Draw the wrapped text line by line
|
58 |
y_text = padding
|
59 |
for line in lines:
|
60 |
draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
|
61 |
y_text += font.getsize(line)[1]
|
62 |
-
|
63 |
# Determine the point by the min area rect of mask
|
64 |
try:
|
65 |
ret, thresh = cv2.threshold(segmask, 127, 255, 0)
|
@@ -109,7 +117,11 @@ def draw_rounded_rectangle(draw, xy, corner_radius, fill=None, outline=None, wid
|
|
109 |
width=width
|
110 |
)
|
111 |
|
112 |
-
draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline,
|
113 |
-
|
114 |
-
draw.pieslice((x2 - corner_radius * 2,
|
115 |
-
|
|
|
|
|
|
|
|
|
|
1 |
from PIL import Image, ImageDraw, ImageFont
|
2 |
import copy
|
3 |
import numpy as np
|
4 |
+
import cv2
|
5 |
+
|
6 |
|
7 |
def wrap_text(text, font, max_width):
|
8 |
lines = []
|
|
|
19 |
lines.append(current_line)
|
20 |
return lines
|
21 |
|
22 |
+
|
23 |
+
def create_bubble_frame(image, text, point, segmask, input_points=(), input_labels=(),
|
24 |
+
font_path='assets/times_with_simsun.ttf', font_size_ratio=0.033, point_size_ratio=0.01):
|
25 |
# Load the image
|
26 |
+
if input_points is None:
|
27 |
+
input_points = []
|
28 |
+
if input_labels is None:
|
29 |
+
input_labels = []
|
30 |
+
|
31 |
if type(image) == np.ndarray:
|
32 |
image = Image.fromarray(image)
|
33 |
+
|
34 |
image = copy.deepcopy(image)
|
35 |
width, height = image.size
|
36 |
|
|
|
55 |
bubble_height = text_height + 2 * padding
|
56 |
|
57 |
# Create a new image for the bubble frame
|
58 |
+
bubble = Image.new('RGBA', (bubble_width, bubble_height), (255, 248, 220, 0))
|
59 |
|
60 |
# Draw the bubble frame on the new image
|
61 |
draw = ImageDraw.Draw(bubble)
|
62 |
# draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
|
63 |
+
draw_rounded_rectangle(draw, (0, 0, bubble_width - 1, bubble_height - 1), point_size * 2,
|
64 |
+
fill=(255, 248, 220, 120), outline=None, width=2)
|
65 |
# Draw the wrapped text line by line
|
66 |
y_text = padding
|
67 |
for line in lines:
|
68 |
draw.text((padding, y_text), line, font=font, fill=(0, 0, 0, 255))
|
69 |
y_text += font.getsize(line)[1]
|
70 |
+
|
71 |
# Determine the point by the min area rect of mask
|
72 |
try:
|
73 |
ret, thresh = cv2.threshold(segmask, 127, 255, 0)
|
|
|
117 |
width=width
|
118 |
)
|
119 |
|
120 |
+
draw.pieslice((x1, y1, x1 + corner_radius * 2, y1 + corner_radius * 2), 180, 270, fill=fill, outline=outline,
|
121 |
+
width=width)
|
122 |
+
draw.pieslice((x2 - corner_radius * 2, y1, x2, y1 + corner_radius * 2), 270, 360, fill=fill, outline=outline,
|
123 |
+
width=width)
|
124 |
+
draw.pieslice((x2 - corner_radius * 2, y2 - corner_radius * 2, x2, y2), 0, 90, fill=fill, outline=outline,
|
125 |
+
width=width)
|
126 |
+
draw.pieslice((x1, y2 - corner_radius * 2, x1 + corner_radius * 2, y2), 90, 180, fill=fill, outline=outline,
|
127 |
+
width=width)
|
caption_anything/utils/parser.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
def parse_augment():
|
4 |
+
parser = argparse.ArgumentParser()
|
5 |
+
parser.add_argument('--captioner', type=str, default="blip2")
|
6 |
+
parser.add_argument('--segmenter', type=str, default="huge")
|
7 |
+
parser.add_argument('--text_refiner', type=str, default="base")
|
8 |
+
parser.add_argument('--segmenter_checkpoint', type=str, default=None, help="SAM checkpoint path")
|
9 |
+
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'],
|
10 |
+
help="whether to add or remove background of the image when captioning")
|
11 |
+
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
12 |
+
parser.add_argument('--context_captions', action="store_true",
|
13 |
+
help="use surrounding captions to enhance current caption (TODO)")
|
14 |
+
parser.add_argument('--disable_regular_box', action="store_true", default=False,
|
15 |
+
help="crop image with a regular box")
|
16 |
+
parser.add_argument('--device', type=str, default="cuda:0")
|
17 |
+
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
18 |
+
parser.add_argument('--debug', action="store_true")
|
19 |
+
parser.add_argument('--gradio_share', action="store_true")
|
20 |
+
parser.add_argument('--disable_gpt', action="store_true")
|
21 |
+
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
22 |
+
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
23 |
+
parser.add_argument('--enable_morphologyex', action="store_true", default=False)
|
24 |
+
parser.add_argument('--chat_tools_dict', type=str, default='VisualQuestionAnswering_cuda:0', help='Visual ChatGPT tools, only useful when running gradio applications')
|
25 |
+
args = parser.parse_args()
|
26 |
+
|
27 |
+
if args.debug:
|
28 |
+
print(args)
|
29 |
+
return args
|
caption_anything/utils/utils.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import requests
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import time
|
7 |
+
import sys
|
8 |
+
import urllib
|
9 |
+
from tqdm import tqdm
|
10 |
+
import hashlib
|
11 |
+
|
12 |
+
def is_platform_win():
|
13 |
+
return sys.platform == "win32"
|
14 |
+
|
15 |
+
|
16 |
+
def colormap(rgb=True):
|
17 |
+
color_list = np.array(
|
18 |
+
[
|
19 |
+
0.000, 0.000, 0.000,
|
20 |
+
1.000, 1.000, 1.000,
|
21 |
+
1.000, 0.498, 0.313,
|
22 |
+
0.392, 0.581, 0.929,
|
23 |
+
0.000, 0.447, 0.741,
|
24 |
+
0.850, 0.325, 0.098,
|
25 |
+
0.929, 0.694, 0.125,
|
26 |
+
0.494, 0.184, 0.556,
|
27 |
+
0.466, 0.674, 0.188,
|
28 |
+
0.301, 0.745, 0.933,
|
29 |
+
0.635, 0.078, 0.184,
|
30 |
+
0.300, 0.300, 0.300,
|
31 |
+
0.600, 0.600, 0.600,
|
32 |
+
1.000, 0.000, 0.000,
|
33 |
+
1.000, 0.500, 0.000,
|
34 |
+
0.749, 0.749, 0.000,
|
35 |
+
0.000, 1.000, 0.000,
|
36 |
+
0.000, 0.000, 1.000,
|
37 |
+
0.667, 0.000, 1.000,
|
38 |
+
0.333, 0.333, 0.000,
|
39 |
+
0.333, 0.667, 0.000,
|
40 |
+
0.333, 1.000, 0.000,
|
41 |
+
0.667, 0.333, 0.000,
|
42 |
+
0.667, 0.667, 0.000,
|
43 |
+
0.667, 1.000, 0.000,
|
44 |
+
1.000, 0.333, 0.000,
|
45 |
+
1.000, 0.667, 0.000,
|
46 |
+
1.000, 1.000, 0.000,
|
47 |
+
0.000, 0.333, 0.500,
|
48 |
+
0.000, 0.667, 0.500,
|
49 |
+
0.000, 1.000, 0.500,
|
50 |
+
0.333, 0.000, 0.500,
|
51 |
+
0.333, 0.333, 0.500,
|
52 |
+
0.333, 0.667, 0.500,
|
53 |
+
0.333, 1.000, 0.500,
|
54 |
+
0.667, 0.000, 0.500,
|
55 |
+
0.667, 0.333, 0.500,
|
56 |
+
0.667, 0.667, 0.500,
|
57 |
+
0.667, 1.000, 0.500,
|
58 |
+
1.000, 0.000, 0.500,
|
59 |
+
1.000, 0.333, 0.500,
|
60 |
+
1.000, 0.667, 0.500,
|
61 |
+
1.000, 1.000, 0.500,
|
62 |
+
0.000, 0.333, 1.000,
|
63 |
+
0.000, 0.667, 1.000,
|
64 |
+
0.000, 1.000, 1.000,
|
65 |
+
0.333, 0.000, 1.000,
|
66 |
+
0.333, 0.333, 1.000,
|
67 |
+
0.333, 0.667, 1.000,
|
68 |
+
0.333, 1.000, 1.000,
|
69 |
+
0.667, 0.000, 1.000,
|
70 |
+
0.667, 0.333, 1.000,
|
71 |
+
0.667, 0.667, 1.000,
|
72 |
+
0.667, 1.000, 1.000,
|
73 |
+
1.000, 0.000, 1.000,
|
74 |
+
1.000, 0.333, 1.000,
|
75 |
+
1.000, 0.667, 1.000,
|
76 |
+
0.167, 0.000, 0.000,
|
77 |
+
0.333, 0.000, 0.000,
|
78 |
+
0.500, 0.000, 0.000,
|
79 |
+
0.667, 0.000, 0.000,
|
80 |
+
0.833, 0.000, 0.000,
|
81 |
+
1.000, 0.000, 0.000,
|
82 |
+
0.000, 0.167, 0.000,
|
83 |
+
0.000, 0.333, 0.000,
|
84 |
+
0.000, 0.500, 0.000,
|
85 |
+
0.000, 0.667, 0.000,
|
86 |
+
0.000, 0.833, 0.000,
|
87 |
+
0.000, 1.000, 0.000,
|
88 |
+
0.000, 0.000, 0.167,
|
89 |
+
0.000, 0.000, 0.333,
|
90 |
+
0.000, 0.000, 0.500,
|
91 |
+
0.000, 0.000, 0.667,
|
92 |
+
0.000, 0.000, 0.833,
|
93 |
+
0.000, 0.000, 1.000,
|
94 |
+
0.143, 0.143, 0.143,
|
95 |
+
0.286, 0.286, 0.286,
|
96 |
+
0.429, 0.429, 0.429,
|
97 |
+
0.571, 0.571, 0.571,
|
98 |
+
0.714, 0.714, 0.714,
|
99 |
+
0.857, 0.857, 0.857
|
100 |
+
]
|
101 |
+
).astype(np.float32)
|
102 |
+
color_list = color_list.reshape((-1, 3)) * 255
|
103 |
+
if not rgb:
|
104 |
+
color_list = color_list[:, ::-1]
|
105 |
+
return color_list
|
106 |
+
|
107 |
+
|
108 |
+
color_list = colormap()
|
109 |
+
color_list = color_list.astype('uint8').tolist()
|
110 |
+
|
111 |
+
|
112 |
+
def vis_add_mask(image, mask, color, alpha, kernel_size):
|
113 |
+
color = np.array(color)
|
114 |
+
mask = mask.astype('float').copy()
|
115 |
+
mask = (cv2.GaussianBlur(mask, (kernel_size, kernel_size), kernel_size) / 255.) * (alpha)
|
116 |
+
for i in range(3):
|
117 |
+
image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
|
118 |
+
return image
|
119 |
+
|
120 |
+
|
121 |
+
def vis_add_mask_wo_blur(image, mask, color, alpha):
|
122 |
+
color = np.array(color)
|
123 |
+
mask = mask.astype('float').copy()
|
124 |
+
for i in range(3):
|
125 |
+
image[:, :, i] = image[:, :, i] * (1-alpha+mask) + color[i] * (alpha-mask)
|
126 |
+
return image
|
127 |
+
|
128 |
+
|
129 |
+
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
|
130 |
+
background_color = np.array(background_color)
|
131 |
+
contour_color = np.array(contour_color)
|
132 |
+
|
133 |
+
# background_mask = 1 - background_mask
|
134 |
+
# contour_mask = 1 - contour_mask
|
135 |
+
|
136 |
+
for i in range(3):
|
137 |
+
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
|
138 |
+
+ background_color[i] * (background_alpha-background_mask*background_alpha)
|
139 |
+
|
140 |
+
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
|
141 |
+
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
|
142 |
+
|
143 |
+
return image.astype('uint8')
|
144 |
+
|
145 |
+
|
146 |
+
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, background_color=0, paint_foreground=False):
|
147 |
+
"""
|
148 |
+
add color mask to the background/foreground area
|
149 |
+
input_image: numpy array (w, h, C)
|
150 |
+
input_mask: numpy array (w, h)
|
151 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
152 |
+
background_blur_radius: radius of background blur, must be odd number
|
153 |
+
contour_width: width of mask contour, must be odd number
|
154 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
155 |
+
background_color: color index of the background (area with input_mask == False)
|
156 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
157 |
+
paint_foreground: True for paint on foreground, False for background. Default: Flase
|
158 |
+
|
159 |
+
Output:
|
160 |
+
painted_image: numpy array
|
161 |
+
"""
|
162 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
163 |
+
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
164 |
+
|
165 |
+
# 0: background, 1: foreground
|
166 |
+
input_mask[input_mask>0] = 255
|
167 |
+
if paint_foreground:
|
168 |
+
painted_image = vis_add_mask(input_image, 255 - input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
|
169 |
+
else:
|
170 |
+
# mask background
|
171 |
+
painted_image = vis_add_mask(input_image, input_mask, color_list[background_color], background_alpha, background_blur_radius) # black for background
|
172 |
+
# mask contour
|
173 |
+
contour_mask = input_mask.copy()
|
174 |
+
contour_mask = cv2.Canny(contour_mask, 100, 200) # contour extraction
|
175 |
+
# widden contour
|
176 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
177 |
+
contour_mask = cv2.dilate(contour_mask, kernel)
|
178 |
+
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
|
179 |
+
return painted_image
|
180 |
+
|
181 |
+
|
182 |
+
def mask_painter_foreground_all(input_image, input_masks, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
|
183 |
+
"""
|
184 |
+
paint color mask on the all foreground area
|
185 |
+
input_image: numpy array with shape (w, h, C)
|
186 |
+
input_mask: list of masks, each mask is a numpy array with shape (w,h)
|
187 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
188 |
+
background_blur_radius: radius of background blur, must be odd number
|
189 |
+
contour_width: width of mask contour, must be odd number
|
190 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
191 |
+
background_color: color index of the background (area with input_mask == False)
|
192 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
193 |
+
|
194 |
+
Output:
|
195 |
+
painted_image: numpy array
|
196 |
+
"""
|
197 |
+
|
198 |
+
for i, input_mask in enumerate(input_masks):
|
199 |
+
input_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, background_color=i + 2, paint_foreground=True)
|
200 |
+
return input_image
|
201 |
+
|
202 |
+
def mask_generator_00(mask, background_radius, contour_radius):
|
203 |
+
# no background width when '00'
|
204 |
+
# distance map
|
205 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
206 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
207 |
+
dist_map = dist_transform_fore - dist_transform_back
|
208 |
+
# ...:::!!!:::...
|
209 |
+
contour_radius += 2
|
210 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
211 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
212 |
+
contour_mask[contour_mask>0.5] = 1.
|
213 |
+
|
214 |
+
return mask, contour_mask
|
215 |
+
|
216 |
+
|
217 |
+
def mask_generator_01(mask, background_radius, contour_radius):
|
218 |
+
# no background width when '00'
|
219 |
+
# distance map
|
220 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
221 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
222 |
+
dist_map = dist_transform_fore - dist_transform_back
|
223 |
+
# ...:::!!!:::...
|
224 |
+
contour_radius += 2
|
225 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
226 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
227 |
+
return mask, contour_mask
|
228 |
+
|
229 |
+
|
230 |
+
def mask_generator_10(mask, background_radius, contour_radius):
|
231 |
+
# distance map
|
232 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
233 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
234 |
+
dist_map = dist_transform_fore - dist_transform_back
|
235 |
+
# .....:::::!!!!!
|
236 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
237 |
+
background_mask = (background_mask - np.min(background_mask))
|
238 |
+
background_mask = background_mask / np.max(background_mask)
|
239 |
+
# ...:::!!!:::...
|
240 |
+
contour_radius += 2
|
241 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
242 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
243 |
+
contour_mask[contour_mask>0.5] = 1.
|
244 |
+
return background_mask, contour_mask
|
245 |
+
|
246 |
+
|
247 |
+
def mask_generator_11(mask, background_radius, contour_radius):
|
248 |
+
# distance map
|
249 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
250 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
251 |
+
dist_map = dist_transform_fore - dist_transform_back
|
252 |
+
# .....:::::!!!!!
|
253 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
254 |
+
background_mask = (background_mask - np.min(background_mask))
|
255 |
+
background_mask = background_mask / np.max(background_mask)
|
256 |
+
# ...:::!!!:::...
|
257 |
+
contour_radius += 2
|
258 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
259 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
260 |
+
return background_mask, contour_mask
|
261 |
+
|
262 |
+
|
263 |
+
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
|
264 |
+
"""
|
265 |
+
Input:
|
266 |
+
input_image: numpy array
|
267 |
+
input_mask: numpy array
|
268 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
269 |
+
background_blur_radius: radius of background blur, must be odd number
|
270 |
+
contour_width: width of mask contour, must be odd number
|
271 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
272 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
273 |
+
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
|
274 |
+
|
275 |
+
Output:
|
276 |
+
painted_image: numpy array
|
277 |
+
"""
|
278 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
279 |
+
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
280 |
+
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
|
281 |
+
|
282 |
+
# downsample input image and mask
|
283 |
+
width, height = input_image.shape[0], input_image.shape[1]
|
284 |
+
res = 1024
|
285 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
286 |
+
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
287 |
+
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
288 |
+
|
289 |
+
# 0: background, 1: foreground
|
290 |
+
msk = np.clip(input_mask, 0, 1)
|
291 |
+
|
292 |
+
# generate masks for background and contour pixels
|
293 |
+
background_radius = (background_blur_radius - 1) // 2
|
294 |
+
contour_radius = (contour_width - 1) // 2
|
295 |
+
generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
|
296 |
+
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
297 |
+
|
298 |
+
# paint
|
299 |
+
painted_image = vis_add_mask_wo_gaussian \
|
300 |
+
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
|
301 |
+
|
302 |
+
return painted_image
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == '__main__':
|
306 |
+
|
307 |
+
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
308 |
+
background_blur_radius = 31 # radius of background blur, must be odd number
|
309 |
+
contour_width = 11 # contour width, must be odd number
|
310 |
+
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
|
311 |
+
contour_alpha = 1 # transparency of background, 0: no contour highlighted
|
312 |
+
|
313 |
+
# load input image and mask
|
314 |
+
input_image = np.array(Image.open('./test_images/painter_input_image.jpg').convert('RGB'))
|
315 |
+
input_mask = np.array(Image.open('./test_images/painter_input_mask.jpg').convert('P'))
|
316 |
+
|
317 |
+
# paint
|
318 |
+
overall_time_1 = 0
|
319 |
+
overall_time_2 = 0
|
320 |
+
overall_time_3 = 0
|
321 |
+
overall_time_4 = 0
|
322 |
+
overall_time_5 = 0
|
323 |
+
|
324 |
+
for i in range(50):
|
325 |
+
t2 = time.time()
|
326 |
+
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
|
327 |
+
e2 = time.time()
|
328 |
+
|
329 |
+
t3 = time.time()
|
330 |
+
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
|
331 |
+
e3 = time.time()
|
332 |
+
|
333 |
+
t1 = time.time()
|
334 |
+
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
|
335 |
+
e1 = time.time()
|
336 |
+
|
337 |
+
t4 = time.time()
|
338 |
+
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
|
339 |
+
e4 = time.time()
|
340 |
+
|
341 |
+
t5 = time.time()
|
342 |
+
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
|
343 |
+
e5 = time.time()
|
344 |
+
|
345 |
+
overall_time_1 += (e1 - t1)
|
346 |
+
overall_time_2 += (e2 - t2)
|
347 |
+
overall_time_3 += (e3 - t3)
|
348 |
+
overall_time_4 += (e4 - t4)
|
349 |
+
overall_time_5 += (e5 - t5)
|
350 |
+
|
351 |
+
print(f'average time w gaussian: {overall_time_1/50}')
|
352 |
+
print(f'average time w/o gaussian00: {overall_time_2/50}')
|
353 |
+
print(f'average time w/o gaussian10: {overall_time_3/50}')
|
354 |
+
print(f'average time w/o gaussian01: {overall_time_4/50}')
|
355 |
+
print(f'average time w/o gaussian11: {overall_time_5/50}')
|
356 |
+
|
357 |
+
# save
|
358 |
+
painted_image_00 = Image.fromarray(painted_image_00)
|
359 |
+
painted_image_00.save('./test_images/painter_output_image_00.png')
|
360 |
+
|
361 |
+
painted_image_10 = Image.fromarray(painted_image_10)
|
362 |
+
painted_image_10.save('./test_images/painter_output_image_10.png')
|
363 |
+
|
364 |
+
painted_image_01 = Image.fromarray(painted_image_01)
|
365 |
+
painted_image_01.save('./test_images/painter_output_image_01.png')
|
366 |
+
|
367 |
+
painted_image_11 = Image.fromarray(painted_image_11)
|
368 |
+
painted_image_11.save('./test_images/painter_output_image_11.png')
|
369 |
+
|
370 |
+
|
371 |
+
seg_model_map = {
|
372 |
+
'base': 'vit_b',
|
373 |
+
'large': 'vit_l',
|
374 |
+
'huge': 'vit_h'
|
375 |
+
}
|
376 |
+
ckpt_url_map = {
|
377 |
+
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
378 |
+
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
379 |
+
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
|
380 |
+
}
|
381 |
+
expected_sha256_map = {
|
382 |
+
'vit_b': 'ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912',
|
383 |
+
'vit_l': '3adcc4315b642a4d2101128f611684e8734c41232a17c648ed1693702a49a622',
|
384 |
+
'vit_h': 'a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e'
|
385 |
+
}
|
386 |
+
def prepare_segmenter(segmenter = "huge", download_root: str = None):
|
387 |
+
"""
|
388 |
+
Prepare segmenter model and download checkpoint if necessary.
|
389 |
+
|
390 |
+
Returns: segmenter model name from 'vit_b', 'vit_l', 'vit_h'.
|
391 |
+
|
392 |
+
"""
|
393 |
+
|
394 |
+
os.makedirs('result', exist_ok=True)
|
395 |
+
seg_model_name = seg_model_map[segmenter]
|
396 |
+
checkpoint_url = ckpt_url_map[seg_model_name]
|
397 |
+
folder = download_root or os.path.expanduser("~/.cache/SAM")
|
398 |
+
filename = os.path.basename(checkpoint_url)
|
399 |
+
segmenter_checkpoint = download_checkpoint(checkpoint_url, folder, filename, expected_sha256_map[seg_model_name])
|
400 |
+
|
401 |
+
return seg_model_name, segmenter_checkpoint
|
402 |
+
|
403 |
+
|
404 |
+
def download_checkpoint(url, folder, filename, expected_sha256):
|
405 |
+
os.makedirs(folder, exist_ok=True)
|
406 |
+
download_target = os.path.join(folder, filename)
|
407 |
+
if os.path.isfile(download_target):
|
408 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
409 |
+
return download_target
|
410 |
+
|
411 |
+
print(f'Download SAM checkpoint {url}, saving to {download_target} ...')
|
412 |
+
with requests.get(url, stream=True) as response, open(download_target, "wb") as output:
|
413 |
+
progress = tqdm(total=int(response.headers.get('content-length', 0)), unit='B', unit_scale=True)
|
414 |
+
for data in response.iter_content(chunk_size=1024):
|
415 |
+
size = output.write(data)
|
416 |
+
progress.update(size)
|
417 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
418 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
419 |
+
return download_target
|
env.sh
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
conda create -n caption_anything python=3.8 -y
|
2 |
-
source activate caption_anything
|
3 |
-
pip install -r requirements.txt
|
4 |
-
# cd [email protected]
|
5 |
-
# wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
segmenter/__init__.py
DELETED
@@ -1,5 +0,0 @@
|
|
1 |
-
from segmenter.base_segmenter import BaseSegmenter
|
2 |
-
|
3 |
-
|
4 |
-
def build_segmenter(type, device, args=None, model=None):
|
5 |
-
return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
|
|
|
|
|
|
|
|
|
|
|
|
segmenter/images/truck.jpg
DELETED
Binary file (271 kB)
|
|
segmenter/sam_vit_h_4b8939.pth
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
-
size 2564550879
|
|
|
|
|
|
|
|
test_img/img0.png
DELETED
Binary file (185 kB)
|
|
test_img/img1.jpg
DELETED
Binary file (501 kB)
|
|
test_img/img1.jpg.raw_mask.png
DELETED
Binary file (114 kB)
|
|
test_img/img10.jpg
DELETED
Binary file (376 kB)
|
|
test_img/img10.jpg.raw_mask.png
DELETED
Binary file (24.3 kB)
|
|
test_img/img11.jpg
DELETED
Binary file (616 kB)
|
|
test_img/img12.jpg
DELETED
Binary file (277 kB)
|
|
test_img/img12.jpg.raw_mask.png
DELETED
Binary file (29.1 kB)
|
|
test_img/img13.jpg
DELETED
Binary file (335 kB)
|
|
test_img/img13.jpg.raw_mask.png
DELETED
Binary file (22.9 kB)
|
|
test_img/img14.jpg
DELETED
Binary file (741 kB)
|
|
test_img/img14.jpg.raw_mask.png
DELETED
Binary file (26.9 kB)
|
|
test_img/img15.jpg
DELETED
Binary file (376 kB)
|
|
test_img/img15.jpg.raw_mask.png
DELETED
Binary file (114 kB)
|
|
test_img/img16.jpg
DELETED
Binary file (337 kB)
|
|