ttengwang commited on
Commit
ff883a7
1 Parent(s): 73cef8f

update promtps for chating, add duplicate icon

Browse files
app.py CHANGED
@@ -20,6 +20,7 @@ from segment_anything import sam_model_registry
20
  from text_refiner import build_text_refiner
21
  from segmenter import build_segmenter
22
 
 
23
  def download_checkpoint(url, folder, filename):
24
  os.makedirs(folder, exist_ok=True)
25
  filepath = os.path.join(folder, filename)
@@ -32,16 +33,11 @@ def download_checkpoint(url, folder, filename):
32
  f.write(chunk)
33
 
34
  return filepath
35
- checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
36
- folder = "segmenter"
37
- filename = "sam_vit_h_4b8939.pth"
38
-
39
- download_checkpoint(checkpoint_url, folder, filename)
40
 
41
 
42
- title = """<h1 align="center">Caption-Anything</h1>"""
43
- description = """Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything
44
  """
 
45
 
46
  examples = [
47
  ["test_img/img35.webp"],
@@ -53,7 +49,26 @@ examples = [
53
  ["test_img/img1.jpg"],
54
  ]
55
 
 
 
 
 
 
 
 
 
 
 
 
56
  args = parse_augment()
 
 
 
 
 
 
 
 
57
  # args.device = 'cuda:5'
58
  # args.disable_gpt = True
59
  # args.enable_reduce_tokens = False
@@ -61,7 +76,7 @@ args = parse_augment()
61
  # args.captioner = 'blip'
62
  # args.regular_box = True
63
  shared_captioner = build_captioner(args.captioner, args.device, args)
64
- shared_sam_model = sam_model_registry['vit_h'](checkpoint=args.segmenter_checkpoint).to(args.device)
65
 
66
 
67
  def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
@@ -102,7 +117,7 @@ def get_prompt(chat_input, click_state, click_mode):
102
  click_state[1] = labels
103
  else:
104
  raise NotImplementedError
105
-
106
  prompt = {
107
  "prompt_type":["click"],
108
  "input_point":click_state[0],
@@ -117,21 +132,21 @@ def update_click_state(click_state, caption, click_mode):
117
  elif click_mode == 'Single':
118
  click_state[2] = [caption]
119
  else:
120
- raise NotImplementedError
121
 
122
 
123
- def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
124
  if text_refiner is None:
125
  response = "Text refiner is not initilzed, please input openai api key."
126
  state = state + [(chat_input, response)]
127
  return state, state, chat_state
128
-
129
  points, labels, captions = click_state
130
  # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
131
  suffix = '\nHuman: {chat_input}\nAI: '
132
  qa_template = '\nHuman: {q}\nAI: {a}'
133
  # # "The image is of width {width} and height {height}."
134
- point_chat_prompt = "I am an AI trained to chat with you about an image based on specific points (w, h) you provide, along with their visual descriptions. Please note that (0, 0) refers to the top-left corner of the image, w refers to the width, and h refers to the height. Here are the points and their descriptions you've given me: {points_with_caps} \n Now, let's chat!"
135
  prev_visual_context = ""
136
  pos_points = []
137
  pos_captions = []
@@ -139,8 +154,8 @@ def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
139
  if labels[i] == 1:
140
  pos_points.append(f"({points[i][0]}, {points[i][0]})")
141
  pos_captions.append(captions[i])
142
- prev_visual_context = prev_visual_context + '\n' + 'Points: ' +', '.join(pos_points) + '. Description: ' + pos_captions[-1]
143
-
144
  context_length_thres = 500
145
  prev_history = ""
146
  for i in range(len(chat_state)):
@@ -149,26 +164,25 @@ def chat_with_points(chat_input, click_state, chat_state, state, text_refiner):
149
  prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
150
  else:
151
  break
152
-
153
- chat_prompt = point_chat_prompt.format(**{"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
154
  print('\nchat_prompt: ', chat_prompt)
155
  response = text_refiner.llm(chat_prompt)
156
  state = state + [(chat_input, response)]
157
  chat_state = chat_state + [(chat_input, response)]
158
  return state, state, chat_state
159
 
160
- def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
161
- length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
162
-
163
  model = build_caption_anything_with_models(
164
- args,
165
  api_key="",
166
  captioner=shared_captioner,
167
  sam_model=shared_sam_model,
168
  text_refiner=text_refiner,
169
  session_id=iface.app_id
170
  )
171
-
172
  model.segmenter.image_embedding = image_embedding
173
  model.segmenter.predictor.original_size = original_size
174
  model.segmenter.predictor.input_size = input_size
@@ -178,11 +192,11 @@ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, langua
178
  coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
179
  else:
180
  coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
181
-
182
  controls = {'length': length,
183
- 'sentiment': sentiment,
184
- 'factuality': factuality,
185
- 'language': language}
186
 
187
  # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
188
  # chat_input = click_coordinate
@@ -217,8 +231,7 @@ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, langua
217
  yield state, state, click_state, chat_input, refined_image_input, wiki
218
 
219
 
220
- def upload_callback(image_input, state):
221
- state = [] + [(None, 'Image size: ' + str(image_input.size))]
222
  chat_state = []
223
  click_state = [[], [], []]
224
  res = 1024
@@ -227,9 +240,9 @@ def upload_callback(image_input, state):
227
  if ratio < 1.0:
228
  image_input = image_input.resize((int(width * ratio), int(height * ratio)))
229
  print('Scaling input image to {}'.format(image_input.size))
230
-
231
  model = build_caption_anything_with_models(
232
- args,
233
  api_key="",
234
  captioner=shared_captioner,
235
  sam_model=shared_sam_model,
@@ -239,10 +252,11 @@ def upload_callback(image_input, state):
239
  image_embedding = model.segmenter.image_embedding
240
  original_size = model.segmenter.predictor.original_size
241
  input_size = model.segmenter.predictor.input_size
242
- return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size
 
243
 
244
  with gr.Blocks(
245
- css='''
246
  #image_upload{min-height:400px}
247
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
248
  '''
@@ -255,6 +269,7 @@ with gr.Blocks(
255
  text_refiner = gr.State(None)
256
  original_size = gr.State(None)
257
  input_size = gr.State(None)
 
258
 
259
  gr.Markdown(title)
260
  gr.Markdown(description)
@@ -281,13 +296,13 @@ with gr.Blocks(
281
  clear_button_image = gr.Button(value="Clear Image", interactive=True)
282
  with gr.Column(visible=False) as modules_need_gpt:
283
  with gr.Row(scale=1.0):
284
- language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
285
- sentiment = gr.Radio(
286
- choices=["Positive", "Natural", "Negative"],
287
- value="Natural",
288
- label="Sentiment",
289
- interactive=True,
290
- )
291
  with gr.Row(scale=1.0):
292
  factuality = gr.Radio(
293
  choices=["Factual", "Imagination"],
@@ -304,10 +319,10 @@ with gr.Blocks(
304
  label="Generated Caption Length",
305
  )
306
  enable_wiki = gr.Radio(
307
- choices=["Yes", "No"],
308
- value="No",
309
- label="Enable Wiki",
310
- interactive=True)
311
  with gr.Column(visible=True) as modules_not_need_gpt3:
312
  gr.Examples(
313
  examples=examples,
@@ -332,11 +347,11 @@ with gr.Blocks(
332
  with gr.Row():
333
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
334
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
335
-
336
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
337
  enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
338
  disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
339
-
340
  clear_button_clike.click(
341
  lambda x: ([[], [], []], x, ""),
342
  [origin_image],
@@ -345,9 +360,9 @@ with gr.Blocks(
345
  show_progress=False
346
  )
347
  clear_button_image.click(
348
- lambda: (None, [], [], [], [[], [], []], "", ""),
349
  [],
350
- [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
351
  queue=False,
352
  show_progress=False
353
  )
@@ -359,37 +374,38 @@ with gr.Blocks(
359
  show_progress=False
360
  )
361
  image_input.clear(
362
- lambda: (None, [], [], [], [[], [], []], "", ""),
363
  [],
364
- [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image],
365
  queue=False,
366
  show_progress=False
367
  )
368
 
369
- image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
370
- chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner], [chatbot, state, chat_state])
371
- example_image.change(upload_callback,[example_image, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size])
 
372
 
373
  # select coordinate
374
- image_input.select(inference_seg_cap,
375
- inputs=[
376
- origin_image,
377
- point_prompt,
378
- click_mode,
379
- enable_wiki,
380
- language,
381
- sentiment,
382
- factuality,
383
- length,
384
- image_embedding,
385
- state,
386
- click_state,
387
- original_size,
388
- input_size,
389
- text_refiner
390
- ],
391
- outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
392
- show_progress=False, queue=True)
393
-
394
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
395
  iface.launch(server_name="0.0.0.0", enable_queue=True)
 
20
  from text_refiner import build_text_refiner
21
  from segmenter import build_segmenter
22
 
23
+
24
  def download_checkpoint(url, folder, filename):
25
  os.makedirs(folder, exist_ok=True)
26
  filepath = os.path.join(folder, filename)
 
33
  f.write(chunk)
34
 
35
  return filepath
 
 
 
 
 
36
 
37
 
38
+ title = """<p><h1 align="center">Caption-Anything</h1></p>
 
39
  """
40
+ description = """<p>Gradio demo for Caption Anything, image to dense captioning generation with various language styles. To use it, simply upload your image, or click one of the examples to load them. Code: https://github.com/ttengwang/Caption-Anything <a href="https://huggingface.co/spaces/TencentARC/Caption-Anything?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a></p>"""
41
 
42
  examples = [
43
  ["test_img/img35.webp"],
 
49
  ["test_img/img1.jpg"],
50
  ]
51
 
52
+ seg_model_map = {
53
+ 'base': 'vit_b',
54
+ 'large': 'vit_l',
55
+ 'huge': 'vit_h'
56
+ }
57
+ ckpt_url_map = {
58
+ 'vit_b': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth',
59
+ 'vit_l': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth',
60
+ 'vit_h': 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
61
+ }
62
+
63
  args = parse_augment()
64
+
65
+ checkpoint_url = ckpt_url_map[seg_model_map[args.segmenter]]
66
+ folder = "segmenter"
67
+ filename = os.path.basename(checkpoint_url)
68
+ args.segmenter_checkpoint = os.path.join(folder, filename)
69
+
70
+ download_checkpoint(checkpoint_url, folder, filename)
71
+
72
  # args.device = 'cuda:5'
73
  # args.disable_gpt = True
74
  # args.enable_reduce_tokens = False
 
76
  # args.captioner = 'blip'
77
  # args.regular_box = True
78
  shared_captioner = build_captioner(args.captioner, args.device, args)
79
+ shared_sam_model = sam_model_registry[seg_model_map[args.segmenter]](checkpoint=args.segmenter_checkpoint).to(args.device)
80
 
81
 
82
  def build_caption_anything_with_models(args, api_key="", captioner=None, sam_model=None, text_refiner=None, session_id=None):
 
117
  click_state[1] = labels
118
  else:
119
  raise NotImplementedError
120
+
121
  prompt = {
122
  "prompt_type":["click"],
123
  "input_point":click_state[0],
 
132
  elif click_mode == 'Single':
133
  click_state[2] = [caption]
134
  else:
135
+ raise NotImplementedError
136
 
137
 
138
+ def chat_with_points(chat_input, click_state, chat_state, state, text_refiner, img_caption):
139
  if text_refiner is None:
140
  response = "Text refiner is not initilzed, please input openai api key."
141
  state = state + [(chat_input, response)]
142
  return state, state, chat_state
143
+
144
  points, labels, captions = click_state
145
  # point_chat_prompt = "I want you act as a chat bot in terms of image. I will give you some points (w, h) in the image and tell you what happed on the point in natural language. Note that (0, 0) refers to the top-left corner of the image, w refers to the width and h refers the height. You should chat with me based on the fact in the image instead of imagination. Now I tell you the points with their visual description:\n{points_with_caps}\nNow begin chatting!"
146
  suffix = '\nHuman: {chat_input}\nAI: '
147
  qa_template = '\nHuman: {q}\nAI: {a}'
148
  # # "The image is of width {width} and height {height}."
149
+ point_chat_prompt = "I am an AI trained to chat with you about an image. I am greate at what is going on in any image based on the image information your provide. The overall image description is \"{img_caption}\". You will also provide me objects in the image in details, i.e., their location and visual descriptions. Here are the locations and descriptions of events that happen in the image: {points_with_caps} \n Now, let's chat!"
150
  prev_visual_context = ""
151
  pos_points = []
152
  pos_captions = []
 
154
  if labels[i] == 1:
155
  pos_points.append(f"({points[i][0]}, {points[i][0]})")
156
  pos_captions.append(captions[i])
157
+ prev_visual_context = prev_visual_context + '\n' + 'There is an event described as \"{}\" locating at {}'.format(pos_captions[-1], ', '.join(pos_points))
158
+
159
  context_length_thres = 500
160
  prev_history = ""
161
  for i in range(len(chat_state)):
 
164
  prev_history = prev_history + qa_template.format(**{"q": q, "a": a})
165
  else:
166
  break
167
+ chat_prompt = point_chat_prompt.format(**{"img_caption":img_caption,"points_with_caps": prev_visual_context}) + prev_history + suffix.format(**{"chat_input": chat_input})
 
168
  print('\nchat_prompt: ', chat_prompt)
169
  response = text_refiner.llm(chat_prompt)
170
  state = state + [(chat_input, response)]
171
  chat_state = chat_state + [(chat_input, response)]
172
  return state, state, chat_state
173
 
174
+ def inference_seg_cap(image_input, point_prompt, click_mode, enable_wiki, language, sentiment, factuality,
175
+ length, image_embedding, state, click_state, original_size, input_size, text_refiner, evt:gr.SelectData):
176
+
177
  model = build_caption_anything_with_models(
178
+ args,
179
  api_key="",
180
  captioner=shared_captioner,
181
  sam_model=shared_sam_model,
182
  text_refiner=text_refiner,
183
  session_id=iface.app_id
184
  )
185
+
186
  model.segmenter.image_embedding = image_embedding
187
  model.segmenter.predictor.original_size = original_size
188
  model.segmenter.predictor.input_size = input_size
 
192
  coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
193
  else:
194
  coordinate = "[[{}, {}, 0]]".format(str(evt.index[0]), str(evt.index[1]))
195
+
196
  controls = {'length': length,
197
+ 'sentiment': sentiment,
198
+ 'factuality': factuality,
199
+ 'language': language}
200
 
201
  # click_coordinate = "[[{}, {}, 1]]".format(str(evt.index[0]), str(evt.index[1]))
202
  # chat_input = click_coordinate
 
231
  yield state, state, click_state, chat_input, refined_image_input, wiki
232
 
233
 
234
+ def upload_callback(image_input, state):
 
235
  chat_state = []
236
  click_state = [[], [], []]
237
  res = 1024
 
240
  if ratio < 1.0:
241
  image_input = image_input.resize((int(width * ratio), int(height * ratio)))
242
  print('Scaling input image to {}'.format(image_input.size))
243
+ state = [] + [(None, 'Image size: ' + str(image_input.size))]
244
  model = build_caption_anything_with_models(
245
+ args,
246
  api_key="",
247
  captioner=shared_captioner,
248
  sam_model=shared_sam_model,
 
252
  image_embedding = model.segmenter.image_embedding
253
  original_size = model.segmenter.predictor.original_size
254
  input_size = model.segmenter.predictor.input_size
255
+ img_caption, _ = model.captioner.inference_seg(image_input)
256
+ return state, state, chat_state, image_input, click_state, image_input, image_embedding, original_size, input_size, img_caption
257
 
258
  with gr.Blocks(
259
+ css='''
260
  #image_upload{min-height:400px}
261
  #image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 600px}
262
  '''
 
269
  text_refiner = gr.State(None)
270
  original_size = gr.State(None)
271
  input_size = gr.State(None)
272
+ img_caption = gr.State(None)
273
 
274
  gr.Markdown(title)
275
  gr.Markdown(description)
 
296
  clear_button_image = gr.Button(value="Clear Image", interactive=True)
297
  with gr.Column(visible=False) as modules_need_gpt:
298
  with gr.Row(scale=1.0):
299
+ language = gr.Dropdown(['English', 'Chinese', 'French', "Spanish", "Arabic", "Portuguese", "Cantonese"], value="English", label="Language", interactive=True)
300
+ sentiment = gr.Radio(
301
+ choices=["Positive", "Natural", "Negative"],
302
+ value="Natural",
303
+ label="Sentiment",
304
+ interactive=True,
305
+ )
306
  with gr.Row(scale=1.0):
307
  factuality = gr.Radio(
308
  choices=["Factual", "Imagination"],
 
319
  label="Generated Caption Length",
320
  )
321
  enable_wiki = gr.Radio(
322
+ choices=["Yes", "No"],
323
+ value="No",
324
+ label="Enable Wiki",
325
+ interactive=True)
326
  with gr.Column(visible=True) as modules_not_need_gpt3:
327
  gr.Examples(
328
  examples=examples,
 
347
  with gr.Row():
348
  clear_button_text = gr.Button(value="Clear Text", interactive=True)
349
  submit_button_text = gr.Button(value="Submit", interactive=True, variant="primary")
350
+
351
  openai_api_key.submit(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
352
  enable_chatGPT_button.click(init_openai_api_key, inputs=[openai_api_key], outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
353
  disable_chatGPT_button.click(init_openai_api_key, outputs=[modules_need_gpt,modules_need_gpt2, modules_need_gpt3, modules_not_need_gpt, modules_not_need_gpt2, modules_not_need_gpt3, text_refiner])
354
+
355
  clear_button_clike.click(
356
  lambda x: ([[], [], []], x, ""),
357
  [origin_image],
 
360
  show_progress=False
361
  )
362
  clear_button_image.click(
363
+ lambda: (None, [], [], [], [[], [], []], "", "", ""),
364
  [],
365
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
366
  queue=False,
367
  show_progress=False
368
  )
 
374
  show_progress=False
375
  )
376
  image_input.clear(
377
+ lambda: (None, [], [], [], [[], [], []], "", "", ""),
378
  [],
379
+ [image_input, chatbot, state, chat_state, click_state, wiki_output, origin_image, img_caption],
380
  queue=False,
381
  show_progress=False
382
  )
383
 
384
+ image_input.upload(upload_callback,[image_input, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size, img_caption])
385
+ chat_input.submit(chat_with_points, [chat_input, click_state, chat_state, state, text_refiner, img_caption], [chatbot, state, chat_state])
386
+ chat_input.submit(lambda: "", None, chat_input)
387
+ example_image.change(upload_callback,[example_image, state], [chatbot, state, chat_state, origin_image, click_state, image_input, image_embedding, original_size, input_size, img_caption])
388
 
389
  # select coordinate
390
+ image_input.select(inference_seg_cap,
391
+ inputs=[
392
+ origin_image,
393
+ point_prompt,
394
+ click_mode,
395
+ enable_wiki,
396
+ language,
397
+ sentiment,
398
+ factuality,
399
+ length,
400
+ image_embedding,
401
+ state,
402
+ click_state,
403
+ original_size,
404
+ input_size,
405
+ text_refiner
406
+ ],
407
+ outputs=[chatbot, state, click_state, chat_input, image_input, wiki_output],
408
+ show_progress=False, queue=True)
409
+
410
  iface.queue(concurrency_count=5, api_open=False, max_size=10)
411
  iface.launch(server_name="0.0.0.0", enable_queue=True)
caption_anything.py CHANGED
@@ -72,7 +72,7 @@ class CaptionAnything():
72
  def parse_augment():
73
  parser = argparse.ArgumentParser()
74
  parser.add_argument('--captioner', type=str, default="blip2")
75
- parser.add_argument('--segmenter', type=str, default="base")
76
  parser.add_argument('--text_refiner', type=str, default="base")
77
  parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
78
  parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
 
72
  def parse_augment():
73
  parser = argparse.ArgumentParser()
74
  parser.add_argument('--captioner', type=str, default="blip2")
75
+ parser.add_argument('--segmenter', type=str, default="huge")
76
  parser.add_argument('--text_refiner', type=str, default="base")
77
  parser.add_argument('--segmenter_checkpoint', type=str, default="segmenter/sam_vit_h_4b8939.pth")
78
  parser.add_argument('--seg_crop_mode', type=str, default="wo_bg", choices=['wo_bg', 'w_bg'], help="whether to add or remove background of the image when captioning")
captioner/base_captioner.py CHANGED
@@ -130,13 +130,17 @@ class BaseCaptioner:
130
  return caption, crop_save_path
131
 
132
 
133
- def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str], crop_mode="w_bg", filter=False, disable_regular_box = False):
 
 
 
134
  if type(image) == str:
135
  image = Image.open(image)
136
  if type(seg_mask) == str:
137
  seg_mask = Image.open(seg_mask)
138
  elif type(seg_mask) == np.ndarray:
139
  seg_mask = Image.fromarray(seg_mask)
 
140
  seg_mask = seg_mask.resize(image.size)
141
  seg_mask = np.array(seg_mask) > 0
142
 
 
130
  return caption, crop_save_path
131
 
132
 
133
+ def inference_seg(self, image: Union[np.ndarray, str], seg_mask: Union[np.ndarray, Image.Image, str]=None, crop_mode="w_bg", filter=False, disable_regular_box = False):
134
+ if seg_mask is None:
135
+ seg_mask = np.ones(image.size).astype(bool)
136
+
137
  if type(image) == str:
138
  image = Image.open(image)
139
  if type(seg_mask) == str:
140
  seg_mask = Image.open(seg_mask)
141
  elif type(seg_mask) == np.ndarray:
142
  seg_mask = Image.fromarray(seg_mask)
143
+
144
  seg_mask = seg_mask.resize(image.size)
145
  seg_mask = np.array(seg_mask) > 0
146
 
captioner/blip2.py CHANGED
@@ -6,6 +6,8 @@ import pdb
6
  import cv2
7
  import numpy as np
8
  from typing import Union
 
 
9
  from .base_captioner import BaseCaptioner
10
 
11
  class BLIP2Captioner(BaseCaptioner):
@@ -15,14 +17,18 @@ class BLIP2Captioner(BaseCaptioner):
15
  self.dialogue = dialogue
16
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
17
  self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
18
- self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map = 'sequential', load_in_8bit=True)
 
 
 
 
19
  @torch.no_grad()
20
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
21
  if type(image) == str: # input path
22
- image = Image.open(image)
23
 
24
  if not self.dialogue:
25
- text_prompt = 'Context: ignore the white background in this image. Question: describe this image. Answer:'
26
  inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
27
  out = self.model.generate(**inputs, max_new_tokens=50)
28
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
@@ -42,7 +48,7 @@ class BLIP2Captioner(BaseCaptioner):
42
  out = self.model.generate(**inputs, max_new_tokens=50)
43
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
44
  context.append((input_texts, captions))
45
-
46
  return captions
47
 
48
  if __name__ == '__main__':
 
6
  import cv2
7
  import numpy as np
8
  from typing import Union
9
+
10
+ from tools import is_platform_win
11
  from .base_captioner import BaseCaptioner
12
 
13
  class BLIP2Captioner(BaseCaptioner):
 
17
  self.dialogue = dialogue
18
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
19
  self.processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
20
+ if is_platform_win():
21
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map="sequential", torch_dtype=self.torch_dtype)
22
+ else:
23
+ self.model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b", device_map='sequential', load_in_8bit=True)
24
+
25
  @torch.no_grad()
26
  def inference(self, image: Union[np.ndarray, Image.Image, str], filter=False):
27
  if type(image) == str: # input path
28
+ image = Image.open(image)
29
 
30
  if not self.dialogue:
31
+ text_prompt = 'Question: what does the image show? Answer:'
32
  inputs = self.processor(image, text = text_prompt, return_tensors="pt").to(self.device, self.torch_dtype)
33
  out = self.model.generate(**inputs, max_new_tokens=50)
34
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
 
48
  out = self.model.generate(**inputs, max_new_tokens=50)
49
  captions = self.processor.decode(out[0], skip_special_tokens=True).strip()
50
  context.append((input_texts, captions))
51
+
52
  return captions
53
 
54
  if __name__ == '__main__':
segmenter/__init__.py CHANGED
@@ -2,7 +2,4 @@ from segmenter.base_segmenter import BaseSegmenter
2
 
3
 
4
  def build_segmenter(type, device, args=None, model=None):
5
- if type == 'base':
6
- return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
7
- else:
8
- raise NotImplementedError()
 
2
 
3
 
4
  def build_segmenter(type, device, args=None, model=None):
5
+ return BaseSegmenter(device, args.segmenter_checkpoint, reuse_feature=not args.disable_reuse_features, model=model)
 
 
 
tools.py CHANGED
@@ -4,6 +4,11 @@ import numpy as np
4
  from PIL import Image
5
  import copy
6
  import time
 
 
 
 
 
7
 
8
 
9
  def colormap(rgb=True):
@@ -130,10 +135,10 @@ def vis_add_mask_wo_gaussian(image, background_mask, contour_mask, background_co
130
 
131
  for i in range(3):
132
  image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
133
- + background_color[i] * (background_alpha-background_mask*background_alpha)
134
-
135
  image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
136
- + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
137
 
138
  return image.astype('uint8')
139
 
@@ -155,7 +160,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
155
  assert input_image.shape[:2] == input_mask.shape, 'different shape'
156
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
157
 
158
-
159
  # 0: background, 1: foreground
160
  input_mask[input_mask>0] = 255
161
 
@@ -170,7 +175,7 @@ def mask_painter(input_image, input_mask, background_alpha=0.7, background_blur_
170
  painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
171
 
172
  # painted_image = background_dist_map
173
-
174
  return painted_image
175
 
176
 
@@ -257,10 +262,10 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
257
  # downsample input image and mask
258
  width, height = input_image.shape[0], input_image.shape[1]
259
  res = 1024
260
- ratio = min(1.0 * res / max(width, height), 1.0)
261
  input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
262
  input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
263
-
264
  # 0: background, 1: foreground
265
  msk = np.clip(input_mask, 0, 1)
266
 
@@ -271,14 +276,14 @@ def mask_painter_wo_gaussian(input_image, input_mask, background_alpha=0.5, back
271
  background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
272
 
273
  # paint
274
- painted_image = vis_add_mask_wo_gaussian\
275
  (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
276
 
277
  return painted_image
278
 
279
 
280
  if __name__ == '__main__':
281
-
282
  background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
283
  background_blur_radius = 31 # radius of background blur, must be odd number
284
  contour_width = 11 # contour width, must be odd number
@@ -288,14 +293,14 @@ if __name__ == '__main__':
288
  # load input image and mask
289
  input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
290
  input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
291
-
292
  # paint
293
  overall_time_1 = 0
294
  overall_time_2 = 0
295
  overall_time_3 = 0
296
  overall_time_4 = 0
297
  overall_time_5 = 0
298
-
299
  for i in range(50):
300
  t2 = time.time()
301
  painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')
 
4
  from PIL import Image
5
  import copy
6
  import time
7
+ import sys
8
+
9
+
10
+ def is_platform_win():
11
+ return sys.platform == "win32"
12
 
13
 
14
  def colormap(rgb=True):
 
135
 
136
  for i in range(3):
137
  image[:, :, i] = image[:, :, i] * (1-background_alpha+background_mask*background_alpha) \
138
+ + background_color[i] * (background_alpha-background_mask*background_alpha)
139
+
140
  image[:, :, i] = image[:, :, i] * (1-contour_alpha+contour_mask*contour_alpha) \
141
+ + contour_color[i] * (contour_alpha-contour_mask*contour_alpha)
142
 
143
  return image.astype('uint8')
144
 
 
160
  assert input_image.shape[:2] == input_mask.shape, 'different shape'
161
  assert background_blur_radius % 2 * contour_width % 2 > 0, 'background_blur_radius and contour_width must be ODD'
162
 
163
+
164
  # 0: background, 1: foreground
165
  input_mask[input_mask>0] = 255
166
 
 
175
  painted_image = vis_add_mask(painted_image, 255-contour_mask, color_list[contour_color], contour_alpha, contour_width)
176
 
177
  # painted_image = background_dist_map
178
+
179
  return painted_image
180
 
181
 
 
262
  # downsample input image and mask
263
  width, height = input_image.shape[0], input_image.shape[1]
264
  res = 1024
265
+ ratio = min(1.0 * res / max(width, height), 1.0)
266
  input_image = cv2.resize(input_image, (int(height*ratio), int(width*ratio)))
267
  input_mask = cv2.resize(input_mask, (int(height*ratio), int(width*ratio)))
268
+
269
  # 0: background, 1: foreground
270
  msk = np.clip(input_mask, 0, 1)
271
 
 
276
  background_mask, contour_mask = generator_dict[mode](msk, background_radius, contour_radius)
277
 
278
  # paint
279
+ painted_image = vis_add_mask_wo_gaussian \
280
  (input_image, background_mask, contour_mask, color_list[0], color_list[contour_color], background_alpha, contour_alpha) # black for background
281
 
282
  return painted_image
283
 
284
 
285
  if __name__ == '__main__':
286
+
287
  background_alpha = 0.7 # transparency of background 1: all black, 0: do nothing
288
  background_blur_radius = 31 # radius of background blur, must be odd number
289
  contour_width = 11 # contour width, must be odd number
 
293
  # load input image and mask
294
  input_image = np.array(Image.open('./test_img/painter_input_image.jpg').convert('RGB'))
295
  input_mask = np.array(Image.open('./test_img/painter_input_mask.jpg').convert('P'))
296
+
297
  # paint
298
  overall_time_1 = 0
299
  overall_time_2 = 0
300
  overall_time_3 = 0
301
  overall_time_4 = 0
302
  overall_time_5 = 0
303
+
304
  for i in range(50):
305
  t2 = time.time()
306
  painted_image_00 = mask_painter_wo_gaussian(input_image, input_mask, background_alpha, background_blur_radius, contour_width, contour_color, contour_alpha, mode='00')