Tobias Cornille commited on
Commit
17d77a8
1 Parent(s): 391271a

Make more robust + fix segments annotations

Browse files
Files changed (1) hide show
  1. app.py +105 -80
app.py CHANGED
@@ -110,7 +110,7 @@ def dino_detection(
110
  visualization = Image.fromarray(annotated_frame)
111
  return boxes, category_ids, visualization
112
  else:
113
- return boxes, category_ids
114
 
115
 
116
  def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
@@ -156,13 +156,16 @@ def clipseg_segmentation(
156
  ).to(device)
157
  with torch.no_grad():
158
  outputs = model(**inputs)
 
 
 
159
  # resize the outputs
160
- logits = nn.functional.interpolate(
161
- outputs.logits.unsqueeze(1),
162
  size=(image.size[1], image.size[0]),
163
  mode="bilinear",
164
  )
165
- preds = torch.sigmoid(logits.squeeze())
166
  semantic_inds = preds_to_semantic_inds(preds, background_threshold)
167
  return preds, semantic_inds
168
 
@@ -195,7 +198,7 @@ def clip_and_shrink_preds(semantic_inds, preds, shrink_kernel_size, num_categori
195
  torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
196
  ]
197
  max_size = max(sizes)
198
- relative_sizes = [size / max_size for size in sizes]
199
 
200
  # use bool masks to clip preds
201
  clipped_preds = torch.zeros_like(preds)
@@ -240,7 +243,7 @@ def upsample_pred(pred, image_source):
240
  else:
241
  target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
242
  upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
243
- return upsampled_tensor.squeeze()
244
 
245
 
246
  def sam_mask_from_points(predictor, image_array, points):
@@ -262,26 +265,30 @@ def sam_mask_from_points(predictor, image_array, points):
262
 
263
 
264
  def inds_to_segments_format(
265
- panoptic_inds, thing_category_ids, stuff_category_ids, output_file_path
266
  ):
267
  panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
268
  bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
269
- with open(output_file_path, "wb") as output_file:
270
- output_file.write(bitmap_file.read())
 
 
 
 
271
 
272
  unique_inds = np.unique(panoptic_inds_array)
273
  stuff_annotations = [
274
- {"id": i + 1, "category_id": stuff_category_id}
275
- for i, stuff_category_id in enumerate(stuff_category_ids)
276
  if i in unique_inds
277
  ]
278
  thing_annotations = [
279
- {"id": len(stuff_category_ids) + 1 + i, "category_id": thing_category_id}
280
  for i, thing_category_id in enumerate(thing_category_ids)
281
  ]
282
  annotations = stuff_annotations + thing_annotations
283
 
284
- return annotations
285
 
286
 
287
  def generate_panoptic_mask(
@@ -295,7 +302,7 @@ def generate_panoptic_mask(
295
  num_samples_factor=1000,
296
  task_attributes_json="",
297
  ):
298
- if task_attributes_json is not "":
299
  task_attributes = json.loads(task_attributes_json)
300
  categories = task_attributes["categories"]
301
  category_name_to_id = {
@@ -334,67 +341,89 @@ def generate_panoptic_mask(
334
  image = image.convert("RGB")
335
  image_array = np.asarray(image)
336
 
337
- # detect boxes for "thing" categories using Grounding DINO
338
- thing_boxes, thing_category_ids = dino_detection(
339
- dino_model,
340
- image,
341
- image_array,
342
- thing_category_names,
343
- category_name_to_id,
344
- dino_box_threshold,
345
- dino_text_threshold,
346
- device,
347
- )
348
  # compute SAM image embedding
349
  sam_predictor.set_image(image_array)
350
- # get segmentation masks for the thing boxes
351
- thing_masks = sam_masks_from_dino_boxes(
352
- sam_predictor, image_array, thing_boxes, device
353
- )
354
- # get rough segmentation masks for "stuff" categories using CLIPSeg
355
- clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
356
- clipseg_processor,
357
- clipseg_model,
358
- image,
359
- stuff_category_names,
360
- segmentation_background_threshold,
361
- device,
362
- )
363
- # remove things from stuff masks
364
- combined_things_mask = torch.any(thing_masks, dim=0)
365
- clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
366
- clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
367
- # clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
368
- # also returns the relative size of each category
369
- clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
370
- clipseg_semantic_inds_without_things,
371
- clipseg_preds,
372
- shrink_kernel_size,
373
- len(stuff_category_names) + 1,
374
- )
375
- # get finer segmentation masks for the "stuff" categories using SAM
376
- sam_preds = torch.zeros_like(clipsed_clipped_preds)
377
- for i in range(clipsed_clipped_preds.shape[0]):
378
- clipseg_pred = clipsed_clipped_preds[i]
379
- # for each "stuff" category, sample points in the rough segmentation mask
380
- num_samples = int(relative_sizes[i] * num_samples_factor)
381
- if num_samples == 0:
382
- continue
383
- points = sample_points_based_on_preds(clipseg_pred.cpu().numpy(), num_samples)
384
- if len(points) == 0:
385
- continue
386
- # use SAM to get mask for points
387
- pred = sam_mask_from_points(sam_predictor, image_array, points)
388
- sam_preds[i] = pred
389
- sam_semantic_inds = preds_to_semantic_inds(
390
- sam_preds, segmentation_background_threshold
391
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  # combine the thing inds and the stuff inds into panoptic inds
393
- panoptic_inds = sam_semantic_inds.clone()
 
 
 
 
394
  ind = len(stuff_category_names) + 1
395
  for thing_mask in thing_masks:
396
  # overlay thing mask on panoptic inds
397
- panoptic_inds[thing_mask.squeeze()] = ind
398
  ind += 1
399
 
400
  panoptic_bool_masks = (
@@ -403,23 +432,19 @@ def generate_panoptic_mask(
403
  .astype(int)
404
  )
405
  panoptic_names = (
406
- ["background"]
407
- + stuff_category_names
408
- + [category_names[category_id] for category_id in thing_category_ids]
409
  )
410
  subsection_label_pairs = [
411
  (panoptic_bool_masks[i], panoptic_name)
412
  for i, panoptic_name in enumerate(panoptic_names)
413
  ]
414
 
415
- output_file_path = "output_segmentation_bitmap.png"
416
- stuff_category_ids = [category_name_to_id[name] for name in stuff_category_names]
417
- annotations = inds_to_segments_format(
418
- panoptic_inds, thing_category_ids, stuff_category_ids, output_file_path
419
  )
420
  annotations_json = json.dumps(annotations)
421
 
422
- return (image_array, subsection_label_pairs), output_file_path, annotations_json
423
 
424
 
425
  config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
@@ -497,7 +522,7 @@ if __name__ == "__main__":
497
  step=0.001,
498
  )
499
  segmentation_background_threshold = gr.Slider(
500
- label="Segmentation background threshold (under this threshold, a pixel is considered background)",
501
  minimum=0.0,
502
  maximum=1.0,
503
  value=0.1,
@@ -529,11 +554,11 @@ if __name__ == "__main__":
529
  The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
530
  The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
531
  Unlabeled regions have a value of 0.
532
- Because of the large dynamic range, these png images may appear black in an image viewer.
533
  """
534
  )
535
  segmentation_bitmap = gr.Image(
536
- type="filepath", label="Segmentation bitmap"
537
  )
538
  annotations_json = gr.Textbox(
539
  label="Annotations JSON",
 
110
  visualization = Image.fromarray(annotated_frame)
111
  return boxes, category_ids, visualization
112
  else:
113
+ return boxes, category_ids, phrases
114
 
115
 
116
  def sam_masks_from_dino_boxes(predictor, image_array, boxes, device):
 
156
  ).to(device)
157
  with torch.no_grad():
158
  outputs = model(**inputs)
159
+ logits = outputs.logits
160
+ if len(logits.shape) == 2:
161
+ logits = logits.unsqueeze(0)
162
  # resize the outputs
163
+ upscaled_logits = nn.functional.interpolate(
164
+ logits.unsqueeze(1),
165
  size=(image.size[1], image.size[0]),
166
  mode="bilinear",
167
  )
168
+ preds = torch.sigmoid(upscaled_logits.squeeze(dim=1))
169
  semantic_inds = preds_to_semantic_inds(preds, background_threshold)
170
  return preds, semantic_inds
171
 
 
198
  torch.sum(bool_masks[i].int()).item() for i in range(1, bool_masks.size(0))
199
  ]
200
  max_size = max(sizes)
201
+ relative_sizes = [size / max_size for size in sizes] if max_size > 0 else sizes
202
 
203
  # use bool masks to clip preds
204
  clipped_preds = torch.zeros_like(preds)
 
243
  else:
244
  target_height = int(upsampled_tensor.shape[2] * aspect_ratio)
245
  upsampled_tensor = upsampled_tensor[:, :, :target_height, :]
246
+ return upsampled_tensor.squeeze(dim=1)
247
 
248
 
249
  def sam_mask_from_points(predictor, image_array, points):
 
265
 
266
 
267
  def inds_to_segments_format(
268
+ panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
269
  ):
270
  panoptic_inds_array = panoptic_inds.numpy().astype(np.uint32)
271
  bitmap_file = bitmap2file(panoptic_inds_array, is_segmentation_bitmap=True)
272
+ segmentation_bitmap = Image.open(bitmap_file)
273
+
274
+ stuff_category_ids = [
275
+ category_name_to_id[stuff_category_name]
276
+ for stuff_category_name in stuff_category_names
277
+ ]
278
 
279
  unique_inds = np.unique(panoptic_inds_array)
280
  stuff_annotations = [
281
+ {"id": i, "category_id": stuff_category_ids[i - 1]}
282
+ for i in range(1, len(stuff_category_names) + 1)
283
  if i in unique_inds
284
  ]
285
  thing_annotations = [
286
+ {"id": len(stuff_category_names) + 1 + i, "category_id": thing_category_id}
287
  for i, thing_category_id in enumerate(thing_category_ids)
288
  ]
289
  annotations = stuff_annotations + thing_annotations
290
 
291
+ return segmentation_bitmap, annotations
292
 
293
 
294
  def generate_panoptic_mask(
 
302
  num_samples_factor=1000,
303
  task_attributes_json="",
304
  ):
305
+ if task_attributes_json != "":
306
  task_attributes = json.loads(task_attributes_json)
307
  categories = task_attributes["categories"]
308
  category_name_to_id = {
 
341
  image = image.convert("RGB")
342
  image_array = np.asarray(image)
343
 
 
 
 
 
 
 
 
 
 
 
 
344
  # compute SAM image embedding
345
  sam_predictor.set_image(image_array)
346
+
347
+ # detect boxes for "thing" categories using Grounding DINO
348
+ thing_category_ids = []
349
+ thing_masks = []
350
+ thing_boxes = []
351
+ detected_thing_category_names = []
352
+ if len(thing_category_names) > 0:
353
+ thing_boxes, thing_category_ids, detected_thing_category_names = dino_detection(
354
+ dino_model,
355
+ image,
356
+ image_array,
357
+ thing_category_names,
358
+ category_name_to_id,
359
+ dino_box_threshold,
360
+ dino_text_threshold,
361
+ device,
362
+ )
363
+ if len(thing_boxes) > 0:
364
+ # get segmentation masks for the thing boxes
365
+ thing_masks = sam_masks_from_dino_boxes(
366
+ sam_predictor, image_array, thing_boxes, device
367
+ )
368
+ detected_stuff_category_names = []
369
+ if len(stuff_category_names) > 0:
370
+ # get rough segmentation masks for "stuff" categories using CLIPSeg
371
+ clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(
372
+ clipseg_processor,
373
+ clipseg_model,
374
+ image,
375
+ stuff_category_names,
376
+ segmentation_background_threshold,
377
+ device,
378
+ )
379
+ # remove things from stuff masks
380
+ clipseg_semantic_inds_without_things = clipseg_semantic_inds.clone()
381
+ if len(thing_boxes) > 0:
382
+ combined_things_mask = torch.any(thing_masks, dim=0)
383
+ clipseg_semantic_inds_without_things[combined_things_mask[0]] = 0
384
+ # clip CLIPSeg preds based on non-overlapping semantic segmentation inds (+ optionally shrink the mask of each category)
385
+ # also returns the relative size of each category
386
+ clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(
387
+ clipseg_semantic_inds_without_things,
388
+ clipseg_preds,
389
+ shrink_kernel_size,
390
+ len(stuff_category_names) + 1,
391
+ )
392
+ # get finer segmentation masks for the "stuff" categories using SAM
393
+ sam_preds = torch.zeros_like(clipsed_clipped_preds)
394
+ for i in range(clipsed_clipped_preds.shape[0]):
395
+ clipseg_pred = clipsed_clipped_preds[i]
396
+ # for each "stuff" category, sample points in the rough segmentation mask
397
+ num_samples = int(relative_sizes[i] * num_samples_factor)
398
+ if num_samples == 0:
399
+ continue
400
+ points = sample_points_based_on_preds(
401
+ clipseg_pred.cpu().numpy(), num_samples
402
+ )
403
+ if len(points) == 0:
404
+ continue
405
+ # use SAM to get mask for points
406
+ pred = sam_mask_from_points(sam_predictor, image_array, points)
407
+ sam_preds[i] = pred
408
+ sam_semantic_inds = preds_to_semantic_inds(
409
+ sam_preds, segmentation_background_threshold
410
+ )
411
+ detected_stuff_category_names = [
412
+ category_name
413
+ for i, category_name in enumerate(category_names)
414
+ if i + 1 in np.unique(sam_semantic_inds.numpy())
415
+ ]
416
+
417
  # combine the thing inds and the stuff inds into panoptic inds
418
+ panoptic_inds = (
419
+ sam_semantic_inds.clone()
420
+ if len(stuff_category_names) > 0
421
+ else torch.zeros(image_array.shape[0], image_array.shape[1], dtype=torch.long)
422
+ )
423
  ind = len(stuff_category_names) + 1
424
  for thing_mask in thing_masks:
425
  # overlay thing mask on panoptic inds
426
+ panoptic_inds[thing_mask.squeeze(dim=0)] = ind
427
  ind += 1
428
 
429
  panoptic_bool_masks = (
 
432
  .astype(int)
433
  )
434
  panoptic_names = (
435
+ ["unlabeled"] + detected_stuff_category_names + detected_thing_category_names
 
 
436
  )
437
  subsection_label_pairs = [
438
  (panoptic_bool_masks[i], panoptic_name)
439
  for i, panoptic_name in enumerate(panoptic_names)
440
  ]
441
 
442
+ segmentation_bitmap, annotations = inds_to_segments_format(
443
+ panoptic_inds, thing_category_ids, stuff_category_names, category_name_to_id
 
 
444
  )
445
  annotations_json = json.dumps(annotations)
446
 
447
+ return (image_array, subsection_label_pairs), segmentation_bitmap, annotations_json
448
 
449
 
450
  config_file = "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
 
522
  step=0.001,
523
  )
524
  segmentation_background_threshold = gr.Slider(
525
+ label="Segmentation background threshold (under this threshold, a pixel is considered background/unlabeled)",
526
  minimum=0.0,
527
  maximum=1.0,
528
  value=0.1,
 
554
  The segmentation bitmap is a 32-bit RGBA png image which contains the segmentation masks.
555
  The alpha channel is set to 255, and the remaining 24-bit values in the RGB channels correspond to the object ids in the annotations list.
556
  Unlabeled regions have a value of 0.
557
+ Because of the large dynamic range, the segmentation bitmap appears black in the image viewer.
558
  """
559
  )
560
  segmentation_bitmap = gr.Image(
561
+ type="pil", label="Segmentation bitmap"
562
  )
563
  annotations_json = gr.Textbox(
564
  label="Annotations JSON",