Spaces:
Runtime error
Runtime error
ttengwang
commited on
Commit
•
13c1c2e
1
Parent(s):
eeb5fe8
assign api key and img embed from different users to different sessions
Browse files- app.py +141 -60
- caption_anything.py +21 -10
- segmenter/__init__.py +4 -2
- segmenter/base_segmenter.py +7 -4
app.py
CHANGED
@@ -15,6 +15,10 @@ import copy
|
|
15 |
from tools import mask_painter
|
16 |
from PIL import Image
|
17 |
import os
|
|
|
|
|
|
|
|
|
18 |
|
19 |
def download_checkpoint(url, folder, filename):
|
20 |
os.makedirs(folder, exist_ok=True)
|
@@ -50,37 +54,74 @@ examples = [
|
|
50 |
]
|
51 |
|
52 |
args = parse_augment()
|
53 |
-
args.disable_reuse_features = True
|
54 |
# args.device = 'cuda:5'
|
55 |
-
# args.disable_gpt =
|
56 |
-
# args.enable_reduce_tokens =
|
57 |
# args.port=20322
|
58 |
-
|
|
|
|
|
|
|
59 |
|
60 |
-
def init_openai_api_key(api_key):
|
61 |
-
# os.environ['OPENAI_API_KEY'] = api_key
|
62 |
-
model.init_refiner(api_key)
|
63 |
-
openai_available = model.text_refiner is not None
|
64 |
-
return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True)
|
65 |
|
66 |
-
def
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
inputs = json.loads(chat_input)
|
70 |
-
|
71 |
-
points
|
72 |
-
labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
prompt = {
|
75 |
"prompt_type":["click"],
|
76 |
-
"input_point":
|
77 |
-
"input_label":
|
78 |
"multimask_output":"True",
|
79 |
}
|
80 |
return prompt
|
81 |
|
82 |
-
def
|
83 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
response = "Text refiner is not initilzed, please input openai api key."
|
85 |
state = state + [(chat_input, response)]
|
86 |
return state, state
|
@@ -96,11 +137,26 @@ def chat_with_points(chat_input, click_state, state):
|
|
96 |
else:
|
97 |
prev_visual_context = 'no point exists.'
|
98 |
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
99 |
-
response =
|
100 |
state = state + [(chat_input, response)]
|
101 |
return state, state
|
102 |
|
103 |
-
def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
if point_prompt == 'Positive':
|
106 |
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
@@ -114,7 +170,7 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
|
|
114 |
|
115 |
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
116 |
# chat_input = click_coordinate
|
117 |
-
prompt = get_prompt(coordinate, click_state)
|
118 |
print('prompt: ', prompt, 'controls: ', controls)
|
119 |
|
120 |
out = model.inference(image_input, prompt, controls)
|
@@ -123,12 +179,12 @@ def inference_seg_cap(image_input, point_prompt, language, sentiment, factuality
|
|
123 |
# state = state + [(f'{k}: {v}', None)]
|
124 |
state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
125 |
wiki = out['generated_captions'].get('wiki', "")
|
126 |
-
|
127 |
-
|
128 |
text = out['generated_captions']['raw_caption']
|
129 |
# draw = ImageDraw.Draw(image_input)
|
130 |
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
131 |
-
input_mask = np.array(
|
132 |
image_input = mask_painter(np.array(image_input), input_mask)
|
133 |
origin_image_input = image_input
|
134 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
@@ -151,10 +207,19 @@ def upload_callback(image_input, state):
|
|
151 |
if ratio < 1.0:
|
152 |
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
153 |
print('Scaling input image to {}'.format(image_input.size))
|
154 |
-
|
155 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
model.segmenter.set_image(image_input)
|
157 |
-
|
|
|
|
|
|
|
158 |
|
159 |
with gr.Blocks(
|
160 |
css='''
|
@@ -165,6 +230,10 @@ with gr.Blocks(
|
|
165 |
state = gr.State([])
|
166 |
click_state = gr.State([[],[],[]])
|
167 |
origin_image = gr.State(None)
|
|
|
|
|
|
|
|
|
168 |
|
169 |
gr.Markdown(title)
|
170 |
gr.Markdown(description)
|
@@ -175,17 +244,24 @@ with gr.Blocks(
|
|
175 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
176 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
177 |
with gr.Row(scale=1.0):
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
with gr.Column(visible=False) as modules_need_gpt:
|
186 |
with gr.Row(scale=1.0):
|
187 |
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
188 |
-
|
189 |
sentiment = gr.Radio(
|
190 |
choices=["Positive", "Natural", "Negative"],
|
191 |
value="Natural",
|
@@ -206,27 +282,36 @@ with gr.Blocks(
|
|
206 |
step=1,
|
207 |
interactive=True,
|
208 |
label="Length",
|
209 |
-
)
|
210 |
-
|
|
|
|
|
|
|
|
|
211 |
with gr.Column(scale=0.5):
|
212 |
openai_api_key = gr.Textbox(
|
213 |
-
placeholder="Input openAI API key
|
214 |
show_label=False,
|
215 |
label = "OpenAI API Key",
|
216 |
lines=1,
|
217 |
-
type="password"
|
218 |
-
|
|
|
|
|
219 |
with gr.Column(visible=False) as modules_need_gpt2:
|
220 |
-
wiki_output = gr.Textbox(lines=
|
221 |
with gr.Column(visible=False) as modules_not_need_gpt2:
|
222 |
-
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=
|
223 |
with gr.Column(visible=False) as modules_need_gpt3:
|
224 |
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
225 |
with gr.Row():
|
226 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
227 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
228 |
-
|
229 |
-
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2])
|
|
|
|
|
|
|
230 |
clear_button_clike.click(
|
231 |
lambda x: ([[], [], []], x, ""),
|
232 |
[origin_image],
|
@@ -256,33 +341,29 @@ with gr.Blocks(
|
|
256 |
show_progress=False
|
257 |
)
|
258 |
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
gr.Examples(
|
264 |
-
examples=examples,
|
265 |
-
inputs=[example_image],
|
266 |
-
)
|
267 |
-
|
268 |
-
image_input.upload(upload_callback,[image_input, state], [state, origin_image, click_state, image_input])
|
269 |
-
chat_input.submit(chat_with_points, [chat_input, click_state, state], [chatbot, state])
|
270 |
-
example_image.change(upload_callback,[example_image, state], [state, origin_image, click_state, image_input])
|
271 |
|
272 |
# select coordinate
|
273 |
image_input.select(inference_seg_cap,
|
274 |
inputs=[
|
275 |
origin_image,
|
276 |
point_prompt,
|
|
|
277 |
language,
|
278 |
sentiment,
|
279 |
factuality,
|
280 |
length,
|
|
|
281 |
state,
|
282 |
-
click_state
|
|
|
|
|
|
|
283 |
],
|
284 |
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
285 |
show_progress=False, queue=True)
|
286 |
|
287 |
-
iface.queue(concurrency_count=1, api_open=False)
|
288 |
-
iface.launch(server_name="0.0.0.0", enable_queue=True)
|
|
|
15 |
from tools import mask_painter
|
16 |
from PIL import Image
|
17 |
import os
|
18 |
+
from captioner import build_captioner
|
19 |
+
from segment_anything import sam_model_registry
|
20 |
+
from text_refiner import build_text_refiner
|
21 |
+
from segmenter import build_segmenter
|
22 |
|
23 |
def download_checkpoint(url, folder, filename):
|
24 |
os.makedirs(folder, exist_ok=True)
|
|
|
54 |
]
|
55 |
|
56 |
args = parse_augment()
|
|
|
57 |
# args.device = 'cuda:5'
|
58 |
+
# args.disable_gpt = True
|
59 |
+
# args.enable_reduce_tokens = False
|
60 |
# args.port=20322
|
61 |
+
# args.captioner = 'blip'
|
62 |
+
# args.regular_box = True
|
63 |
+
shared_captioner = build_captioner(args.captioner, args.device, args)
|
64 |
+
shared_sam_model = sam_model_registry['vit_h'](checkpoint=args.segmenter_checkpoint).to(args.device)
|
65 |
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
|
68 |
+
segmenter = build_segmenter(args.segmenter, args.device, args, model=sam_model)
|
69 |
+
captioner = captioner
|
70 |
+
if session_id is not None:
|
71 |
+
print('Init caption anything for session {}'.format(session_id))
|
72 |
+
return CaptionAnything(args, api_key, captioner=captioner, segmenter=segmenter, text_refiner=text_refiner)
|
73 |
+
|
74 |
+
|
75 |
+
def init_openai_api_key(api_key=""):
|
76 |
+
text_refiner = None
|
77 |
+
if api_key and len(api_key) > 30:
|
78 |
+
try:
|
79 |
+
text_refiner = build_text_refiner(args.text_refiner, args.device, args, api_key)
|
80 |
+
text_refiner.llm('hi') # test
|
81 |
+
except:
|
82 |
+
text_refiner = None
|
83 |
+
openai_available = text_refiner is not None
|
84 |
+
return gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = openai_available), gr.update(visible = True), gr.update(visible = True), gr.update(visible = True), text_refiner
|
85 |
+
|
86 |
+
|
87 |
+
def get_prompt(chat_input, click_state, click_mode):
|
88 |
inputs = json.loads(chat_input)
|
89 |
+
if click_mode == 'Continuous':
|
90 |
+
points = click_state[0]
|
91 |
+
labels = click_state[1]
|
92 |
+
for input in inputs:
|
93 |
+
points.append(input[:2])
|
94 |
+
labels.append(input[2])
|
95 |
+
elif click_mode == 'Single':
|
96 |
+
points = []
|
97 |
+
labels = []
|
98 |
+
for input in inputs:
|
99 |
+
points.append(input[:2])
|
100 |
+
labels.append(input[2])
|
101 |
+
click_state[0] = points
|
102 |
+
click_state[1] = labels
|
103 |
+
else:
|
104 |
+
raise NotImplementedError
|
105 |
|
106 |
prompt = {
|
107 |
"prompt_type":["click"],
|
108 |
+
"input_point":click_state[0],
|
109 |
+
"input_label":click_state[1],
|
110 |
"multimask_output":"True",
|
111 |
}
|
112 |
return prompt
|
113 |
|
114 |
+
def update_click_state(click_state, caption, click_mode):
|
115 |
+
if click_mode == 'Continuous':
|
116 |
+
click_state[2].append(caption)
|
117 |
+
elif click_mode == 'Single':
|
118 |
+
click_state[2] = [caption]
|
119 |
+
else:
|
120 |
+
raise NotImplementedError
|
121 |
+
|
122 |
+
|
123 |
+
def chat_with_points(chat_input, click_state, state, text_refiner):
|
124 |
+
if text_refiner is None:
|
125 |
response = "Text refiner is not initilzed, please input openai api key."
|
126 |
state = state + [(chat_input, response)]
|
127 |
return state, state
|
|
|
137 |
else:
|
138 |
prev_visual_context = 'no point exists.'
|
139 |
chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context, "chat_input": chat_input})
|
140 |
+
response = text_refiner.llm(chat_prompt)
|
141 |
state = state + [(chat_input, response)]
|
142 |
return state, state
|
143 |
|
144 |
+
def inference_seg_cap(image_input, point_prompt, click_mode, language, sentiment, factuality,
|
145 |
+
length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
|
146 |
+
|
147 |
+
model = build_caption_anything_with_models(
|
148 |
+
args,
|
149 |
+
api_key="",
|
150 |
+
captioner=shared_captioner,
|
151 |
+
sam_model=shared_sam_model,
|
152 |
+
text_refiner=text_refiner,
|
153 |
+
session_id=iface.app_id
|
154 |
+
)
|
155 |
+
|
156 |
+
model.segmenter.image_embedding = image_embedding
|
157 |
+
model.segmenter.predictor.original_size = original_size
|
158 |
+
model.segmenter.predictor.input_size = input_size
|
159 |
+
model.segmenter.predictor.is_image_set = True
|
160 |
|
161 |
if point_prompt == 'Positive':
|
162 |
coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
|
|
170 |
|
171 |
# click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
|
172 |
# chat_input = click_coordinate
|
173 |
+
prompt = get_prompt(coordinate, click_state, click_mode)
|
174 |
print('prompt: ', prompt, 'controls: ', controls)
|
175 |
|
176 |
out = model.inference(image_input, prompt, controls)
|
|
|
179 |
# state = state + [(f'{k}: {v}', None)]
|
180 |
state = state + [("caption: {}".format(out['generated_captions']['raw_caption']), None)]
|
181 |
wiki = out['generated_captions'].get('wiki', "")
|
182 |
+
|
183 |
+
update_click_state(click_state, out['generated_captions']['raw_caption'], click_mode)
|
184 |
text = out['generated_captions']['raw_caption']
|
185 |
# draw = ImageDraw.Draw(image_input)
|
186 |
# draw.text((evt.index[0], evt.index[1]), text, textcolor=(0,0,255), text_size=120)
|
187 |
+
input_mask = np.array(out['mask'].convert('P'))
|
188 |
image_input = mask_painter(np.array(image_input), input_mask)
|
189 |
origin_image_input = image_input
|
190 |
image_input = create_bubble_frame(image_input, text, (evt.index[0], evt.index[1]))
|
|
|
207 |
if ratio < 1.0:
|
208 |
image_input = image_input.resize((int(width * ratio), int(height * ratio)))
|
209 |
print('Scaling input image to {}'.format(image_input.size))
|
210 |
+
|
211 |
+
model = build_caption_anything_with_models(
|
212 |
+
args,
|
213 |
+
api_key="",
|
214 |
+
captioner=shared_captioner,
|
215 |
+
sam_model=shared_sam_model,
|
216 |
+
session_id=iface.app_id
|
217 |
+
)
|
218 |
model.segmenter.set_image(image_input)
|
219 |
+
image_embedding = model.segmenter.image_embedding
|
220 |
+
original_size = model.segmenter.predictor.original_size
|
221 |
+
input_size = model.segmenter.predictor.input_size
|
222 |
+
return state, state, image_input, click_state, image_input, image_embedding, original_size, input_size
|
223 |
|
224 |
with gr.Blocks(
|
225 |
css='''
|
|
|
230 |
state = gr.State([])
|
231 |
click_state = gr.State([[],[],[]])
|
232 |
origin_image = gr.State(None)
|
233 |
+
image_embedding = gr.State(None)
|
234 |
+
text_refiner = gr.State(None)
|
235 |
+
original_size = gr.State(None)
|
236 |
+
input_size = gr.State(None)
|
237 |
|
238 |
gr.Markdown(title)
|
239 |
gr.Markdown(description)
|
|
|
244 |
image_input = gr.Image(type="pil", interactive=True, elem_id="image_upload")
|
245 |
example_image = gr.Image(type="pil", interactive=False, visible=False)
|
246 |
with gr.Row(scale=1.0):
|
247 |
+
with gr.Row(scale=0.4):
|
248 |
+
point_prompt = gr.Radio(
|
249 |
+
choices=["Positive", "Negative"],
|
250 |
+
value="Positive",
|
251 |
+
label="Point Prompt",
|
252 |
+
interactive=True)
|
253 |
+
click_mode = gr.Radio(
|
254 |
+
choices=["Continuous", "Single"],
|
255 |
+
value="Continuous",
|
256 |
+
label="Clicking Mode",
|
257 |
+
interactive=True)
|
258 |
+
with gr.Row(scale=0.4):
|
259 |
+
clear_button_clike = gr.Button(value="Clear Clicks", interactive=True)
|
260 |
+
clear_button_image = gr.Button(value="Clear Image", interactive=True)
|
261 |
with gr.Column(visible=False) as modules_need_gpt:
|
262 |
with gr.Row(scale=1.0):
|
263 |
language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
|
264 |
+
|
265 |
sentiment = gr.Radio(
|
266 |
choices=["Positive", "Natural", "Negative"],
|
267 |
value="Natural",
|
|
|
282 |
step=1,
|
283 |
interactive=True,
|
284 |
label="Length",
|
285 |
+
)
|
286 |
+
with gr.Column(visible=True) as modules_not_need_gpt3:
|
287 |
+
gr.Examples(
|
288 |
+
examples=examples,
|
289 |
+
inputs=[example_image],
|
290 |
+
)
|
291 |
with gr.Column(scale=0.5):
|
292 |
openai_api_key = gr.Textbox(
|
293 |
+
placeholder="Input openAI API key",
|
294 |
show_label=False,
|
295 |
label = "OpenAI API Key",
|
296 |
lines=1,
|
297 |
+
type="password")
|
298 |
+
with gr.Row(scale=0.5):
|
299 |
+
enable_chatGPT_button = gr.Button(value="Run with ChatGPT", interactive=True, variant='primary')
|
300 |
+
disable_chatGPT_button = gr.Button(value="Run without ChatGPT (Faster)", interactive=True, variant='primary')
|
301 |
with gr.Column(visible=False) as modules_need_gpt2:
|
302 |
+
wiki_output = gr.Textbox(lines=5, label="Wiki", max_lines=5)
|
303 |
with gr.Column(visible=False) as modules_not_need_gpt2:
|
304 |
+
chatbot = gr.Chatbot(label="Chat about Selected Object",).style(height=550,scale=0.5)
|
305 |
with gr.Column(visible=False) as modules_need_gpt3:
|
306 |
chat_input = gr.Textbox(lines=1, label="Chat Input")
|
307 |
with gr.Row():
|
308 |
clear_button_text = gr.Button(value="Clear Text", interactive=True)
|
309 |
submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
|
310 |
+
|
311 |
+
openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
312 |
+
enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
313 |
+
disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
|
314 |
+
|
315 |
clear_button_clike.click(
|
316 |
lambda x: ([[], [], []], x, ""),
|
317 |
[origin_image],
|
|
|
341 |
show_progress=False
|
342 |
)
|
343 |
|
344 |
+
image_input.upload(upload_callback,[image_input, state], [chatbot, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
345 |
+
chat_input.submit(chat_with_points, [chat_input, click_state, state, text_refiner], [chatbot, state])
|
346 |
+
example_image.change(upload_callback,[example_image, state], [state, state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
# select coordinate
|
349 |
image_input.select(inference_seg_cap,
|
350 |
inputs=[
|
351 |
origin_image,
|
352 |
point_prompt,
|
353 |
+
click_mode,
|
354 |
language,
|
355 |
sentiment,
|
356 |
factuality,
|
357 |
length,
|
358 |
+
image_embedding,
|
359 |
state,
|
360 |
+
click_state,
|
361 |
+
original_size,
|
362 |
+
input_size,
|
363 |
+
text_refiner
|
364 |
],
|
365 |
outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
|
366 |
show_progress=False, queue=True)
|
367 |
|
368 |
+
iface.queue(concurrency_count=1, api_open=False, max_size=10)
|
369 |
+
iface.launch(server_name="0.0.0.0", enable_queue=True, server_port=args.port, share=args.gradio_share)
|
caption_anything.py
CHANGED
@@ -6,14 +6,17 @@ import argparse
|
|
6 |
import pdb
|
7 |
import time
|
8 |
from PIL import Image
|
|
|
|
|
9 |
|
10 |
class CaptionAnything():
|
11 |
-
def __init__(self, args, api_key=""):
|
12 |
self.args = args
|
13 |
-
self.captioner = build_captioner(args.captioner, args.device, args)
|
14 |
-
self.segmenter = build_segmenter(args.segmenter, args.device, args)
|
|
|
15 |
self.text_refiner = None
|
16 |
-
if not args.disable_gpt:
|
17 |
self.init_refiner(api_key)
|
18 |
|
19 |
def init_refiner(self, api_key):
|
@@ -22,19 +25,25 @@ class CaptionAnything():
|
|
22 |
self.text_refiner.llm('hi') # test
|
23 |
except:
|
24 |
self.text_refiner = None
|
25 |
-
print('
|
26 |
|
27 |
def inference(self, image, prompt, controls, disable_gpt=False):
|
28 |
# segment with prompt
|
29 |
print("CA prompt: ", prompt, "CA controls",controls)
|
30 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
mask_save_path = f'result/mask_{time.time()}.png'
|
32 |
if not os.path.exists(os.path.dirname(mask_save_path)):
|
33 |
os.makedirs(os.path.dirname(mask_save_path))
|
34 |
-
|
35 |
-
if
|
36 |
-
|
37 |
-
|
38 |
print('seg_mask path: ', mask_save_path)
|
39 |
print("seg_mask.shape: ", seg_mask.shape)
|
40 |
# captioning with mask
|
@@ -53,6 +62,7 @@ class CaptionAnything():
|
|
53 |
out = {'generated_captions': refined_caption,
|
54 |
'crop_save_path': crop_save_path,
|
55 |
'mask_save_path': mask_save_path,
|
|
|
56 |
'context_captions': context_captions}
|
57 |
return out
|
58 |
|
@@ -73,6 +83,7 @@ def parse_augment():
|
|
73 |
parser.add_argument('--disable_gpt', action="store_true")
|
74 |
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
75 |
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
|
|
76 |
args = parser.parse_args()
|
77 |
|
78 |
if args.debug:
|
@@ -115,4 +126,4 @@ if __name__ == "__main__":
|
|
115 |
print('Language controls:\n', controls)
|
116 |
out = model.inference(image_path, prompt, controls)
|
117 |
|
118 |
-
|
|
|
6 |
import pdb
|
7 |
import time
|
8 |
from PIL import Image
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
|
12 |
class CaptionAnything():
|
13 |
+
def __init__(self, args, api_key="", captioner=None, segmenter=None, text_refiner=None):
|
14 |
self.args = args
|
15 |
+
self.captioner = build_captioner(args.captioner, args.device, args) if captioner is None else captioner
|
16 |
+
self.segmenter = build_segmenter(args.segmenter, args.device, args) if segmenter is None else segmenter
|
17 |
+
|
18 |
self.text_refiner = None
|
19 |
+
if not args.disable_gpt and text_refiner is not None:
|
20 |
self.init_refiner(api_key)
|
21 |
|
22 |
def init_refiner(self, api_key):
|
|
|
25 |
self.text_refiner.llm('hi') # test
|
26 |
except:
|
27 |
self.text_refiner = None
|
28 |
+
print('OpenAI GPT is not available')
|
29 |
|
30 |
def inference(self, image, prompt, controls, disable_gpt=False):
|
31 |
# segment with prompt
|
32 |
print("CA prompt: ", prompt, "CA controls",controls)
|
33 |
seg_mask = self.segmenter.inference(image, prompt)[0, ...]
|
34 |
+
if self.args.enable_morphologyex:
|
35 |
+
seg_mask = 255 * seg_mask.astype(np.uint8)
|
36 |
+
seg_mask = np.stack([seg_mask, seg_mask, seg_mask], axis = -1)
|
37 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_OPEN, kernel = np.ones((6, 6), np.uint8))
|
38 |
+
seg_mask = cv2.morphologyEx(seg_mask, cv2.MORPH_CLOSE, kernel = np.ones((6, 6), np.uint8))
|
39 |
+
seg_mask = seg_mask[:,:,0] > 0
|
40 |
mask_save_path = f'result/mask_{time.time()}.png'
|
41 |
if not os.path.exists(os.path.dirname(mask_save_path)):
|
42 |
os.makedirs(os.path.dirname(mask_save_path))
|
43 |
+
seg_mask_img = Image.fromarray(seg_mask.astype('int') * 255.)
|
44 |
+
if seg_mask_img.mode != 'RGB':
|
45 |
+
seg_mask_img = seg_mask_img.convert('RGB')
|
46 |
+
seg_mask_img.save(mask_save_path)
|
47 |
print('seg_mask path: ', mask_save_path)
|
48 |
print("seg_mask.shape: ", seg_mask.shape)
|
49 |
# captioning with mask
|
|
|
62 |
out = {'generated_captions': refined_caption,
|
63 |
'crop_save_path': crop_save_path,
|
64 |
'mask_save_path': mask_save_path,
|
65 |
+
'mask': seg_mask_img,
|
66 |
'context_captions': context_captions}
|
67 |
return out
|
68 |
|
|
|
83 |
parser.add_argument('--disable_gpt', action="store_true")
|
84 |
parser.add_argument('--enable_reduce_tokens', action="store_true", default=False)
|
85 |
parser.add_argument('--disable_reuse_features', action="store_true", default=False)
|
86 |
+
parser.add_argument('--enable_morphologyex', action="store_true", default=False)
|
87 |
args = parser.parse_args()
|
88 |
|
89 |
if args.debug:
|
|
|
126 |
print('Language controls:\n', controls)
|
127 |
out = model.inference(image_path, prompt, controls)
|
128 |
|
129 |
+
|
segmenter/__init__.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
from segmenter.base_segmenter import BaseSegmenter
|
2 |
|
3 |
|
4 |
-
def build_segmenter(type, device, args=None):
|
5 |
if type == 'base':
|
6 |
-
return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features)
|
|
|
|
|
|
1 |
from segmenter.base_segmenter import BaseSegmenter
|
2 |
|
3 |
|
4 |
+
def build_segmenter(type, device, args=None, model=None):
|
5 |
if type == 'base':
|
6 |
+
return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
|
7 |
+
else:
|
8 |
+
raise NotImplementedError()
|
segmenter/base_segmenter.py
CHANGED
@@ -9,15 +9,18 @@ import matplotlib.pyplot as plt
|
|
9 |
import PIL
|
10 |
|
11 |
class BaseSegmenter:
|
12 |
-
def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True):
|
13 |
print(f"Initializing BaseSegmenter to {device}")
|
14 |
self.device = device
|
15 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
16 |
self.processor = None
|
17 |
self.model_type = model_type
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
21 |
self.reuse_feature = reuse_feature
|
22 |
self.predictor = SamPredictor(self.model)
|
23 |
self.mask_generator = SamAutomaticMaskGenerator(self.model)
|
|
|
9 |
import PIL
|
10 |
|
11 |
class BaseSegmenter:
|
12 |
+
def __init__(self, device, checkpoint, model_type='vit_h', reuse_feature = True, model=None):
|
13 |
print(f"Initializing BaseSegmenter to {device}")
|
14 |
self.device = device
|
15 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
16 |
self.processor = None
|
17 |
self.model_type = model_type
|
18 |
+
if model is None:
|
19 |
+
self.checkpoint = checkpoint
|
20 |
+
self.model = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
|
21 |
+
self.model.to(device=self.device)
|
22 |
+
else:
|
23 |
+
self.model = model
|
24 |
self.reuse_feature = reuse_feature
|
25 |
self.predictor = SamPredictor(self.model)
|
26 |
self.mask_generator = SamAutomaticMaskGenerator(self.model)
|