SauravMaheshkar commited on
Commit
410698b
1 Parent(s): 95190fc

feat: display masks as a single image

Browse files
Files changed (3) hide show
  1. app.py +15 -29
  2. requirements.txt +1 -1
  3. src/plot_utils.py +48 -30
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import pathlib
2
- import zipfile
3
  from typing import Any, Dict, List
4
 
5
  import cv2
@@ -7,24 +5,20 @@ import gradio as gr
7
  import numpy as np
8
  import torch
9
  from gradio_image_annotation import image_annotator
10
- from sam2.build_sam import build_sam2
11
  from sam2.sam2_image_predictor import SAM2ImagePredictor
12
 
13
- from src.plot_utils import render_masks
14
-
15
- choice_mapping: Dict[str, List[str]] = {
16
- "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
17
- "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
18
- "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
19
- "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
20
- }
21
 
22
 
23
  def predict(model_choice, annotations: Dict[str, Any]):
24
- config_file, ckpt_path = choice_mapping[str(model_choice)]
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
- sam2_model = build_sam2(config_file, ckpt_path, device=device)
27
- predictor = SAM2ImagePredictor(sam2_model)
 
 
 
 
28
  predictor.set_image(annotations["image"])
29
  coordinates = []
30
  for i in range(len(annotations["boxes"])):
@@ -42,19 +36,12 @@ def predict(model_choice, annotations: Dict[str, Any]):
42
  box=np.array(coordinates),
43
  multimask_output=False,
44
  )
45
- for count, mask in enumerate(masks):
46
- mask = mask.transpose(1, 2, 0) # type:ignore
47
- mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format
48
- cv2.imwrite(f"assets/mask_{count}.png", mask_image)
49
- mask_dir = pathlib.Path("assets/")
50
- with zipfile.ZipFile("assets/masks.zip", "w") as archive:
51
- for mask_file in mask_dir.glob("mask_*.png"):
52
- archive.write(mask_file, arcname=mask_file.relative_to(mask_dir))
53
 
54
- return [
55
- render_masks(annotations["image"], masks),
56
- gr.DownloadButton("Download Mask(s)", value="assets/masks.zip", visible=True),
57
- ]
 
58
 
59
 
60
  with gr.Blocks(delete_cache=(30, 30)) as demo:
@@ -83,9 +70,8 @@ with gr.Blocks(delete_cache=(30, 30)) as demo:
83
  label="Draw a bounding box",
84
  )
85
  btn = gr.Button("Get Segmentation Mask(s)")
86
- download_btn = gr.DownloadButton(
87
- "Download Mask(s)", value="assets/masks.zip", visible=False
88
  )
89
- btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn])
90
 
91
  demo.launch()
 
 
 
1
  from typing import Any, Dict, List
2
 
3
  import cv2
 
5
  import numpy as np
6
  import torch
7
  from gradio_image_annotation import image_annotator
8
+ from sam2 import load_model
9
  from sam2.sam2_image_predictor import SAM2ImagePredictor
10
 
11
+ from src.plot_utils import export_mask
 
 
 
 
 
 
 
12
 
13
 
14
  def predict(model_choice, annotations: Dict[str, Any]):
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ sam2_model = load_model(
17
+ variant=model_choice,
18
+ ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
19
+ device=device,
20
+ )
21
+ predictor = SAM2ImagePredictor(sam2_model) # type:ignore
22
  predictor.set_image(annotations["image"])
23
  coordinates = []
24
  for i in range(len(annotations["boxes"])):
 
36
  box=np.array(coordinates),
37
  multimask_output=False,
38
  )
 
 
 
 
 
 
 
 
39
 
40
+ if masks.shape[0] == 1:
41
+ # handle single mask cases
42
+ masks = np.expand_dims(masks, axis=0)
43
+
44
+ return export_mask(masks)
45
 
46
 
47
  with gr.Blocks(delete_cache=(30, 30)) as demo:
 
70
  label="Draw a bounding box",
71
  )
72
  btn = gr.Button("Get Segmentation Mask(s)")
73
+ btn.click(
74
+ fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
75
  )
 
76
 
77
  demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
 
1
  gradio
2
  gradio_image_annotation
3
  opencv-python
4
- samv2
5
  spaces
 
1
+ git+https://github.com/SauravMaheshkar/samv2.git
2
  gradio
3
  gradio_image_annotation
