SauravMaheshkar commited on
Commit
6038d30
1 Parent(s): 8a9151b

feat: default to mask generation when no annotations are provided

Browse files
Files changed (2) hide show
  1. app.py +28 -21
  2. src/plot_utils.py +68 -47
app.py CHANGED
@@ -5,6 +5,7 @@ import gradio as gr
5
  import numpy as np
6
  from gradio_image_annotation import image_annotator
7
  from sam2 import load_model
 
8
  from sam2.sam2_image_predictor import SAM2ImagePredictor
9
 
10
  from src.plot_utils import export_mask
@@ -17,30 +18,36 @@ def predict(model_choice, annotations: Dict[str, Any]):
17
  ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
18
  device="cpu",
19
  )
20
- predictor = SAM2ImagePredictor(sam2_model) # type:ignore
21
- predictor.set_image(annotations["image"])
22
- coordinates = []
23
- for i in range(len(annotations["boxes"])):
24
- coordinate = [
25
- int(annotations["boxes"][i]["xmin"]),
26
- int(annotations["boxes"][i]["ymin"]),
27
- int(annotations["boxes"][i]["xmax"]),
28
- int(annotations["boxes"][i]["ymax"]),
29
- ]
30
- coordinates.append(coordinate)
 
31
 
32
- masks, scores, _ = predictor.predict(
33
- point_coords=None,
34
- point_labels=None,
35
- box=np.array(coordinates),
36
- multimask_output=False,
37
- )
 
 
 
 
38
 
39
- if masks.shape[0] == 1:
40
- # handle single mask cases
41
- masks = np.expand_dims(masks, axis=0)
42
 
43
- return export_mask(masks)
 
 
 
44
 
45
 
46
  with gr.Blocks(delete_cache=(30, 30)) as demo:
 
5
  import numpy as np
6
  from gradio_image_annotation import image_annotator
7
  from sam2 import load_model
8
+ from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
11
  from src.plot_utils import export_mask
 
18
  ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
19
  device="cpu",
20
  )
21
+ if annotations["boxes"]:
22
+ predictor = SAM2ImagePredictor(sam2_model) # type:ignore
23
+ predictor.set_image(annotations["image"])
24
+ coordinates = []
25
+ for i in range(len(annotations["boxes"])):
26
+ coordinate = [
27
+ int(annotations["boxes"][i]["xmin"]),
28
+ int(annotations["boxes"][i]["ymin"]),
29
+ int(annotations["boxes"][i]["xmax"]),
30
+ int(annotations["boxes"][i]["ymax"]),
31
+ ]
32
+ coordinates.append(coordinate)
33
 
34
+ masks, scores, _ = predictor.predict(
35
+ point_coords=None,
36
+ point_labels=None,
37
+ box=np.array(coordinates),
38
+ multimask_output=False,
39
+ )
40
+
41
+ if masks.shape[0] == 1:
42
+ # handle single mask cases
43
+ masks = np.expand_dims(masks, axis=0)
44
 
45
+ return export_mask(masks)
 
 
46
 
47
+ else:
48
+ mask_generator = SAM2AutomaticMaskGenerator(sam2_model) # type: ignore
49
+ masks = mask_generator.generate(annotations["image"])
50
+ return export_mask(masks, autogenerated=True)
51
 
52
 
53
  with gr.Blocks(delete_cache=(30, 30)) as demo:
src/plot_utils.py CHANGED
@@ -1,68 +1,89 @@
1
  from typing import Optional
2
 
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
 
7
  def export_mask(
8
- masks: np.ndarray,
 
9
  random_color: Optional[bool] = True,
10
  smoothen_contours: Optional[bool] = True,
11
  ) -> Image:
12
- num_masks, _, h, w = masks.shape
13
- num_masks = len(masks)
 
14
 
15
- # Ensure masks are 2D by squeezing channel dimension
16
- masks = masks.squeeze(axis=1)
17
 
18
- # Create a single uint8 image with unique values for each mask
19
- combined_mask = np.zeros((h, w), dtype=np.uint8)
20
 
21
- for i in range(num_masks):
22
- mask = masks[i]
23
- mask = mask.astype(np.uint8)
24
- combined_mask[mask > 0] = i + 1
25
 
26
- # Create color map for visualization
27
- if random_color:
28
- colors = np.random.rand(num_masks, 3) # Random colors for each mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  else:
30
- colors = np.array(
31
- [[30 / 255, 144 / 255, 255 / 255]] * num_masks
32
- ) # Use fixed color
 
33
 
34
- # Create an RGB image where each mask has its own color
35
- color_image = np.zeros((h, w, 3), dtype=np.uint8)
 
 
36
 
37
- for i in range(1, num_masks + 1):
38
- mask_color = colors[i - 1] * 255
39
- color_image[combined_mask == i] = mask_color
40
 
