SauravMaheshkar
commited on
Commit
•
410698b
1
Parent(s):
95190fc
feat: display masks as a single image
Browse files- app.py +15 -29
- requirements.txt +1 -1
- src/plot_utils.py +48 -30
app.py
CHANGED
@@ -1,5 +1,3 @@
|
|
1 |
-
import pathlib
|
2 |
-
import zipfile
|
3 |
from typing import Any, Dict, List
|
4 |
|
5 |
import cv2
|
@@ -7,24 +5,20 @@ import gradio as gr
|
|
7 |
import numpy as np
|
8 |
import torch
|
9 |
from gradio_image_annotation import image_annotator
|
10 |
-
from sam2
|
11 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
12 |
|
13 |
-
from src.plot_utils import
|
14 |
-
|
15 |
-
choice_mapping: Dict[str, List[str]] = {
|
16 |
-
"tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
|
17 |
-
"small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
|
18 |
-
"base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
|
19 |
-
"large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
|
20 |
-
}
|
21 |
|
22 |
|
23 |
def predict(model_choice, annotations: Dict[str, Any]):
|
24 |
-
config_file, ckpt_path = choice_mapping[str(model_choice)]
|
25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
-
sam2_model =
|
27 |
-
|
|
|
|
|
|
|
|
|
28 |
predictor.set_image(annotations["image"])
|
29 |
coordinates = []
|
30 |
for i in range(len(annotations["boxes"])):
|
@@ -42,19 +36,12 @@ def predict(model_choice, annotations: Dict[str, Any]):
|
|
42 |
box=np.array(coordinates),
|
43 |
multimask_output=False,
|
44 |
)
|
45 |
-
for count, mask in enumerate(masks):
|
46 |
-
mask = mask.transpose(1, 2, 0) # type:ignore
|
47 |
-
mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format
|
48 |
-
cv2.imwrite(f"assets/mask_{count}.png", mask_image)
|
49 |
-
mask_dir = pathlib.Path("assets/")
|
50 |
-
with zipfile.ZipFile("assets/masks.zip", "w") as archive:
|
51 |
-
for mask_file in mask_dir.glob("mask_*.png"):
|
52 |
-
archive.write(mask_file, arcname=mask_file.relative_to(mask_dir))
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
58 |
|
59 |
|
60 |
with gr.Blocks(delete_cache=(30, 30)) as demo:
|
@@ -83,9 +70,8 @@ with gr.Blocks(delete_cache=(30, 30)) as demo:
|
|
83 |
label="Draw a bounding box",
|
84 |
)
|
85 |
btn = gr.Button("Get Segmentation Mask(s)")
|
86 |
-
|
87 |
-
"
|
88 |
)
|
89 |
-
btn.click(fn=predict, inputs=[model, annotator], outputs=[gr.Plot(), download_btn])
|
90 |
|
91 |
demo.launch()
|
|
|
|
|
|
|
1 |
from typing import Any, Dict, List
|
2 |
|
3 |
import cv2
|
|
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
from gradio_image_annotation import image_annotator
|
8 |
+
from sam2 import load_model
|
9 |
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
10 |
|
11 |
+
from src.plot_utils import export_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
|
14 |
def predict(model_choice, annotations: Dict[str, Any]):
|
|
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
sam2_model = load_model(
|
17 |
+
variant=model_choice,
|
18 |
+
ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt",
|
19 |
+
device=device,
|
20 |
+
)
|
21 |
+
predictor = SAM2ImagePredictor(sam2_model) # type:ignore
|
22 |
predictor.set_image(annotations["image"])
|
23 |
coordinates = []
|
24 |
for i in range(len(annotations["boxes"])):
|
|
|
36 |
box=np.array(coordinates),
|
37 |
multimask_output=False,
|
38 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
+
if masks.shape[0] == 1:
|
41 |
+
# handle single mask cases
|
42 |
+
masks = np.expand_dims(masks, axis=0)
|
43 |
+
|
44 |
+
return export_mask(masks)
|
45 |
|
46 |
|
47 |
with gr.Blocks(delete_cache=(30, 30)) as demo:
|
|
|
70 |
label="Draw a bounding box",
|
71 |
)
|
72 |
btn = gr.Button("Get Segmentation Mask(s)")
|
73 |
+
btn.click(
|
74 |
+
fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")]
|
75 |
)
|
|
|
76 |
|
77 |
demo.launch()
|
requirements.txt
CHANGED
@@ -1,5 +1,5 @@
|
|
|
|
1 |
gradio
|
2 |
gradio_image_annotation
|
3 |
opencv-python
|
4 |
-
samv2
|
5 |
spaces
|
|
|
1 |
+
git+https://github.com/SauravMaheshkar/samv2.git
|
2 |
gradio
|
3 |
gradio_image_annotation
|
4 |
opencv-python
|
|
|
5 |
spaces
|
src/plot_utils.py
CHANGED
@@ -1,50 +1,68 @@
|
|
1 |
from typing import Optional
|
2 |
|
3 |
-
import matplotlib.pyplot as plt
|
4 |
import numpy as np
|
5 |
-
from
|
6 |
|
7 |
|
8 |
-
def
|
9 |
-
|
10 |
-
masks,
|
11 |
random_color: Optional[bool] = True,
|
12 |
smoothen_contours: Optional[bool] = True,
|
13 |
-
) ->
|
14 |
-
h, w =
|
15 |
-
|
16 |
-
ax.axis("off")
|
17 |
-
ax.imshow(image)
|
18 |
-
|
19 |
-
for mask in masks:
|
20 |
-
if random_color:
|
21 |
-
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
22 |
-
else:
|
23 |
-
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
mask = mask.astype(np.uint8)
|
26 |
-
mask =
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
|
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
contours, _ = cv2.findContours(
|
33 |
-
mask, cv2.RETR_EXTERNAL, cv2.
|
34 |
)
|
35 |
contours = [
|
36 |
cv2.approxPolyDP(contour, epsilon=0.01, closed=True)
|
37 |
for contour in contours
|
38 |
]
|
39 |
-
|
40 |
-
|
41 |
)
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
ax.set_ylim(h, 0)
|
48 |
-
plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
|
49 |
|
50 |
-
return
|
|
|
1 |
from typing import Optional
|
2 |
|
|
|
3 |
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
|
6 |
|
7 |
+
def export_mask(
|
8 |
+
masks: np.ndarray,
|
|
|
9 |
random_color: Optional[bool] = True,
|
10 |
smoothen_contours: Optional[bool] = True,
|
11 |
+
) -> Image:
|
12 |
+
num_masks, _, h, w = masks.shape
|
13 |
+
num_masks = len(masks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
+
# Ensure masks are 2D by squeezing channel dimension
|
16 |
+
masks = masks.squeeze(axis=1)
|
17 |
+
|
18 |
+
# Create a single uint8 image with unique values for each mask
|
19 |
+
combined_mask = np.zeros((h, w), dtype=np.uint8)
|
20 |
+
|
21 |
+
for i in range(num_masks):
|
22 |
+
mask = masks[i]
|
23 |
mask = mask.astype(np.uint8)
|
24 |
+
combined_mask[mask > 0] = i + 1
|
25 |
+
|
26 |
+
# Create color map for visualization
|
27 |
+
if random_color:
|
28 |
+
colors = np.random.rand(num_masks, 3) # Random colors for each mask
|
29 |
+
else:
|
30 |
+
colors = np.array(
|
31 |
+
[[30 / 255, 144 / 255, 255 / 255]] * num_masks
|
32 |
+
) # Use fixed color
|
33 |
+
|
34 |
+
# Create an RGB image where each mask has its own color
|
35 |
+
color_image = np.zeros((h, w, 3), dtype=np.uint8)
|
36 |
|
37 |
+
for i in range(1, num_masks + 1):
|
38 |
+
mask_color = colors[i - 1] * 255
|
39 |
+
color_image[combined_mask == i] = mask_color
|
40 |
|
41 |
+
# Convert the NumPy array to a PIL Image
|
42 |
+
pil_image = Image.fromarray(color_image)
|
43 |
+
|
44 |
+
# Optional: Add contours to the mask image
|
45 |
+
if smoothen_contours:
|
46 |
+
import cv2
|
47 |
+
|
48 |
+
contours_image = np.zeros((h, w, 4), dtype=np.float32)
|
49 |
+
|
50 |
+
for i in range(1, num_masks + 1):
|
51 |
+
mask = (combined_mask == i).astype(np.uint8)
|
52 |
contours, _ = cv2.findContours(
|
53 |
+
mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
|
54 |
)
|
55 |
contours = [
|
56 |
cv2.approxPolyDP(contour, epsilon=0.01, closed=True)
|
57 |
for contour in contours
|
58 |
]
|
59 |
+
contours_image = cv2.drawContours(
|
60 |
+
contours_image, contours, -1, (0, 0, 0, 0.5), thickness=2
|
61 |
)
|
62 |
|
63 |
+
# Convert contours to PIL image and blend with the color image
|
64 |
+
contours_image = (contours_image[:, :, :3] * 255).astype(np.uint8)
|
65 |
+
contours_pil_image = Image.fromarray(contours_image)
|
66 |
+
pil_image = Image.blend(pil_image, contours_pil_image, alpha=0.6)
|
|
|
|
|
67 |
|
68 |
+
return pil_image
|