SauravMaheshkar commited on
Commit
8260e47
1 Parent(s): 630e69b

feat: add multi-masking support

Browse files
Files changed (2) hide show
  1. app.py +30 -22
  2. src/plot_utils.py +45 -85
app.py CHANGED
@@ -1,18 +1,17 @@
 
 
 
 
 
1
  import gradio as gr
2
  import numpy as np
3
- import cv2
4
  import torch
5
-
6
-
7
- from typing import Dict, Any, List
8
-
9
- from src.plot_utils import show_masks
10
  from gradio_image_annotation import image_annotator
11
-
12
-
13
  from sam2.build_sam import build_sam2
14
  from sam2.sam2_image_predictor import SAM2ImagePredictor
15
 
 
 
16
  choice_mapping: Dict[str, List[str]] = {
17
  "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
18
  "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
@@ -27,27 +26,34 @@ def predict(model_choice, annotations: Dict[str, Any]):
27
  sam2_model = build_sam2(config_file, ckpt_path, device=device)
28
  predictor = SAM2ImagePredictor(sam2_model)
29
  predictor.set_image(annotations["image"])
30
- coordinates = np.array(
31
- [
32
- int(annotations["boxes"][0]["xmin"]),
33
- int(annotations["boxes"][0]["ymin"]),
34
- int(annotations["boxes"][0]["xmax"]),
35
- int(annotations["boxes"][0]["ymax"]),
 
36
  ]
37
- )
 
38
  masks, scores, _ = predictor.predict(
39
  point_coords=None,
40
  point_labels=None,
41
- box=coordinates[None, :],
42
  multimask_output=False,
43
  )
44
- mask = masks.transpose(1, 2, 0)
45
- mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format
46
- cv2.imwrite("mask.png", mask_image)
 
 
 
 
 
47
 
48
  return [
49
- show_masks(annotations["image"], masks, scores, box_coords=coordinates),
50
- gr.DownloadButton("Download Mask", value="mask.png", visible=True),
51
  ]
52
 
53
 
@@ -77,7 +83,9 @@ with gr.Blocks(delete_cache=(30, 30)) as demo:
77
  label="Draw a bounding box",
78
  )
79
  btn = gr.Button("Get Segmentation Mask")
80
- download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False)
 
 
81
  btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn])
82
 
83
  demo.launch()
 
1
+ import pathlib
2
+ import zipfile
3
+ from typing import Any, Dict, List
4
+
5
+ import cv2
6
  import gradio as gr
7
  import numpy as np
 
8
  import torch
 
 
 
 
 
9
  from gradio_image_annotation import image_annotator
 
 
10
  from sam2.build_sam import build_sam2
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
 
13
+ from src.plot_utils import render_masks
14
+
15
  choice_mapping: Dict[str, List[str]] = {
16
  "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
17
  "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
 
26
  sam2_model = build_sam2(config_file, ckpt_path, device=device)
27
  predictor = SAM2ImagePredictor(sam2_model)
28
  predictor.set_image(annotations["image"])
29
+ coordinates = []
30
+ for i in range(len(annotations["boxes"])):
31
+ coordinate = [
32
+ int(annotations["boxes"][i]["xmin"]),
33
+ int(annotations["boxes"][i]["ymin"]),
34
+ int(annotations["boxes"][i]["xmax"]),
35
+ int(annotations["boxes"][i]["ymax"]),
36
  ]
37
+ coordinates.append(coordinate)
38
+
39
  masks, scores, _ = predictor.predict(
40
  point_coords=None,
41
  point_labels=None,
42
+ box=np.array(coordinates),
43
  multimask_output=False,
44
  )
45
+ for count, mask in enumerate(masks):
46
+ mask = mask.transpose(1, 2, 0) # type:ignore
47
+ mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format
48
+ cv2.imwrite(f"assets/mask_{count}.png", mask_image)
49
+ mask_dir = pathlib.Path("assets/")
50
+ with zipfile.ZipFile("assets/masks.zip", "w") as archive:
51
+ for mask_file in mask_dir.glob("mask_*.png"):
52
+ archive.write(mask_file, arcname=mask_file.relative_to(mask_dir))
53
 
54
  return [
55
+ render_masks(annotations["image"], masks),
56
+ gr.DownloadButton("Download Mask", value="assets/masks.zip", visible=True),
57
  ]
58
 
59
 
 
83
  label="Draw a bounding box",
84
  )
