Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
•
ff883a7
1
Parent(s):
73cef8f
update promtps for chating, add duplicate icon
Browse files- app.py +88 -72
- caption_anything.py +1 -1
- captioner/base_captioner.py +5 -1
- captioner/blip2.py +10 -4
- segmenter/__init__.py +1 -4
- tools.py +16 -11
app.py
CHANGED
@@ -20,6 +20,7 @@ from segment_anything import sam_model_registry
|
|
20 |
from text_refiner import build_text_refiner
|
21 |
from segmenter import build_segmenter
|
22 |
|
|
|
23 |
def download_checkpoint(url, folder, filename):
|
24 |
os.makedirs(folder, exist_ok=True)
|
25 |
filepath = os.path.join(folder, filename)
|
@@ -32,16 +33,11 @@ def download_checkpoint(url, folder, filename):
|
|
32 |
f.write(chunk)
|
33 |
|
34 |
return filepath
|
35 |
-
checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
36 |
-
folder = "segmenter"
|
37 |
-
filename = "sam_vit_h_4b8939.pth"
|
38 |
-
|
39 |
-
download_checkpoint(checkpoint_url, folder, filename)
|
40 |
|
41 |
|
42 |
-
title = """<h1 align="center">Caption-Anything</h1>
|
43 |
-
description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
|
44 |
"""
|
|
|
45 |
|
46 |
examples = [
|
47 |
["test_img/img35.webp"],
|
@@ -53,7 +49,26 @@ examples = [
|
|
53 |
["test_img/img1.jpg"],
|
54 |
]
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
args = parse_augment()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
# args.device = 'cuda:5'
|
58 |
# args.disable_gpt = True
|
59 |
# args.enable_reduce_tokens = False
|
@@ -61,7 +76,7 @@ args = parse_augment()
|
|
61 |
# args.captioner = 'blip'
|
62 |
# args.regular_box = True
|
63 |
shared_captioner = build_captioner(args.captioner, args.device, args)
|
64 |
-
shared_sam_model = sam_model_registry[
|
65 |
|
66 |
|
67 |
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
|
@@ -102,7 +117,7 @@ def get_prompt(chat_input, click_state, click_mode):
|
|
102 |
click_state[1] = labels
|
103 |
else:
|
104 |
raise NotImplementedError
|
105 |
-
|
106 |
prompt = {
|
107 |
"prompt_type":["click"],
|
108 |
"input_point":click_state[0],
|
@@ -117,21 +132,21 @@ def update_click_state(click_state, caption, click_mode):
|
|
117 |
elif click_mode == 'Single':
|
118 |
click_state[2] = [caption]
|
119 |
else:
|
120 |
-
raise NotImplementedError
|
121 |
|
122 |
|
123 |
-
def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
|
124 |
if text_refiner is None:
|
125 |
response = "Text refiner is not initilzed, please input openai api key."
|
126 |
state = state + [(chat_input, response)]
|
127 |
return state, state, chat_state
|
128 |
-
|
129 |
points, labels, captions = click_state
|
130 |
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
|
131 |
suffix = '\nHuman: {chat_input}\nAI: '
|
132 |
qa_template = '\nHuman: {q}\nAI: {a}'
|
133 |
# # "The image is of width {width} and height {height}."
|
134 |
-
point_chat_prompt = "I am an AI trained to chat with you about an image
|
135 |
prev_visual_context = ""
|
136 |
pos_points = []
|
137 |
pos_captions = []
|
@@ -139,8 +154,8 @@ def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
|
|
139 |
if labels[i] == 1:
|
140 |
pos_points.append(f"({points[i][0]}, {points[i][0]})")
|
141 |
pos_captions.append(captions[i])
|
142 |
-
prev_visual_context = prev_visual_context + '\n' + '
|
143 |
-
|
144 |
context_length_thres = 500
|
145 |
prev_history = ""
|
146 |
for i in range(len(chat_state)):
|
@@ -149,26 +164,25 @@ def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
|
|
149 |
prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
|
150 |
else:
|
151 |
break
|
152 |
-
|
153 |
-
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
|
154 |
print('\nchat_prompt: ', chat_prompt)
|
155 |
response = text_refiner.llm(chat_prompt)
|
156 |
state = state + [(chat_input, response)]
|
157 |
chat_state = chat_state + [(chat_input, response)]
|
158 |
return state, state, chat_state
|
159 |
|
160 |
-
def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
161 |
-
|
162 |
-
|
163 |
model = build_caption_anything_with_models(
|
164 |
-
args,
|
165 |
api_key="",
|
166 |
captioner=shared_captioner,
|
167 |
sam_model=shared_sam_model,
|
168 |
text_refiner=text_refiner,
|
169 |
session_id=iface.app_id
|
170 |
)
|
171 |
-
|
172 |
model.segmenter.image_embedding = image_embedding
|
173 |
model.segmenter.predictor.original_size = original_size
|
174 |
model.segmenter.predictor.input_size = input_size
|
@@ -178,11 +192,11 @@ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, langua
|
|
178 |
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
179 |
else:
|
180 |
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
181 |
-
|
182 |
controls = {'length': length,
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
|
187 |
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
188 |
# chat_input = click_coordinate
|
@@ -217,8 +231,7 @@ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, langua
|
|
217 |
yield state, state, click_state, chat_input, refined_image_input, wiki
|
218 |
|
219 |
|
220 |
-
def upload_callback(image_input, state):
|
221 |
-
state = [] + [(None, 'Image size: ' + str(image_input.size))]
|
222 |
chat_state = []
|
223 |
click_state = [[], [], []]
|
224 |
res = 1024
|
@@ -227,9 +240,9 @@ def upload_callback(image_input, state):
|
|
227 |
if ratio < 1.0:
|
228 |
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
229 |
print('Scaling input image to {}'.format(image_input.size))
|
230 |
-
|
231 |
model = build_caption_anything_with_models(
|
232 |
-
args,
|
233 |
api_key="",
|
234 |
captioner=shared_captioner,
|
235 |
sam_model=shared_sam_model,
|
@@ -239,10 +252,11 @@ def upload_callback(image_input, state):
|
|
239 |
image_embedding = model.segmenter.image_embedding
|
240 |
original_size = model.segmenter.predictor.original_size
|
241 |
input_size = model.segmenter.predictor.input_size
|
242 |
-
|
|
|
243 |
|
244 |
with gr.Blocks(
|
245 |
-
|
246 |
#image_upload{min-height:400px}
|
247 |
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
|
248 |
'''
|
@@ -255,6 +269,7 @@ with gr.Blocks(
|
|
255 |
text_refiner = gr.State(None)
|
256 |
original_size = gr.State(None)
|
257 |
input_size = gr.State(None)
|
|
|
258 |
|
259 |
gr.Markdown(title)
|
260 |
gr.Markdown(description)
|
@@ -281,13 +296,13 @@ with gr.Blocks(
|
|
281 |
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
282 |
with gr.Column(visible=False) as modules_need_gpt:
|
283 |
with gr.Row(scale=1.0):
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
with gr.Row(scale=1.0):
|
292 |
factuality = gr.Radio(
|
293 |
choices=["Factual", "Imagination"],
|
@@ -304,10 +319,10 @@ with gr.Blocks(
|
|
304 |
label="Generated Caption Length",
|
305 |
)
|
306 |
enable_wiki = gr.Radio(
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
with gr.Column(visible=True) as modules_not_need_gpt3:
|
312 |
gr.Examples(
|
313 |
examples=examples,
|
@@ -332,11 +347,11 @@ with gr.Blocks(
|
|
332 |
with gr.Row():
|
333 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
334 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
335 |
-
|
336 |
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
337 |
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
338 |
disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
339 |
-
|
340 |
clear_button_clike.click(
|
341 |
lambda x: ([[], [], []], x, ""),
|
342 |
[origin_image],
|
@@ -345,9 +360,9 @@ with gr.Blocks(
|
|
345 |
show_progress=False
|
346 |
)
|
347 |
clear_button_image.click(
|
348 |
-
lambda: (None, [], [], [], [[], [], []], "", ""),
|
349 |
[],
|
350 |
-
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
|
351 |
queue=False,
|
352 |
show_progress=False
|
353 |
)
|
@@ -359,37 +374,38 @@ with gr.Blocks(
|
|
359 |
show_progress=False
|
360 |
)
|
361 |
image_input.clear(
|
362 |
-
lambda: (None, [], [], [], [[], [], []], "", ""),
|
363 |
[],
|
364 |
-
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
|
365 |
queue=False,
|
366 |
show_progress=False
|
367 |
)
|
368 |
|
369 |
-
image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
370 |
-
chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner], [chatbot, state, chat_state])
|
371 |
-
|
|
|
372 |
|
373 |
# select coordinate
|
374 |
-
image_input.select(inference_seg_cap,
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
395 |
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|
20 |
from text_refiner import build_text_refiner
|
21 |
from segmenter import build_segmenter
|
22 |
|
23 |
+
|
24 |
def download_checkpoint(url, folder, filename):
|
25 |
os.makedirs(folder, exist_ok=True)
|
26 |
filepath = os.path.join(folder, filename)
|
|
|
33 |
f.write(chunk)
|
34 |
|
35 |
return filepath
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
+
title = """<p><h1 align="center">Caption-Anything</h1></p>
|
|
|
39 |
"""
|
40 |
+
description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
|
41 |
|
42 |
examples = [
|
43 |
["test_img/img35.webp"],
|
|
|
49 |
["test_img/img1.jpg"],
|
50 |
]
|
51 |
|
52 |
+
seg_model_map = {
|
53 |
+
'base': 'vit_b',
|
54 |
+
'large': 'vit_l',
|
55 |
+
'huge': 'vit_h'
|
56 |
+
}
|
57 |
+
ckpt_url_map = {
|
58 |
+
'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
|
59 |
+
'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
|
60 |
+
'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
|
61 |
+
}
|
62 |
+
|
63 |
args = parse_augment()
|
64 |
+
|
65 |
+
checkpoint_url = ckpt_url_map[seg_model_map[args.segmenter]]
|
66 |
+
folder = "segmenter"
|
67 |
+
filename = os.path.basename(checkpoint_url)
|
68 |
+
args.segmenter_checkpoint = os.path.join(folder, filename)
|
69 |
+
|
70 |
+
download_checkpoint(checkpoint_url, folder, filename)
|
71 |
+
|
72 |
# args.device = 'cuda:5'
|
73 |
# args.disable_gpt = True
|
74 |
# args.enable_reduce_tokens = False
|
|
|
76 |
# args.captioner = 'blip'
|
77 |
# args.regular_box = True
|
78 |
shared_captioner = build_captioner(args.captioner, args.device, args)
|
79 |
+
shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=args.segmenter_checkpoint).to(args.device)
|
80 |
|
81 |
|
82 |
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
|
|
|
117 |
click_state[1] = labels
|
118 |
else:
|
119 |
raise NotImplementedError
|
120 |
+
|
121 |
prompt = {
|
122 |
"prompt_type":["click"],
|
123 |
"input_point":click_state[0],
|
|
|
132 |
elif click_mode == 'Single':
|
133 |
click_state[2] = [caption]
|
134 |
else:
|
135 |
+
raise NotImplementedError
|
136 |
|
137 |
|
138 |
+
def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
|
139 |
if text_refiner is None:
|
140 |
response = "Text refiner is not initilzed, please input openai api key."
|
141 |
state = state + [(chat_input, response)]
|
142 |
return state, state, chat_state
|
143 |
+
|
144 |
points, labels, captions = click_state
|
145 |
# point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
|
146 |
suffix = '\nHuman: {chat_input}\nAI: '
|
147 |
qa_template = '\nHuman: {q}\nAI: {a}'
|
148 |
# # "The image is of width {width} and height {height}."
|
149 |
+
point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \n Now, let's chat!"
|
150 |
prev_visual_context = ""
|
151 |
pos_points = []
|
152 |
pos_captions = []
|
|
|
154 |
if labels[i] == 1:
|
155 |
pos_points.append(f"({points[i][0]}, {points[i][0]})")
|
156 |
pos_captions.append(captions[i])
|
157 |
+
prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(pos_captions[-1], ', '.join(pos_points))
|
158 |
+
|
159 |
context_length_thres = 500
|
160 |
prev_history = ""
|
161 |
for i in range(len(chat_state)):
|
|
|
164 |
prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
|
165 |
else:
|
166 |
break
|
167 |
+
chat_prompt = point_chat_prompt.format(**{"img_caption":img_caption,"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
|
|
|
168 |
print('\nchat_prompt: ', chat_prompt)
|
169 |
response = text_refiner.llm(chat_prompt)
|
170 |
state = state + [(chat_input, response)]
|
171 |
chat_state = chat_state + [(chat_input, response)]
|
172 |
return state, state, chat_state
|
173 |
|
174 |
+
def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
|
175 |
+
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
176 |
+
|
177 |
model = build_caption_anything_with_models(
|
178 |
+
args,
|
179 |
api_key="",
|
180 |
captioner=shared_captioner,
|
181 |
sam_model=shared_sam_model,
|
182 |
text_refiner=text_refiner,
|
183 |
session_id=iface.app_id
|
184 |
)
|
185 |
+
|
186 |
model.segmenter.image_embedding = image_embedding
|
187 |
model.segmenter.predictor.original_size = original_size
|
188 |
model.segmenter.predictor.input_size = input_size
|
|
|
192 |
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
193 |
else:
|
194 |
coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
|
195 |
+
|
196 |
controls = {'length': length,
|
197 |
+
'sentiment': sentiment,
|
198 |
+
'factuality': factuality,
|
199 |
+
'language': language}
|
200 |
|
201 |
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
202 |
# chat_input = click_coordinate
|
|
|
231 |
yield state, state, click_state, chat_input, refined_image_input, wiki
|
232 |
|
233 |
|
234 |
+
def upload_callback(image_input, state):
|
|
|
235 |
chat_state = []
|
236 |
click_state = [[], [], []]
|
237 |
res = 1024
|
|
|
240 |
if ratio < 1.0:
|
241 |
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
242 |
print('Scaling input image to {}'.format(image_input.size))
|
243 |
+
state = [] + [(None, 'Image size: ' + str(image_input.size))]
|
244 |
model = build_caption_anything_with_models(
|
245 |
+
args,
|
246 |
api_key="",
|
247 |
captioner=shared_captioner,
|
248 |
sam_model=shared_sam_model,
|
|
|
252 |
image_embedding = model.segmenter.image_embedding
|
253 |
original_size = model.segmenter.predictor.original_size
|
254 |
input_size = model.segmenter.predictor.input_size
|
255 |
+
img_caption, _ = model.captioner.inference_seg(image_input)
|
256 |
+
return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size, img_caption
|
257 |
|
258 |
with gr.Blocks(
|
259 |
+
css='''
|
260 |
#image_upload{min-height:400px}
|
261 |
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
|
262 |
'''
|
|
|
269 |
text_refiner = gr.State(None)
|
270 |
original_size = gr.State(None)
|
271 |
input_size = gr.State(None)
|
272 |
+
img_caption = gr.State(None)
|
273 |
|
274 |
gr.Markdown(title)
|
275 |
gr.Markdown(description)
|
|
|
296 |
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
297 |
with gr.Column(visible=False) as modules_need_gpt:
|
298 |
with gr.Row(scale=1.0):
|
299 |
+
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
300 |
+
sentiment = gr.Radio(
|
301 |
+
choices=["Positive", "Natural", "Negative"],
|
302 |
+
value="Natural",
|
303 |
+
label="Sentiment",
|
304 |
+
interactive=True,
|
305 |
+
)
|
306 |
with gr.Row(scale=1.0):
|
307 |
factuality = gr.Radio(
|
308 |
choices=["Factual", "Imagination"],
|
|
|
319 |
label="Generated Caption Length",
|
320 |
)
|
321 |
enable_wiki = gr.Radio(
|
322 |
+
choices=["Yes", "No"],
|
323 |
+
value="No",
|
324 |
+
label="Enable Wiki",
|
325 |
+
interactive=True)
|
326 |
with gr.Column(visible=True) as modules_not_need_gpt3:
|
327 |
gr.Examples(
|
328 |
examples=examples,
|
|
|
347 |
with gr.Row():
|
348 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
349 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
350 |
+
|
351 |
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
352 |
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
353 |
disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
354 |
+
|
355 |
clear_button_clike.click(
|
356 |
lambda x: ([[], [], []], x, ""),
|
357 |
[origin_image],
|
|
|
360 |
show_progress=False
|
361 |
)
|
362 |
clear_button_image.click(
|
363 |
+
lambda: (None, [], [], [], [[], [], []], "", "", ""),
|
364 |
[],
|
365 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
|
366 |
queue=False,
|
367 |
show_progress=False
|
368 |
)
|
|
|
374 |
show_progress=False
|
375 |
)
|
376 |
image_input.clear(
|
377 |
+
lambda: (None, [], [], [], [[], [], []], "", "", ""),
|
378 |
[],
|
379 |
+
[image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
|
380 |
queue=False,
|
381 |
show_progress=False
|
382 |
)
|
383 |
|
384 |
+
image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size, img_caption])
|
385 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption], [chatbot, state, chat_state])
|
386 |
+
chat_input.submit(lambda: "", None, chat_input)
|
387 |
+
example_image.change(upload_callback,[example_image, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size, img_caption])
|
388 |
|
389 |
# select coordinate
|
390 |
+
image_input.select(inference_seg_cap,
|
391 |
+
inputs=[
|
392 |
+
origin_image,
|
393 |
+
point_prompt,
|
394 |
+
click_mode,
|
395 |
+
enable_wiki,
|
396 |
+
language,
|
397 |
+
sentiment,
|
398 |
+
factuality,
|
399 |
+
length,
|
400 |
+
image_embedding,
|
401 |
+
state,
|
402 |
+
click_state,
|
403 |
+
original_size,
|
404 |
+
input_size,
|
405 |
+
text_refiner
|
406 |
+
],
|
407 |
+
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
408 |
+
show_progress=False, queue=True)
|
409 |
+
|
410 |
iface.queue(concurrency_count=5, api_open=False, max_size=10)
|
411 |
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
caption_anything.py
CHANGED
@@ -72,7 +72,7 @@ class CaptionAnything():
|
|
72 |
def parse_augment():
|
73 |
parser = argparse.ArgumentParser()
|
74 |
parser.add_argument('--captioner', type=str, default="blip2")
|
75 |
-
parser.add_argument('--segmenter', type=str, default="
|
76 |
parser.add_argument('--text_refiner', type=str, default="base")
|
77 |
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
78 |
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
|
|
72 |
def parse_augment():
|
73 |
parser = argparse.ArgumentParser()
|
74 |
parser.add_argument('--captioner', type=str, default="blip2")
|
75 |
+
parser.add_argument('--segmenter', type=str, default="huge")
|
76 |
parser.add_argument('--text_refiner', type=str, default="base")
|
77 |
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
78 |
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
captioner/base_captioner.py
CHANGED
@@ -130,13 +130,17 @@ class BaseCaptioner:
|
|
130 |
return caption, crop_save_path
|
131 |
|
132 |
|
133 |
-
def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, disable_regular_box = False):
|
|
|
|
|
|
|
134 |
if type(image) == str:
|
135 |
image = Image.open(image)
|
136 |
if type(seg_mask) == str:
|
137 |
seg_mask = Image.open(seg_mask)
|
138 |
elif type(seg_mask) == np.ndarray:
|
139 |
seg_mask = Image.fromarray(seg_mask)
|
|
|
140 |
seg_mask = seg_mask.resize(image.size)
|
141 |
seg_mask = np.array(seg_mask) > 0
|
142 |
|
|
|
130 |
return caption, crop_save_path
|
131 |
|
132 |
|
133 |
+
def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str]=None, crop_mode="w_bg", filter=False, disable_regular_box = False):
|
134 |
+
if seg_mask is None:
|
135 |
+
seg_mask = np.ones(image.size).astype(bool)
|
136 |
+
|
137 |
if type(image) == str:
|
138 |
image = Image.open(image)
|
139 |
if type(seg_mask) == str:
|
140 |
seg_mask = Image.open(seg_mask)
|
141 |
elif type(seg_mask) == np.ndarray:
|
142 |
seg_mask = Image.fromarray(seg_mask)
|
143 |
+
|
144 |
seg_mask = seg_mask.resize(image.size)
|
145 |
seg_mask = np.array(seg_mask) > 0
|
146 |
|
captioner/blip2.py
CHANGED
@@ -6,6 +6,8 @@ import pdb
|
|
6 |
import cv2
|
7 |
import numpy as np
|
8 |
from typing import Union
|
|
|
|
|
9 |
from .base_captioner import BaseCaptioner
|
10 |
|
11 |
class BLIP2Captioner(BaseCaptioner):
|
@@ -15,14 +17,18 @@ class BLIP2Captioner(BaseCaptioner):
|
|
15 |
self.dialogue = dialogue
|
16 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
17 |
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
18 |
-
|
|
|
|
|
|
|
|
|
19 |
@torch.no_grad()
|
20 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
21 |
if type(image) == str: # input path
|
22 |
-
|
23 |
|
24 |
if not self.dialogue:
|
25 |
-
text_prompt = '
|
26 |
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
27 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
28 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
@@ -42,7 +48,7 @@ class BLIP2Captioner(BaseCaptioner):
|
|
42 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
43 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
44 |
context.append((input_texts, captions))
|
45 |
-
|
46 |
return captions
|
47 |
|
48 |
if __name__ == '__main__':
|
|
|
6 |
import cv2
|
7 |
import numpy as np
|
8 |
from typing import Union
|
9 |
+
|
10 |
+
from tools import is_platform_win
|
11 |
from .base_captioner import BaseCaptioner
|
12 |
|
13 |
class BLIP2Captioner(BaseCaptioner):
|
|
|
17 |
self.dialogue = dialogue
|
18 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
19 |
self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
|
20 |
+
if is_platform_win():
|
21 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="sequential", torch_dtype=self.torch_dtype)
|
22 |
+
else:
|
23 |
+
self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)
|
24 |
+
|
25 |
@torch.no_grad()
|
26 |
def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
|
27 |
if type(image) == str: # input path
|
28 |
+
image = Image.open(image)
|
29 |
|
30 |
if not self.dialogue:
|
31 |
+
text_prompt = 'Question: what does the image show? Answer:'
|
32 |
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
33 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
34 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
|
|
48 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
49 |
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
50 |
context.append((input_texts, captions))
|
51 |
+
|
52 |
return captions
|
53 |
|
54 |
if __name__ == '__main__':
|
segmenter/__init__.py
CHANGED
@@ -2,7 +2,4 @@ from segmenter.base_segmenter import BaseSegmenter
|
|
2 |
|
3 |
|
4 |
def build_segmenter(type, device, args=None, model=None):
|
5 |
-
|
6 |
-
return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
|
7 |
-
else:
|
8 |
-
raise NotImplementedError()
|
|
|
2 |
|
3 |
|
4 |
def build_segmenter(type, device, args=None, model=None):
|
5 |
+
return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
|
|
|
|
|
|
tools.py
CHANGED
@@ -4,6 +4,11 @@ import numpy as np
|
|
4 |
from PIL import Image
|
5 |
import copy
|
6 |
import time
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
def colormap(rgb=True):
|
@@ -130,10 +135,10 @@ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_co
|
|
130 |
|
131 |
for i in range(3):
|
132 |
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
|
133 |
-
|
134 |
-
|
135 |
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
|
136 |
-
|
137 |
|
138 |
return image.astype('uint8')
|
139 |
|
@@ -155,7 +160,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
|
|
155 |
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
156 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
157 |
|
158 |
-
|
159 |
# 0: background, 1: foreground
|
160 |
input_mask[input_mask>0] = 255
|
161 |
|
@@ -170,7 +175,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
|
|
170 |
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
|
171 |
|
172 |
# painted_image = background_dist_map
|
173 |
-
|
174 |
return painted_image
|
175 |
|
176 |
|
@@ -257,10 +262,10 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
|
|
257 |
# downsample input image and mask
|
258 |
width, height = input_image.shape[0], input_image.shape[1]
|
259 |
res = 1024
|
260 |
-
ratio = min(1.0 * res / max(width, height), 1.0)
|
261 |
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
262 |
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
263 |
-
|
264 |
# 0: background, 1: foreground
|
265 |
msk = np.clip(input_mask, 0, 1)
|
266 |
|
@@ -271,14 +276,14 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
|
|
271 |
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
272 |
|
273 |
# paint
|
274 |
-
painted_image = vis_add_mask_wo_gaussian\
|
275 |
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
|
276 |
|
277 |
return painted_image
|
278 |
|
279 |
|
280 |
if __name__ == '__main__':
|
281 |
-
|
282 |
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
283 |
background_blur_radius = 31 # radius of background blur, must be odd number
|
284 |
contour_width = 11 # contour width, must be odd number
|
@@ -288,14 +293,14 @@ if __name__ == '__main__':
|
|
288 |
# load input image and mask
|
289 |
input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
|
290 |
input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
|
291 |
-
|
292 |
# paint
|
293 |
overall_time_1 = 0
|
294 |
overall_time_2 = 0
|
295 |
overall_time_3 = 0
|
296 |
overall_time_4 = 0
|
297 |
overall_time_5 = 0
|
298 |
-
|
299 |
for i in range(50):
|
300 |
t2 = time.time()
|
301 |
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
|
|
|
4 |
from PIL import Image
|
5 |
import copy
|
6 |
import time
|
7 |
+
import sys
|
8 |
+
|
9 |
+
|
10 |
+
def is_platform_win():
|
11 |
+
return sys.platform == "win32"
|
12 |
|
13 |
|
14 |
def colormap(rgb=True):
|
|
|
135 |
|
136 |
for i in range(3):
|
137 |
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
|
138 |
+
+ background_color[i] * (background_alpha-background_mask*background_alpha)
|
139 |
+
|
140 |
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
|
141 |
+
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
|
142 |
|
143 |
return image.astype('uint8')
|
144 |
|
|
|
160 |
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
161 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
162 |
|
163 |
+
|
164 |
# 0: background, 1: foreground
|
165 |
input_mask[input_mask>0] = 255
|
166 |
|
|
|
175 |
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
|
176 |
|
177 |
# painted_image = background_dist_map
|
178 |
+
|
179 |
return painted_image
|
180 |
|
181 |
|
|
|
262 |
# downsample input image and mask
|
263 |
width, height = input_image.shape[0], input_image.shape[1]
|
264 |
res = 1024
|
265 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
266 |
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
267 |
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
268 |
+
|
269 |
# 0: background, 1: foreground
|
270 |
msk = np.clip(input_mask, 0, 1)
|
271 |
|
|
|
276 |
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
277 |
|
278 |
# paint
|
279 |
+
painted_image = vis_add_mask_wo_gaussian \
|
280 |
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
|
281 |
|
282 |
return painted_image
|
283 |
|
284 |
|
285 |
if __name__ == '__main__':
|
286 |
+
|
287 |
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
288 |
background_blur_radius = 31 # radius of background blur, must be odd number
|
289 |
contour_width = 11 # contour width, must be odd number
|
|
|
293 |
# load input image and mask
|
294 |
input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
|
295 |
input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
|
296 |
+
|
297 |
# paint
|
298 |
overall_time_1 = 0
|
299 |
overall_time_2 = 0
|
300 |
overall_time_3 = 0
|
301 |
overall_time_4 = 0
|
302 |
overall_time_5 = 0
|
303 |
+
|
304 |
for i in range(50):
|
305 |
t2 = time.time()
|
306 |
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
|