41
- # Convert the NumPy array to a PIL Image
42
- pil_image = Image.fromarray(color_image)
43
 
44
- # Optional: Add contours to the mask image
45
- if smoothen_contours:
46
- import cv2
47
 
48
- contours_image = np.zeros((h, w, 4), dtype=np.float32)
49
 
50
- for i in range(1, num_masks + 1):
51
- mask = (combined_mask == i).astype(np.uint8)
52
- contours, _ = cv2.findContours(
53
- mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
54
- )
55
- contours = [
56
- cv2.approxPolyDP(contour, epsilon=0.01, closed=True)
57
- for contour in contours
58
- ]
59
- contours_image = cv2.drawContours(
60
- contours_image, contours, -1, (0, 0, 0, 0.5), thickness=2
61
- )
62
-
63
- # Convert contours to PIL image and blend with the color image
64
- contours_image = (contours_image[:, :, :3] * 255).astype(np.uint8)
65
- contours_pil_image = Image.fromarray(contours_image)
66
- pil_image = Image.blend(pil_image, contours_pil_image, alpha=0.6)
67
-
68
- return pil_image
 
1
  from typing import Optional
2
 
3
+ import cv2
4
  import numpy as np
5
  from PIL import Image
6
 
7
 
8
  def export_mask(
9
+ masks,
10
+ autogenerated: Optional[bool] = False,
11
  random_color: Optional[bool] = True,
12
  smoothen_contours: Optional[bool] = True,
13
  ) -> Image:
14
+ if not autogenerated:
15
+ num_masks, _, h, w = masks.shape
16
+ num_masks = len(masks)
17
 
18
+ # Ensure masks are 2D by squeezing channel dimension
19
+ masks = masks.squeeze(axis=1)
20
 
21
+ # Create a single uint8 image with unique values for each mask
22
+ combined_mask = np.zeros((h, w), dtype=np.uint8)
23
 
24
+ for i in range(num_masks):
25
+ mask = masks[i]
26
+ mask = mask.astype(np.uint8)
27
+ combined_mask[mask > 0] = i + 1
28
 
29
+ # Create color map for visualization
30
+ if random_color:
31
+ colors = np.random.rand(num_masks, 3) # Random colors for each mask
32
+ else:
33
+ colors = np.array(
34
+ [[30 / 255, 144 / 255, 255 / 255]] * num_masks
35
+ ) # Use fixed color
36
+
37
+ # Create an RGB image where each mask has its own color
38
+ color_image = np.zeros((h, w, 3), dtype=np.uint8)
39
+
40
+ for i in range(1, num_masks + 1):
41
+ mask_color = colors[i - 1] * 255
42
+ color_image[combined_mask == i] = mask_color
43
+
44
+ # Convert the NumPy array to a PIL Image
45
+ pil_image = Image.fromarray(color_image)
46
+
47
+ # Optional: Add contours to the mask image
48
+ if smoothen_contours:
49
+ contours_image = np.zeros((h, w, 4), dtype=np.float32)
50
+
51
+ for i in range(1, num_masks + 1):
52
+ mask = (combined_mask == i).astype(np.uint8)
53
+ contours_image = smoothen(mask, contours_image)
54
+
55
+ # Convert contours to PIL image and blend with the color image
56
+ contours_image = (contours_image[:, :, :3] * 255).astype(np.uint8)
57
+ contours_pil_image = Image.fromarray(contours_image)
58
+ pil_image = Image.blend(pil_image, contours_pil_image, alpha=0.6)
59
+
60
+ return pil_image
61
  else:
62
+ sorted_anns = sorted(masks, key=(lambda x: x["area"]), reverse=True)
63
+ img_shape = sorted_anns[0]["segmentation"].shape
64
+ img = np.ones((img_shape[0], img_shape[1], 4))
65
+ img[:, :, 3] = 0
66
 
67
+ for ann in sorted_anns:
68
+ m = ann["segmentation"]
69
+ color_mask = np.concatenate([np.random.random(3), [0.5]])
70
+ img[m] = color_mask
71
 
72
+ if smoothen_contours:
73
+ img = smoothen(m, img)
 
74
 
75
+ img = (img * 255).astype(np.uint8)
76
+ pil_image = Image.fromarray(img)
77
 
78
+ return pil_image
 
 
79
 
 
80
 
81
+ def smoothen(mask: np.ndarray, image: np.ndarray) -> np.ndarray:
82
+ contours, _ = cv2.findContours(
83
+ mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
84
+ )
85
+ contours = [
86
+ cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
87
+ ]
88
+ image = cv2.drawContours(image, contours, -1, (0, 0, 1, 0.4), thickness=1)
89
+ return image