4
  opencv-python
 
5
  spaces
src/plot_utils.py CHANGED
@@ -1,50 +1,68 @@
1
  from typing import Optional
2
 
3
- import matplotlib.pyplot as plt
4
  import numpy as np
5
- from matplotlib.pyplot import Figure
6
 
7
 
8
- def render_masks(
9
- image,
10
- masks,
11
  random_color: Optional[bool] = True,
12
  smoothen_contours: Optional[bool] = True,
13
- ) -> "Figure":
14
- h, w = image.shape[:2]
15
- fig, ax = plt.subplots(figsize=(w / 100, h / 100), dpi=100)
16
- ax.axis("off")
17
- ax.imshow(image)
18
-
19
- for mask in masks:
20
- if random_color:
21
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
22
- else:
23
- color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
24
 
 
 
 
 
 
 
 
 
25
  mask = mask.astype(np.uint8)
26
- mask = mask.reshape(h, w)
27
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
 
 
 
 
 
 
 
 
 
 
28
 
29
- if smoothen_contours:
30
- import cv2
 
31
 
 
 
 
 
 
 
 
 
 
 
 
32
  contours, _ = cv2.findContours(
33
- mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
34
  )
35
  contours = [
36
  cv2.approxPolyDP(contour, epsilon=0.01, closed=True)
37
  for contour in contours
38
  ]
39
- mask_image = cv2.drawContours(
40
- mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
41
  )
42
 
43
- ax.imshow(mask_image, alpha=0.6)
44
-
45
- # Make image occupy the whole figure
46
- ax.set_xlim(0, w)
47
- ax.set_ylim(h, 0)
48
- plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
49
 
50
- return fig
 
1
  from typing import Optional
2
 
 
3
  import numpy as np
4
+ from PIL import Image
5
 
6
 
7
+ def export_mask(
8
+ masks: np.ndarray,
 
9
  random_color: Optional[bool] = True,
10
  smoothen_contours: Optional[bool] = True,
11
+ ) -> Image:
12
+ num_masks, _, h, w = masks.shape
13
+ num_masks = len(masks)
 
 
 
 
 
 
 
 
14
 
15
+ # Ensure masks are 2D by squeezing channel dimension
16
+ masks = masks.squeeze(axis=1)
17
+
18
+ # Create a single uint8 image with unique values for each mask
19
+ combined_mask = np.zeros((h, w), dtype=np.uint8)
20
+
21
+ for i in range(num_masks):
22
+ mask = masks[i]
23
  mask = mask.astype(np.uint8)
24
+ combined_mask[mask > 0] = i + 1
25
+
26
+ # Create color map for visualization
27
+ if random_color:
28
+ colors = np.random.rand(num_masks, 3) # Random colors for each mask
29
+ else:
30
+ colors = np.array(
31
+ [[30 / 255, 144 / 255, 255 / 255]] * num_masks
32
+ ) # Use fixed color
33
+
34
+ # Create an RGB image where each mask has its own color
35
+ color_image = np.zeros((h, w, 3), dtype=np.uint8)
36
 
37
+ for i in range(1, num_masks + 1):
38
+ mask_color = colors[i - 1] * 255
39
+ color_image[combined_mask == i] = mask_color
40
 
41
+ # Convert the NumPy array to a PIL Image
42
+ pil_image = Image.fromarray(color_image)
43
+
44
+ # Optional: Add contours to the mask image
45
+ if smoothen_contours:
46
+ import cv2
47
+
48
+ contours_image = np.zeros((h, w, 4), dtype=np.float32)
49
+
50
+ for i in range(1, num_masks + 1):
51
+ mask = (combined_mask == i).astype(np.uint8)
52
  contours, _ = cv2.findContours(
53
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
54
  )
55
  contours = [
56
  cv2.approxPolyDP(contour, epsilon=0.01, closed=True)
57
  for contour in contours
58
  ]
59
+ contours_image = cv2.drawContours(
60
+ contours_image, contours, -1, (0, 0, 0, 0.5), thickness=2
61
  )
62
 
63
+ # Convert contours to PIL image and blend with the color image
64
+ contours_image = (contours_image[:, :, :3] * 255).astype(np.uint8)
65
+ contours_pil_image = Image.fromarray(contours_image)
66
+ pil_image = Image.blend(pil_image, contours_pil_image, alpha=0.6)
 
 
67
 
68
+ return pil_image