Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
•
5c74464
1
Parent(s):
7d57bb5
fix bugs of example images and api keys
Browse files- Image/demo1.svg +0 -0
- Image/demo2.svg +0 -0
- app.py +78 -59
- app_old.py +2 -2
- caption_anything.py +18 -14
- captioner/base_captioner.py +8 -8
- captioner/blip.py +4 -4
- captioner/blip2.py +3 -2
- captioner/git.py +4 -4
- image_editing_utils.py +3 -2
- segmenter/base_segmenter.py +1 -1
- text_refiner/__init__.py +2 -2
- text_refiner/text_refiner.py +2 -5
- tools.py +177 -21
Image/demo1.svg
CHANGED
Image/demo2.svg
CHANGED
app.py
CHANGED
@@ -40,16 +40,16 @@ description = """Gradio demo for Caption Anything, image to dense captioning gen
|
|
40 |
"""
|
41 |
|
42 |
examples = [
|
|
|
43 |
["test_img/img2.jpg"],
|
44 |
["test_img/img5.jpg"],
|
45 |
["test_img/img12.jpg"],
|
46 |
["test_img/img14.jpg"],
|
|
|
|
|
47 |
]
|
48 |
|
49 |
args = parse_augment()
|
50 |
-
args.captioner = 'blip2'
|
51 |
-
args.seg_crop_mode = 'wo_bg'
|
52 |
-
args.regular_box = True
|
53 |
# args.device = 'cuda:5'
|
54 |
# args.disable_gpt = False
|
55 |
# args.enable_reduce_tokens = True
|
@@ -57,9 +57,10 @@ args.regular_box = True
|
|
57 |
model = CaptionAnything(args)
|
58 |
|
59 |
def init_openai_api_key(api_key):
|
60 |
-
os.environ['OPENAI_API_KEY'] = api_key
|
61 |
-
model.init_refiner()
|
62 |
-
|
|
|
63 |
|
64 |
def get_prompt(chat_input, click_state):
|
65 |
points = click_state[0]
|
@@ -78,7 +79,7 @@ def get_prompt(chat_input, click_state):
|
|
78 |
return prompt
|
79 |
|
80 |
def chat_with_points(chat_input, click_state, state):
|
81 |
-
if
|
82 |
response = "Text refiner is not initilzed, please input openai api key."
|
83 |
state = state + [(chat_input, response)]
|
84 |
return state, state
|
@@ -132,7 +133,7 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
|
|
132 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
133 |
|
134 |
yield state, state, click_state, chat_input, image_input, wiki
|
135 |
-
if not args.disable_gpt and
|
136 |
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
137 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
138 |
new_cap = refined_caption['caption']
|
@@ -143,10 +144,16 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
|
|
143 |
def upload_callback(image_input, state):
|
144 |
state = [] + [('Image size: ' + str(image_input.size), None)]
|
145 |
click_state = [[], [], []]
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
model.segmenter.image = None
|
147 |
model.segmenter.image_embedding = None
|
148 |
model.segmenter.set_image(image_input)
|
149 |
-
return state, image_input, click_state
|
150 |
|
151 |
with gr.Blocks(
|
152 |
css='''
|
@@ -163,55 +170,62 @@ with gr.Blocks(
|
|
163 |
|
164 |
with gr.Row():
|
165 |
with gr.Column(scale=1.0):
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
199 |
|
200 |
with gr.Column(scale=0.5):
|
201 |
openai_api_key = gr.Textbox(
|
202 |
-
placeholder="Input
|
203 |
show_label=False,
|
204 |
label = "OpenAI API Key",
|
205 |
lines=1,
|
206 |
type="password"
|
207 |
)
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
215 |
clear_button_clike.click(
|
216 |
lambda x: ([[], [], []], x, ""),
|
217 |
[origin_image],
|
@@ -220,9 +234,9 @@ with gr.Blocks(
|
|
220 |
show_progress=False
|
221 |
)
|
222 |
clear_button_image.click(
|
223 |
-
lambda: (None, [], [], [[], [], []], ""),
|
224 |
[],
|
225 |
-
[image_input, chatbot, state, click_state, wiki_output],
|
226 |
queue=False,
|
227 |
show_progress=False
|
228 |
)
|
@@ -234,20 +248,25 @@ with gr.Blocks(
|
|
234 |
show_progress=False
|
235 |
)
|
236 |
image_input.clear(
|
237 |
-
lambda: (None, [], [], [[], [], []], ""),
|
238 |
[],
|
239 |
-
[image_input, chatbot, state, click_state, wiki_output],
|
240 |
queue=False,
|
241 |
show_progress=False
|
242 |
)
|
243 |
|
244 |
-
|
|
|
|
|
|
|
|
|
245 |
examples=examples,
|
246 |
-
inputs=[
|
247 |
)
|
248 |
|
249 |
-
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state])
|
250 |
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
|
|
251 |
|
252 |
# select coordinate
|
253 |
image_input.select(inference_seg_cap,
|
@@ -264,5 +283,5 @@ with gr.Blocks(
|
|
264 |
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
265 |
show_progress=False, queue=True)
|
266 |
|
267 |
-
iface.queue(concurrency_count=
|
268 |
-
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|
40 |
"""
|
41 |
|
42 |
examples = [
|
43 |
+
["test_img/img35.webp"],
|
44 |
["test_img/img2.jpg"],
|
45 |
["test_img/img5.jpg"],
|
46 |
["test_img/img12.jpg"],
|
47 |
["test_img/img14.jpg"],
|
48 |
+
["test_img/img0.png"],
|
49 |
+
["test_img/img1.jpg"],
|
50 |
]
|
51 |
|
52 |
args = parse_augment()
|
|
|
|
|
|
|
53 |
# args.device = 'cuda:5'
|
54 |
# args.disable_gpt = False
|
55 |
# args.enable_reduce_tokens = True
|
|
|
57 |
model = CaptionAnything(args)
|
58 |
|
59 |
def init_openai_api_key(api_key):
|
60 |
+
# os.environ['OPENAI_API_KEY'] = api_key
|
61 |
+
model.init_refiner(api_key)
|
62 |
+
openai_available = model.text_refiner is not None
|
63 |
+
return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True)
|
64 |
|
65 |
def get_prompt(chat_input, click_state):
|
66 |
points = click_state[0]
|
|
|
79 |
return prompt
|
80 |
|
81 |
def chat_with_points(chat_input, click_state, state):
|
82 |
+
if model.text_refiner is None:
|
83 |
response = "Text refiner is not initilzed, please input openai api key."
|
84 |
state = state + [(chat_input, response)]
|
85 |
return state, state
|
|
|
133 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
134 |
|
135 |
yield state, state, click_state, chat_input, image_input, wiki
|
136 |
+
if not args.disable_gpt and model.text_refiner:
|
137 |
refined_caption = model.text_refiner.inference(query=text, controls=controls, context=out['context_captions'])
|
138 |
# new_cap = 'Original: ' + text + '. Refined: ' + refined_caption['caption']
|
139 |
new_cap = refined_caption['caption']
|
|
|
144 |
def upload_callback(image_input, state):
|
145 |
state = [] + [('Image size: ' + str(image_input.size), None)]
|
146 |
click_state = [[], [], []]
|
147 |
+
res = 1024
|
148 |
+
width, height = image_input.size
|
149 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
150 |
+
if ratio < 1.0:
|
151 |
+
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
152 |
+
print('Scaling input image to {}'.format(image_input.size))
|
153 |
model.segmenter.image = None
|
154 |
model.segmenter.image_embedding = None
|
155 |
model.segmenter.set_image(image_input)
|
156 |
+
return state, image_input, click_state, image_input
|
157 |
|
158 |
with gr.Blocks(
|
159 |
css='''
|
|
|
170 |
|
171 |
with gr.Row():
|
172 |
with gr.Column(scale=1.0):
|
173 |
+
with gr.Column(visible=False) as modules_not_need_gpt:
|
174 |
+
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
175 |
+
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
176 |
+
with gr.Row(scale=1.0):
|
177 |
+
point_prompt = gr.Radio(
|
178 |
+
choices=["Positive", "Negative"],
|
179 |
+
value="Positive",
|
180 |
+
label="Point Prompt",
|
181 |
+
interactive=True)
|
182 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
183 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
184 |
+
with gr.Column(visible=False) as modules_need_gpt:
|
185 |
+
with gr.Row(scale=1.0):
|
186 |
+
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
187 |
+
|
188 |
+
sentiment = gr.Radio(
|
189 |
+
choices=["Positive", "Natural", "Negative"],
|
190 |
+
value="Natural",
|
191 |
+
label="Sentiment",
|
192 |
+
interactive=True,
|
193 |
+
)
|
194 |
+
with gr.Row(scale=1.0):
|
195 |
+
factuality = gr.Radio(
|
196 |
+
choices=["Factual", "Imagination"],
|
197 |
+
value="Factual",
|
198 |
+
label="Factuality",
|
199 |
+
interactive=True,
|
200 |
+
)
|
201 |
+
length = gr.Slider(
|
202 |
+
minimum=10,
|
203 |
+
maximum=80,
|
204 |
+
value=10,
|
205 |
+
step=1,
|
206 |
+
interactive=True,
|
207 |
+
label="Length",
|
208 |
+
)
|
209 |
|
210 |
with gr.Column(scale=0.5):
|
211 |
openai_api_key = gr.Textbox(
|
212 |
+
placeholder="Input openAI API key and press Enter (Input blank will disable GPT)",
|
213 |
show_label=False,
|
214 |
label = "OpenAI API Key",
|
215 |
lines=1,
|
216 |
type="password"
|
217 |
)
|
218 |
+
with gr.Column(visible=False) as modules_need_gpt2:
|
219 |
+
wiki_output = gr.Textbox(lines=6, label="Wiki")
|
220 |
+
with gr.Column(visible=False) as modules_not_need_gpt2:
|
221 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=450,scale=0.5)
|
222 |
+
with gr.Column(visible=False) as modules_need_gpt3:
|
223 |
+
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
224 |
+
with gr.Row():
|
225 |
+
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
226 |
+
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
227 |
+
|
228 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2])
|
229 |
clear_button_clike.click(
|
230 |
lambda x: ([[], [], []], x, ""),
|
231 |
[origin_image],
|
|
|
234 |
show_progress=False
|
235 |
)
|
236 |
clear_button_image.click(
|
237 |
+
lambda: (None, [], [], [[], [], []], "", ""),
|
238 |
[],
|
239 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
240 |
queue=False,
|
241 |
show_progress=False
|
242 |
)
|
|
|
248 |
show_progress=False
|
249 |
)
|
250 |
image_input.clear(
|
251 |
+
lambda: (None, [], [], [[], [], []], "", ""),
|
252 |
[],
|
253 |
+
[image_input, chatbot, state, click_state, wiki_output, origin_image],
|
254 |
queue=False,
|
255 |
show_progress=False
|
256 |
)
|
257 |
|
258 |
+
def example_callback(x):
|
259 |
+
model.image_embedding = None
|
260 |
+
return x
|
261 |
+
|
262 |
+
gr.Examples(
|
263 |
examples=examples,
|
264 |
+
inputs=[example_image],
|
265 |
)
|
266 |
|
267 |
+
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state, image_input])
|
268 |
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
269 |
+
example_image.change(upload_callback,[example_image, state], [state, origin_image, click_state, image_input])
|
270 |
|
271 |
# select coordinate
|
272 |
image_input.select(inference_seg_cap,
|
|
|
283 |
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
284 |
show_progress=False, queue=True)
|
285 |
|
286 |
+
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
287 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
|
app_old.py
CHANGED
@@ -98,9 +98,9 @@ def chat_with_points(chat_input, click_state, state):
|
|
98 |
return state, state
|
99 |
|
100 |
def init_openai_api_key(api_key):
|
101 |
-
os.environ['OPENAI_API_KEY'] = api_key
|
102 |
global model
|
103 |
-
model = CaptionAnything(args)
|
104 |
|
105 |
css='''
|
106 |
#image_upload{min-height:200px}
|
|
|
98 |
return state, state
|
99 |
|
100 |
def init_openai_api_key(api_key):
|
101 |
+
# os.environ['OPENAI_API_KEY'] = api_key
|
102 |
global model
|
103 |
+
model = CaptionAnything(args, api_key)
|
104 |
|
105 |
css='''
|
106 |
#image_upload{min-height:200px}
|
caption_anything.py
CHANGED
@@ -8,18 +8,22 @@ import time
|
|
8 |
from PIL import Image
|
9 |
|
10 |
class CaptionAnything():
|
11 |
-
def __init__(self, args):
|
12 |
self.args = args
|
13 |
self.captioner = build_captioner(args.captioner, args.device, args)
|
14 |
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
|
|
15 |
if not args.disable_gpt:
|
16 |
-
self.init_refiner()
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
self.text_refiner
|
22 |
-
|
|
|
|
|
|
|
23 |
def inference(self, image, prompt, controls, disable_gpt=False):
|
24 |
# segment with prompt
|
25 |
print("CA prompt: ", prompt, "CA controls",controls)
|
@@ -35,14 +39,14 @@ class CaptionAnything():
|
|
35 |
print("seg_mask.shape: ", seg_mask.shape)
|
36 |
# captioning with mask
|
37 |
if self.args.enable_reduce_tokens:
|
38 |
-
caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter,
|
39 |
else:
|
40 |
-
caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter,
|
41 |
# refining with TextRefiner
|
42 |
context_captions = []
|
43 |
if self.args.context_captions:
|
44 |
context_captions.append(self.captioner.inference(image))
|
45 |
-
if not disable_gpt and
|
46 |
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
47 |
else:
|
48 |
refined_caption = {'raw_caption': caption}
|
@@ -54,14 +58,14 @@ class CaptionAnything():
|
|
54 |
|
55 |
def parse_augment():
|
56 |
parser = argparse.ArgumentParser()
|
57 |
-
parser.add_argument('--captioner', type=str, default="
|
58 |
parser.add_argument('--segmenter', type=str, default="base")
|
59 |
parser.add_argument('--text_refiner', type=str, default="base")
|
60 |
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
61 |
-
parser.add_argument('--seg_crop_mode', type=str, default="
|
62 |
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
63 |
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
|
64 |
-
parser.add_argument('--
|
65 |
parser.add_argument('--device', type=str, default="cuda:0")
|
66 |
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
67 |
parser.add_argument('--debug', action="store_true")
|
@@ -101,7 +105,7 @@ if __name__ == "__main__":
|
|
101 |
"language": "English",
|
102 |
}
|
103 |
|
104 |
-
model = CaptionAnything(args)
|
105 |
for prompt in prompts:
|
106 |
print('*'*30)
|
107 |
print('Image path: ', image_path)
|
|
|
8 |
from PIL import Image
|
9 |
|
10 |
class CaptionAnything():
|
11 |
+
def __init__(self, args, api_key=""):
|
12 |
self.args = args
|
13 |
self.captioner = build_captioner(args.captioner, args.device, args)
|
14 |
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
15 |
+
self.text_refiner = None
|
16 |
if not args.disable_gpt:
|
17 |
+
self.init_refiner(api_key)
|
18 |
|
19 |
+
def init_refiner(self, api_key):
|
20 |
+
try:
|
21 |
+
self.text_refiner = build_text_refiner(self.args.text_refiner, self.args.device, self.args, api_key)
|
22 |
+
self.text_refiner.llm('hi') # test
|
23 |
+
except:
|
24 |
+
self.text_refiner = None
|
25 |
+
print('Openai api key is NOT given')
|
26 |
+
|
27 |
def inference(self, image, prompt, controls, disable_gpt=False):
|
28 |
# segment with prompt
|
29 |
print("CA prompt: ", prompt, "CA controls",controls)
|
|
|
39 |
print("seg_mask.shape: ", seg_mask.shape)
|
40 |
# captioning with mask
|
41 |
if self.args.enable_reduce_tokens:
|
42 |
+
caption, crop_save_path = self.captioner.inference_with_reduced_tokens(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
|
43 |
else:
|
44 |
+
caption, crop_save_path = self.captioner.inference_seg(image, seg_mask, crop_mode=self.args.seg_crop_mode, filter=self.args.clip_filter, disable_regular_box = self.args.disable_regular_box)
|
45 |
# refining with TextRefiner
|
46 |
context_captions = []
|
47 |
if self.args.context_captions:
|
48 |
context_captions.append(self.captioner.inference(image))
|
49 |
+
if not disable_gpt and self.text_refiner is not None:
|
50 |
refined_caption = self.text_refiner.inference(query=caption, controls=controls, context=context_captions)
|
51 |
else:
|
52 |
refined_caption = {'raw_caption': caption}
|
|
|
58 |
|
59 |
def parse_augment():
|
60 |
parser = argparse.ArgumentParser()
|
61 |
+
parser.add_argument('--captioner', type=str, default="blip2")
|
62 |
parser.add_argument('--segmenter', type=str, default="base")
|
63 |
parser.add_argument('--text_refiner', type=str, default="base")
|
64 |
parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
|
65 |
+
parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
|
66 |
parser.add_argument('--clip_filter', action="store_true", help="use clip to filter bad captions")
|
67 |
parser.add_argument('--context_captions', action="store_true", help="use surrounding captions to enhance current caption (TODO)")
|
68 |
+
parser.add_argument('--disable_regular_box', action="store_true", default = False, help="crop image with a regular box")
|
69 |
parser.add_argument('--device', type=str, default="cuda:0")
|
70 |
parser.add_argument('--port', type=int, default=6086, help="only useful when running gradio applications")
|
71 |
parser.add_argument('--debug', action="store_true")
|
|
|
105 |
"language": "English",
|
106 |
}
|
107 |
|
108 |
+
model = CaptionAnything(args, os.environ['OPENAI_API_KEY'])
|
109 |
for prompt in prompts:
|
110 |
print('*'*30)
|
111 |
print('Image path: ', image_path)
|
captioner/base_captioner.py
CHANGED
@@ -135,7 +135,7 @@ class BaseCaptioner:
|
|
135 |
return caption, crop_save_path
|
136 |
|
137 |
|
138 |
-
def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False,
|
139 |
if type(image) == str:
|
140 |
image = Image.open(image)
|
141 |
if type(seg_mask) == str:
|
@@ -151,14 +151,14 @@ class BaseCaptioner:
|
|
151 |
else:
|
152 |
image = np.array(image)
|
153 |
|
154 |
-
if
|
155 |
-
min_area_box = new_seg_to_box(seg_mask)
|
156 |
-
else:
|
157 |
min_area_box = seg_to_box(seg_mask)
|
|
|
|
|
158 |
return self.inference_box(image, min_area_box, filter)
|
159 |
|
160 |
|
161 |
-
def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg",
|
162 |
if type(image) == str:
|
163 |
image = Image.open(image)
|
164 |
if type(seg_mask) == str:
|
@@ -173,10 +173,10 @@ class BaseCaptioner:
|
|
173 |
else:
|
174 |
image = np.array(image)
|
175 |
|
176 |
-
if
|
177 |
-
box = new_seg_to_box(seg_mask)
|
178 |
-
else:
|
179 |
box = seg_to_box(seg_mask)
|
|
|
|
|
180 |
|
181 |
if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
182 |
size = max(image.shape[0], image.shape[1])
|
|
|
135 |
return caption, crop_save_path
|
136 |
|
137 |
|
138 |
+
def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, disable_regular_box = False):
|
139 |
if type(image) == str:
|
140 |
image = Image.open(image)
|
141 |
if type(seg_mask) == str:
|
|
|
151 |
else:
|
152 |
image = np.array(image)
|
153 |
|
154 |
+
if disable_regular_box:
|
|
|
|
|
155 |
min_area_box = seg_to_box(seg_mask)
|
156 |
+
else:
|
157 |
+
min_area_box = new_seg_to_box(seg_mask)
|
158 |
return self.inference_box(image, min_area_box, filter)
|
159 |
|
160 |
|
161 |
+
def generate_seg_cropped_image(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", disable_regular_box = False):
|
162 |
if type(image) == str:
|
163 |
image = Image.open(image)
|
164 |
if type(seg_mask) == str:
|
|
|
173 |
else:
|
174 |
image = np.array(image)
|
175 |
|
176 |
+
if disable_regular_box:
|
|
|
|
|
177 |
box = seg_to_box(seg_mask)
|
178 |
+
else:
|
179 |
+
box = new_seg_to_box(seg_mask)
|
180 |
|
181 |
if np.array(box).size == 4: # [x0, y0, x1, y1], where (x0, y0), (x1, y1) represent top-left and bottom-right corners
|
182 |
size = max(image.shape[0], image.shape[1])
|
captioner/blip.py
CHANGED
@@ -25,15 +25,15 @@ class BLIPCaptioner(BaseCaptioner):
|
|
25 |
image = Image.open(image)
|
26 |
inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
|
27 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
28 |
-
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
29 |
if self.enable_filter and filter:
|
30 |
captions = self.filter_caption(image, captions)
|
31 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
32 |
return captions
|
33 |
|
34 |
@torch.no_grad()
|
35 |
-
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False,
|
36 |
-
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
|
37 |
if type(image) == str: # input path
|
38 |
image = Image.open(image)
|
39 |
inputs = self.processor(image, return_tensors="pt")
|
@@ -45,7 +45,7 @@ class BLIPCaptioner(BaseCaptioner):
|
|
45 |
seg_mask = seg_mask.float()
|
46 |
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
47 |
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
48 |
-
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
49 |
if self.enable_filter and filter:
|
50 |
captions = self.filter_caption(image, captions)
|
51 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
|
|
25 |
image = Image.open(image)
|
26 |
inputs = self.processor(image, return_tensors="pt").to(self.device, self.torch_dtype)
|
27 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
28 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
29 |
if self.enable_filter and filter:
|
30 |
captions = self.filter_caption(image, captions)
|
31 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
32 |
return captions
|
33 |
|
34 |
@torch.no_grad()
|
35 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
|
36 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
|
37 |
if type(image) == str: # input path
|
38 |
image = Image.open(image)
|
39 |
inputs = self.processor(image, return_tensors="pt")
|
|
|
45 |
seg_mask = seg_mask.float()
|
46 |
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
47 |
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
48 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
49 |
if self.enable_filter and filter:
|
50 |
captions = self.filter_caption(image, captions)
|
51 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
captioner/blip2.py
CHANGED
@@ -22,9 +22,10 @@ class BLIP2Captioner(BaseCaptioner):
|
|
22 |
image = Image.open(image)
|
23 |
|
24 |
if not self.dialogue:
|
25 |
-
|
|
|
26 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
27 |
-
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
28 |
if self.enable_filter and filter:
|
29 |
captions = self.filter_caption(image, captions)
|
30 |
print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
|
|
|
22 |
image = Image.open(image)
|
23 |
|
24 |
if not self.dialogue:
|
25 |
+
text_prompt = 'Context: ignore the white part in this image. Question: describe this image. Answer:'
|
26 |
+
inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
|
27 |
out = self.model.generate(**inputs, max_new_tokens=50)
|
28 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
29 |
if self.enable_filter and filter:
|
30 |
captions = self.filter_caption(image, captions)
|
31 |
print(f"\nProcessed ImageCaptioning by BLIP2Captioner, Output Text: {captions}")
|
captioner/git.py
CHANGED
@@ -22,15 +22,15 @@ class GITCaptioner(BaseCaptioner):
|
|
22 |
image = Image.open(image)
|
23 |
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
|
24 |
generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
|
25 |
-
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
26 |
if self.enable_filter and filter:
|
27 |
captions = self.filter_caption(image, captions)
|
28 |
print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
|
29 |
return generated_caption
|
30 |
|
31 |
@torch.no_grad()
|
32 |
-
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False,
|
33 |
-
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode,
|
34 |
if type(image) == str: # input path
|
35 |
image = Image.open(image)
|
36 |
inputs = self.processor(images=image, return_tensors="pt")
|
@@ -42,7 +42,7 @@ class GITCaptioner(BaseCaptioner):
|
|
42 |
seg_mask = seg_mask.float()
|
43 |
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
44 |
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
45 |
-
captions = self.processor.decode(out[0], skip_special_tokens=True)
|
46 |
if self.enable_filter and filter:
|
47 |
captions = self.filter_caption(image, captions)
|
48 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
|
|
22 |
image = Image.open(image)
|
23 |
pixel_values = self.processor(images=image, return_tensors="pt").pixel_values.to(self.device, self.torch_dtype)
|
24 |
generated_ids = self.model.generate(pixel_values=pixel_values, max_new_tokens=50)
|
25 |
+
generated_caption = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
|
26 |
if self.enable_filter and filter:
|
27 |
captions = self.filter_caption(image, captions)
|
28 |
print(f"\nProcessed ImageCaptioning by GITCaptioner, Output Text: {generated_caption}")
|
29 |
return generated_caption
|
30 |
|
31 |
@torch.no_grad()
|
32 |
+
def inference_with_reduced_tokens(self, image: Union[np.ndarray, Image.Image, str], seg_mask, crop_mode="w_bg", filter=False, disable_regular_box = False):
|
33 |
+
crop_save_path = self.generate_seg_cropped_image(image=image, seg_mask=seg_mask, crop_mode=crop_mode, disable_regular_box=disable_regular_box)
|
34 |
if type(image) == str: # input path
|
35 |
image = Image.open(image)
|
36 |
inputs = self.processor(images=image, return_tensors="pt")
|
|
|
42 |
seg_mask = seg_mask.float()
|
43 |
pixel_masks = seg_mask.unsqueeze(0).to(self.device)
|
44 |
out = self.model.generate(pixel_values=pixel_values, pixel_masks=pixel_masks, max_new_tokens=50)
|
45 |
+
captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
|
46 |
if self.enable_filter and filter:
|
47 |
captions = self.filter_caption(image, captions)
|
48 |
print(f"\nProcessed ImageCaptioning by BLIPCaptioner, Output Text: {captions}")
|
image_editing_utils.py
CHANGED
@@ -35,7 +35,8 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
|
|
35 |
|
36 |
# Wrap the text to fit within the max_text_width
|
37 |
lines = wrap_text(text, font, max_text_width)
|
38 |
-
text_width
|
|
|
39 |
text_height = text_height * len(lines)
|
40 |
|
41 |
# Define bubble frame dimensions
|
@@ -48,7 +49,7 @@ def create_bubble_frame(image, text, point, font_path='DejaVuSansCondensed-Bold.
|
|
48 |
|
49 |
# Draw the bubble frame on the new image
|
50 |
draw = ImageDraw.Draw(bubble)
|
51 |
-
draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
|
52 |
|
53 |
# Draw the wrapped text line by line
|
54 |
y_text = padding
|
|
|
35 |
|
36 |
# Wrap the text to fit within the max_text_width
|
37 |
lines = wrap_text(text, font, max_text_width)
|
38 |
+
text_width = max([font.getsize(line)[0] for line in lines])
|
39 |
+
_, text_height = font.getsize(lines[0])
|
40 |
text_height = text_height * len(lines)
|
41 |
|
42 |
# Define bubble frame dimensions
|
|
|
49 |
|
50 |
# Draw the bubble frame on the new image
|
51 |
draw = ImageDraw.Draw(bubble)
|
52 |
+
# draw.rectangle([(0, 0), (bubble_width - 1, bubble_height - 1)], fill=(255, 255, 255, 0), outline=(255, 255, 255, 0), width=2)
|
53 |
|
54 |
# Draw the wrapped text line by line
|
55 |
y_text = padding
|
segmenter/base_segmenter.py
CHANGED
@@ -46,7 +46,7 @@ class BaseSegmenter:
|
|
46 |
new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
|
47 |
return new_masks
|
48 |
else:
|
49 |
-
if not self.reuse_feature:
|
50 |
self.set_image(image)
|
51 |
self.predictor.set_image(self.image)
|
52 |
else:
|
|
|
46 |
new_masks = np.concatenate([mask["segmentation"][np.newaxis,:] for mask in masks])
|
47 |
return new_masks
|
48 |
else:
|
49 |
+
if not self.reuse_feature or self.image_embedding is None:
|
50 |
self.set_image(image)
|
51 |
self.predictor.set_image(self.image)
|
52 |
else:
|
text_refiner/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from text_refiner.text_refiner import TextRefiner
|
2 |
|
3 |
|
4 |
-
def build_text_refiner(type, device, args=None):
|
5 |
if type == 'base':
|
6 |
-
return TextRefiner(device)
|
|
|
1 |
from text_refiner.text_refiner import TextRefiner
|
2 |
|
3 |
|
4 |
+
def build_text_refiner(type, device, args=None, api_key=""):
|
5 |
if type == 'base':
|
6 |
+
return TextRefiner(device, api_key)
|
text_refiner/text_refiner.py
CHANGED
@@ -5,12 +5,9 @@ from transformers import pipeline, BlipProcessor, BlipForConditionalGeneration,
|
|
5 |
import pdb
|
6 |
|
7 |
class TextRefiner:
|
8 |
-
def __init__(self, device):
|
9 |
print(f"Initializing TextRefiner to {device}")
|
10 |
-
|
11 |
-
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0)
|
12 |
-
except:
|
13 |
-
print('Openai api key is NOT given')
|
14 |
self.prompt_tag = {
|
15 |
"imagination": {"True": "could",
|
16 |
"False": "could not"}
|
|
|
5 |
import pdb
|
6 |
|
7 |
class TextRefiner:
|
8 |
+
def __init__(self, device, api_key=""):
|
9 |
print(f"Initializing TextRefiner to {device}")
|
10 |
+
self.llm = OpenAI(model_name="gpt-3.5-turbo", temperature=0, openai_api_key=api_key)
|
|
|
|
|
|
|
11 |
self.prompt_tag = {
|
12 |
"imagination": {"True": "could",
|
13 |
"False": "could not"}
|
tools.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import cv2
|
|
|
2 |
import numpy as np
|
3 |
from PIL import Image
|
4 |
import copy
|
|
|
5 |
|
6 |
|
7 |
def colormap(rgb=True):
|
@@ -100,16 +102,6 @@ color_list = colormap()
|
|
100 |
color_list = color_list.astype('uint8').tolist()
|
101 |
|
102 |
|
103 |
-
def gauss_filter(kernel_size, sigma):
|
104 |
-
max_idx = kernel_size // 2
|
105 |
-
idx = np.linspace(-max_idx, max_idx, kernel_size)
|
106 |
-
Y, X = np.meshgrid(idx, idx)
|
107 |
-
gauss_filter = np.exp(-(X**2 + Y**2) / (2*sigma**2))
|
108 |
-
gauss_filter /= np.sum(np.sum(gauss_filter))
|
109 |
-
|
110 |
-
return gauss_filter
|
111 |
-
|
112 |
-
|
113 |
def vis_add_mask(image, mask, color, alpha, kernel_size):
|
114 |
color = np.array(color)
|
115 |
mask = mask.astype('float').copy()
|
@@ -129,6 +121,23 @@ def vis_add_mask_wo_blur(image, mask, color, alpha):
|
|
129 |
return image
|
130 |
|
131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
|
133 |
"""
|
134 |
Input:
|
@@ -146,11 +155,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
|
|
146 |
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
147 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
148 |
|
149 |
-
|
150 |
-
res = 1024
|
151 |
-
ratio = min(1.0 * res / max(width, height), 1.0)
|
152 |
-
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
153 |
-
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
154 |
# 0: background, 1: foreground
|
155 |
input_mask[input_mask>0] = 255
|
156 |
|
@@ -163,15 +168,120 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
|
|
163 |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
164 |
contour_mask = cv2.dilate(contour_mask, kernel)
|
165 |
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
return painted_image
|
168 |
|
169 |
|
170 |
if __name__ == '__main__':
|
171 |
|
172 |
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
173 |
-
background_blur_radius =
|
174 |
-
contour_width =
|
175 |
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
|
176 |
contour_alpha = 1 # transparency of background, 0: no contour highlighted
|
177 |
|
@@ -180,8 +290,54 @@ if __name__ == '__main__':
|
|
180 |
input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
|
181 |
|
182 |
# paint
|
183 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
# save
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import cv2
|
2 |
+
import torch
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import copy
|
6 |
+
import time
|
7 |
|
8 |
|
9 |
def colormap(rgb=True):
|
|
|
102 |
color_list = color_list.astype('uint8').tolist()
|
103 |
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
def vis_add_mask(image, mask, color, alpha, kernel_size):
|
106 |
color = np.array(color)
|
107 |
mask = mask.astype('float').copy()
|
|
|
121 |
return image
|
122 |
|
123 |
|
124 |
+
def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_color, contour_color, background_alpha, contour_alpha):
|
125 |
+
background_color = np.array(background_color)
|
126 |
+
contour_color = np.array(contour_color)
|
127 |
+
|
128 |
+
# background_mask = 1 - background_mask
|
129 |
+
# contour_mask = 1 - contour_mask
|
130 |
+
|
131 |
+
for i in range(3):
|
132 |
+
image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
|
133 |
+
+ background_color[i] * (background_alpha-background_mask*background_alpha)
|
134 |
+
|
135 |
+
image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
|
136 |
+
+ contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
|
137 |
+
|
138 |
+
return image.astype('uint8')
|
139 |
+
|
140 |
+
|
141 |
def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1):
|
142 |
"""
|
143 |
Input:
|
|
|
155 |
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
156 |
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
157 |
|
158 |
+
|
|
|
|
|
|
|
|
|
159 |
# 0: background, 1: foreground
|
160 |
input_mask[input_mask>0] = 255
|
161 |
|
|
|
168 |
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (contour_width, contour_width))
|
169 |
contour_mask = cv2.dilate(contour_mask, kernel)
|
170 |
painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
|
171 |
+
|
172 |
+
# painted_image = background_dist_map
|
173 |
+
|
174 |
+
return painted_image
|
175 |
+
|
176 |
+
|
177 |
+
def mask_generator_00(mask, background_radius, contour_radius):
|
178 |
+
# no background width when '00'
|
179 |
+
# distance map
|
180 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
181 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
182 |
+
dist_map = dist_transform_fore - dist_transform_back
|
183 |
+
# ...:::!!!:::...
|
184 |
+
contour_radius += 2
|
185 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
186 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
187 |
+
contour_mask[contour_mask>0.5] = 1.
|
188 |
+
|
189 |
+
return mask, contour_mask
|
190 |
+
|
191 |
+
|
192 |
+
def mask_generator_01(mask, background_radius, contour_radius):
|
193 |
+
# no background width when '00'
|
194 |
+
# distance map
|
195 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
196 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
197 |
+
dist_map = dist_transform_fore - dist_transform_back
|
198 |
+
# ...:::!!!:::...
|
199 |
+
contour_radius += 2
|
200 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
201 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
202 |
+
return mask, contour_mask
|
203 |
+
|
204 |
+
|
205 |
+
def mask_generator_10(mask, background_radius, contour_radius):
|
206 |
+
# distance map
|
207 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
208 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
209 |
+
dist_map = dist_transform_fore - dist_transform_back
|
210 |
+
# .....:::::!!!!!
|
211 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
212 |
+
background_mask = (background_mask - np.min(background_mask))
|
213 |
+
background_mask = background_mask / np.max(background_mask)
|
214 |
+
# ...:::!!!:::...
|
215 |
+
contour_radius += 2
|
216 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
217 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
218 |
+
contour_mask[contour_mask>0.5] = 1.
|
219 |
+
return background_mask, contour_mask
|
220 |
+
|
221 |
+
|
222 |
+
def mask_generator_11(mask, background_radius, contour_radius):
|
223 |
+
# distance map
|
224 |
+
dist_transform_fore = cv2.distanceTransform(mask, cv2.DIST_L2, 3)
|
225 |
+
dist_transform_back = cv2.distanceTransform(1-mask, cv2.DIST_L2, 3)
|
226 |
+
dist_map = dist_transform_fore - dist_transform_back
|
227 |
+
# .....:::::!!!!!
|
228 |
+
background_mask = np.clip(dist_map, -background_radius, background_radius)
|
229 |
+
background_mask = (background_mask - np.min(background_mask))
|
230 |
+
background_mask = background_mask / np.max(background_mask)
|
231 |
+
# ...:::!!!:::...
|
232 |
+
contour_radius += 2
|
233 |
+
contour_mask = np.abs(np.clip(dist_map, -contour_radius, contour_radius))
|
234 |
+
contour_mask = contour_mask / np.max(contour_mask)
|
235 |
+
return background_mask, contour_mask
|
236 |
+
|
237 |
+
|
238 |
+
def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, background_blur_radius=7, contour_width=3, contour_color=3, contour_alpha=1, mode='11'):
|
239 |
+
"""
|
240 |
+
Input:
|
241 |
+
input_image: numpy array
|
242 |
+
input_mask: numpy array
|
243 |
+
background_alpha: transparency of background, [0, 1], 1: all black, 0: do nothing
|
244 |
+
background_blur_radius: radius of background blur, must be odd number
|
245 |
+
contour_width: width of mask contour, must be odd number
|
246 |
+
contour_color: color index (in color map) of mask contour, 0: black, 1: white, >1: others
|
247 |
+
contour_alpha: transparency of mask contour, [0, 1], if 0: no contour highlighted
|
248 |
+
mode: painting mode, '00', no blur, '01' only blur contour, '10' only blur background, '11' blur both
|
249 |
+
|
250 |
+
Output:
|
251 |
+
painted_image: numpy array
|
252 |
+
"""
|
253 |
+
assert input_image.shape[:2] == input_mask.shape, 'different shape'
|
254 |
+
assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
|
255 |
+
assert mode in ['00', '01', '10', '11'], 'mode should be 00, 01, 10, or 11'
|
256 |
+
|
257 |
+
# downsample input image and mask
|
258 |
+
width, height = input_image.shape[0], input_image.shape[1]
|
259 |
+
res = 1024
|
260 |
+
ratio = min(1.0 * res / max(width, height), 1.0)
|
261 |
+
input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
|
262 |
+
input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
|
263 |
+
|
264 |
+
# 0: background, 1: foreground
|
265 |
+
msk = np.clip(input_mask, 0, 1)
|
266 |
+
|
267 |
+
# generate masks for background and contour pixels
|
268 |
+
background_radius = (background_blur_radius - 1) // 2
|
269 |
+
contour_radius = (contour_width - 1) // 2
|
270 |
+
generator_dict = {'00':mask_generator_00, '01':mask_generator_01, '10':mask_generator_10, '11':mask_generator_11}
|
271 |
+
background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
|
272 |
+
|
273 |
+
# paint
|
274 |
+
painted_image = vis_add_mask_wo_gaussian\
|
275 |
+
(input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
|
276 |
+
|
277 |
return painted_image
|
278 |
|
279 |
|
280 |
if __name__ == '__main__':
|
281 |
|
282 |
background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
|
283 |
+
background_blur_radius = 31 # radius of background blur, must be odd number
|
284 |
+
contour_width = 11 # contour width, must be odd number
|
285 |
contour_color = 3 # id in color map, 0: black, 1: white, >1: others
|
286 |
contour_alpha = 1 # transparency of background, 0: no contour highlighted
|
287 |
|
|
|
290 |
input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
|
291 |
|
292 |
# paint
|
293 |
+
overall_time_1 = 0
|
294 |
+
overall_time_2 = 0
|
295 |
+
overall_time_3 = 0
|
296 |
+
overall_time_4 = 0
|
297 |
+
overall_time_5 = 0
|
298 |
+
|
299 |
+
for i in range(50):
|
300 |
+
t2 = time.time()
|
301 |
+
painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
|
302 |
+
e2 = time.time()
|
303 |
+
|
304 |
+
t3 = time.time()
|
305 |
+
painted_image_10 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='10')
|
306 |
+
e3 = time.time()
|
307 |
+
|
308 |
+
t1 = time.time()
|
309 |
+
painted_image = mask_painter(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha)
|
310 |
+
e1 = time.time()
|
311 |
+
|
312 |
+
t4 = time.time()
|
313 |
+
painted_image_01 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='01')
|
314 |
+
e4 = time.time()
|
315 |
+
|
316 |
+
t5 = time.time()
|
317 |
+
painted_image_11 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='11')
|
318 |
+
e5 = time.time()
|
319 |
+
|
320 |
+
overall_time_1 += (e1 - t1)
|
321 |
+
overall_time_2 += (e2 - t2)
|
322 |
+
overall_time_3 += (e3 - t3)
|
323 |
+
overall_time_4 += (e4 - t4)
|
324 |
+
overall_time_5 += (e5 - t5)
|
325 |
+
|
326 |
+
print(f'average time w gaussian: {overall_time_1/50}')
|
327 |
+
print(f'average time w/o gaussian00: {overall_time_2/50}')
|
328 |
+
print(f'average time w/o gaussian10: {overall_time_3/50}')
|
329 |
+
print(f'average time w/o gaussian01: {overall_time_4/50}')
|
330 |
+
print(f'average time w/o gaussian11: {overall_time_5/50}')
|
331 |
|
332 |
# save
|
333 |
+
painted_image_00 = Image.fromarray(painted_image_00)
|
334 |
+
painted_image_00.save('./test_img/painter_output_image_00.png')
|
335 |
+
|
336 |
+
painted_image_10 = Image.fromarray(painted_image_10)
|
337 |
+
painted_image_10.save('./test_img/painter_output_image_10.png')
|
338 |
+
|
339 |
+
painted_image_01 = Image.fromarray(painted_image_01)
|
340 |
+
painted_image_01.save('./test_img/painter_output_image_01.png')
|
341 |
+
|
342 |
+
painted_image_11 = Image.fromarray(painted_image_11)
|
343 |
+
painted_image_11.save('./test_img/painter_output_image_11.png')
|