85
  btn = gr.Button("Get Segmentation Mask")
86
+ download_btn = gr.DownloadButton(
87
+ "Download Mask", value="assets/masks.zip", visible=False
88
+ )
89
  btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn])
90
 
91
  demo.launch()
src/plot_utils.py CHANGED
@@ -1,90 +1,50 @@
1
- import numpy as np
2
- import matplotlib.pyplot as plt
3
-
4
-
5
- def show_mask(mask, ax, random_color=False, borders=True):
6
- if random_color:
7
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
8
- else:
9
- color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
10
- h, w = mask.shape[-2:]
11
- mask = mask.astype(np.uint8)
12
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
13
- if borders:
14
- import cv2
15
-
16
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
17
- # Try to smooth contours
18
- contours = [
19
- cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
20
- ]
21
- mask_image = cv2.drawContours(
22
- mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
23
- )
24
- ax.imshow(mask_image)
25
-
26
-
27
- def show_points(coords, labels, ax, marker_size=375):
28
- pos_points = coords[labels == 1]
29
- neg_points = coords[labels == 0]
30
- ax.scatter(
31
- pos_points[:, 0],
32
- pos_points[:, 1],
33
- color="green",
34
- marker="*",
35
- s=marker_size,
36
- edgecolor="white",
37
- linewidth=1.25,
38
- )
39
- ax.scatter(
40
- neg_points[:, 0],
41
- neg_points[:, 1],
42
- color="red",
43
- marker="*",
44
- s=marker_size,
45
- edgecolor="white",
46
- linewidth=1.25,
47
- )
48
-
49
 
50
- def show_box(box, ax):
51
- x0, y0 = box[0], box[1]
52
- w, h = box[2] - box[0], box[3] - box[1]
53
- ax.add_patch(
54
- plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
55
- )
56
 
57
 
58
- def show_masks(
59
  image,
60
  masks,
61
- scores,
62
- point_coords=None,
63
- box_coords=None,
64
- input_labels=None,
65
- borders=True,
66
- ):
67
- num_masks = len(masks)
68
- num_cols = num_masks # Number of columns is equal to the number of masks
69
-
70
- fig, axes = plt.subplots(1, num_cols, figsize=(5 * num_cols, 5))
71
-
72
- if num_masks == 1:
73
- axes = [axes] # Ensure axes is iterable when there's only one mask
74
-
75
- for i, (mask, score) in enumerate(zip(masks, scores)):
76
- ax = axes[i]
77
-
78
- ax.imshow(image)
79
- show_mask(mask, ax, borders=borders)
80
- if point_coords is not None:
81
- assert input_labels is not None
82
- show_points(point_coords, input_labels, ax)
83
- if box_coords is not None:
84
- show_box(box_coords, ax)
85
- if len(scores) > 1:
86
- ax.set_title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
87
- ax.axis("off")
88
-
89
- plt.tight_layout()
90
- return plt
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from matplotlib.pyplot import Figure
 
 
 
6
 
7
 
8
+ def render_masks(
9
  image,
10
  masks,
11
+ random_color: Optional[bool] = True,
12
+ smoothen_contours: Optional[bool] = True,
13
+ ) -> "Figure":
14
+ h, w = image.shape[:2]
15
+ fig, ax = plt.subplots(figsize=(w / 100, h / 100), dpi=100)
16
+ ax.axis("off")
17
+ ax.imshow(image)
18
+
19
+ for mask in masks:
20
+ if random_color:
21
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
22
+ else:
23
+ color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
24
+
25
+ mask = mask.astype(np.uint8)
26
+ mask = mask.reshape(h, w)
27
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
28
+
29
+ if smoothen_contours:
30
+ import cv2
31
+
32
+ contours, _ = cv2.findContours(
33
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
34
+ )
35
+ contours = [
36
+ cv2.approxPolyDP(contour, epsilon=0.01, closed=True)
37
+ for contour in contours
38
+ ]
39
+ mask_image = cv2.drawContours(
40
+ mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
41
+ )
42
+
43
+ ax.imshow(mask_image, alpha=0.6)
44
+
45
+ # Make image occupy the whole figure
46
+ ax.set_xlim(0, w)
47
+ ax.set_ylim(h, 0)
48
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
49
+
50
+ return fig