Spaces:
Running
on
Zero
Running
on
Zero
REF: SAM2 AMG and the corresponding test case.
Browse files- SegmentAnything2AssistApp.py +20 -18
- src/SegmentAnything2Assist/SegmentAnything2Assist.py +32 -11
- test/test_module.py +36 -9
SegmentAnything2AssistApp.py
CHANGED
@@ -257,25 +257,27 @@ def generate_auto_mask(
|
|
257 |
if VERBOSE:
|
258 |
print("SegmentAnything2AssistApp::generate_auto_mask::Called.")
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
276 |
)
|
277 |
|
278 |
-
if len(
|
279 |
gradio.Warning(
|
280 |
"No masks generated, please tweak the advanced parameters.", duration=5
|
281 |
)
|
@@ -294,7 +296,7 @@ def generate_auto_mask(
|
|
294 |
),
|
295 |
)
|
296 |
else:
|
297 |
-
choices = [str(i) for i in range(len(
|
298 |
|
299 |
returning_image = __generate_auto_mask(
|
300 |
image, ["0"], output_mode, False, masks, bboxes
|
|
|
257 |
if VERBOSE:
|
258 |
print("SegmentAnything2AssistApp::generate_auto_mask::Called.")
|
259 |
|
260 |
+
masks, bboxes, predicted_iou, stability_score = (
|
261 |
+
segment_anything2assist.generate_automatic_masks(
|
262 |
+
image,
|
263 |
+
points_per_side,
|
264 |
+
points_per_batch,
|
265 |
+
pred_iou_thresh,
|
266 |
+
stability_score_thresh,
|
267 |
+
stability_score_offset,
|
268 |
+
mask_threshold,
|
269 |
+
box_nms_thresh,
|
270 |
+
crop_n_layers,
|
271 |
+
crop_nms_thresh,
|
272 |
+
crop_overlay_ratio,
|
273 |
+
crop_n_points_downscale_factor,
|
274 |
+
min_mask_region_area,
|
275 |
+
use_m2m,
|
276 |
+
multimask_output,
|
277 |
+
)
|
278 |
)
|
279 |
|
280 |
+
if len(masks) == 0:
|
281 |
gradio.Warning(
|
282 |
"No masks generated, please tweak the advanced parameters.", duration=5
|
283 |
)
|
|
|
296 |
),
|
297 |
)
|
298 |
else:
|
299 |
+
choices = [str(i) for i in range(len(masks))]
|
300 |
|
301 |
returning_image = __generate_auto_mask(
|
302 |
image, ["0"], output_mode, False, masks, bboxes
|
src/SegmentAnything2Assist/SegmentAnything2Assist.py
CHANGED
@@ -98,7 +98,7 @@ class SegmentAnything2Assist:
|
|
98 |
)
|
99 |
|
100 |
if download:
|
101 |
-
self.
|
102 |
|
103 |
if self.is_model_available():
|
104 |
self.sam2 = sam2.build_sam.build_sam2(
|
@@ -121,14 +121,14 @@ class SegmentAnything2Assist:
|
|
121 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
122 |
return ret
|
123 |
|
124 |
-
def
|
125 |
if self.is_model_available():
|
126 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
127 |
return True
|
128 |
|
129 |
return False
|
130 |
|
131 |
-
def
|
132 |
if not force and self.is_model_available():
|
133 |
print(f"{self.model_path} already exists. Skipping download.")
|
134 |
return False
|
@@ -162,7 +162,17 @@ class SegmentAnything2Assist:
|
|
162 |
min_mask_region_area=0,
|
163 |
use_m2m=False,
|
164 |
multimask_output=True,
|
165 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
if self.sam2 is None:
|
167 |
print(
|
168 |
"SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
|
@@ -196,8 +206,15 @@ class SegmentAnything2Assist:
|
|
196 |
cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
|
197 |
]
|
198 |
bbox_masks = [mask["bbox"] for mask in masks]
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
def generate_masks_from_image(
|
203 |
self,
|
@@ -208,7 +225,15 @@ class SegmentAnything2Assist:
|
|
208 |
mask_threshold=0.0,
|
209 |
max_hole_area=0.0,
|
210 |
max_sprinkle_area=0.0,
|
211 |
-
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
|
213 |
self.sam2,
|
214 |
mask_threshold=mask_threshold,
|
@@ -240,8 +265,6 @@ class SegmentAnything2Assist:
|
|
240 |
image_with_bounding_boxes = image.copy()
|
241 |
all_masks = None
|
242 |
|
243 |
-
cv2.imwrite(".tmp/mask_2.png", masks[3])
|
244 |
-
|
245 |
for _ in auto_list:
|
246 |
mask = masks[_]
|
247 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
@@ -252,8 +275,6 @@ class SegmentAnything2Assist:
|
|
252 |
else:
|
253 |
all_masks = cv2.bitwise_or(all_masks, mask)
|
254 |
|
255 |
-
cv2.imwrite(".tmp/mask_3.png", masks[3])
|
256 |
-
|
257 |
random_color = numpy.random.randint(0, 255, size=3)
|
258 |
image_with_bounding_boxes = cv2.rectangle(
|
259 |
image_with_bounding_boxes,
|
|
|
98 |
)
|
99 |
|
100 |
if download:
|
101 |
+
self.__download_model()
|
102 |
|
103 |
if self.is_model_available():
|
104 |
self.sam2 = sam2.build_sam.build_sam2(
|
|
|
121 |
print(f"SegmentAnything2Assist::is_model_available::{ret}")
|
122 |
return ret
|
123 |
|
124 |
+
def __load_model(self) -> bool:
|
125 |
if self.is_model_available():
|
126 |
self.sam2 = sam2.build_sam(checkpoint=self.model_path)
|
127 |
return True
|
128 |
|
129 |
return False
|
130 |
|
131 |
+
def __download_model(self, force: bool = False) -> bool:
|
132 |
if not force and self.is_model_available():
|
133 |
print(f"{self.model_path} already exists. Skipping download.")
|
134 |
return False
|
|
|
162 |
min_mask_region_area=0,
|
163 |
use_m2m=False,
|
164 |
multimask_output=True,
|
165 |
+
) -> typing.Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray, numpy.ndarray]:
|
166 |
+
"""
|
167 |
+
Generates automatic masks from the given image.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
typing.Tuple: Four numpy arrays where:
|
171 |
+
- segmentation_masks: Numpy array shape (N, H, W, C) where N is the number of masks, H is the height of the image, W is the width of the image, and C is the number of channels. Each N is a binary mask of the image of shape (H, W, C).
|
172 |
+
- bbox_masks: Numpy array of shape (N, 4) where N is the number of masks and 4 is the bounding box coordinates. Each mask is a bounding box of shape (x, y, w, h).
|
173 |
+
- predicted_iou: Numpy array of shape (N,) where N is the number of masks. Each value is the predicted IOU of the mask.
|
174 |
+
- stability_score: Numpy array of shape (N,) where N is the number of masks. Each value is the stability score of the mask.
|
175 |
+
"""
|
176 |
if self.sam2 is None:
|
177 |
print(
|
178 |
"SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
|
|
|
206 |
cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
|
207 |
]
|
208 |
bbox_masks = [mask["bbox"] for mask in masks]
|
209 |
+
predicted_iou = [mask["predicted_iou"] for mask in masks]
|
210 |
+
stability_score = [mask["stability_score"] for mask in masks]
|
211 |
+
|
212 |
+
return (
|
213 |
+
numpy.array(segmentation_masks, dtype=numpy.uint8),
|
214 |
+
numpy.array(bbox_masks, dtype=numpy.uint32),
|
215 |
+
numpy.array(predicted_iou, dtype=numpy.float32),
|
216 |
+
numpy.array(stability_score, dtype=numpy.float32),
|
217 |
+
)
|
218 |
|
219 |
def generate_masks_from_image(
|
220 |
self,
|
|
|
225 |
mask_threshold=0.0,
|
226 |
max_hole_area=0.0,
|
227 |
max_sprinkle_area=0.0,
|
228 |
+
) -> typing.Tuple[numpy.ndarray, numpy.ndarray]:
|
229 |
+
"""
|
230 |
+
Generates masks from the given image.
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
typing.Tuple: Two numpy arrays where:
|
234 |
+
- masks_chw: Numpy array shape (1, H, W) for the mask, H is the height of the image, and W is the width of the image.
|
235 |
+
- mask_iou: Numpy array of shape (1,) for IOU of the mask.
|
236 |
+
"""
|
237 |
generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
|
238 |
self.sam2,
|
239 |
mask_threshold=mask_threshold,
|
|
|
265 |
image_with_bounding_boxes = image.copy()
|
266 |
all_masks = None
|
267 |
|
|
|
|
|
268 |
for _ in auto_list:
|
269 |
mask = masks[_]
|
270 |
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
|
|
|
275 |
else:
|
276 |
all_masks = cv2.bitwise_or(all_masks, mask)
|
277 |
|
|
|
|
|
278 |
random_color = numpy.random.randint(0, 255, size=3)
|
279 |
image_with_bounding_boxes = cv2.rectangle(
|
280 |
image_with_bounding_boxes,
|
test/test_module.py
CHANGED
@@ -2,6 +2,8 @@ import unittest
|
|
2 |
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
3 |
import cv2
|
4 |
|
|
|
|
|
5 |
|
6 |
class TestSegmentAnything2Assist(unittest.TestCase):
|
7 |
def setUp(self) -> None:
|
@@ -39,21 +41,46 @@ class TestSegmentAnything2Assist(unittest.TestCase):
|
|
39 |
device="cpu",
|
40 |
)
|
41 |
|
42 |
-
def
|
43 |
image = cv2.imread("test/assets/liberty.jpg")
|
44 |
|
45 |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
46 |
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
47 |
)
|
48 |
|
49 |
-
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
|
3 |
import cv2
|
4 |
|
5 |
+
import numpy
|
6 |
+
|
7 |
|
8 |
class TestSegmentAnything2Assist(unittest.TestCase):
|
9 |
def setUp(self) -> None:
|
|
|
41 |
device="cpu",
|
42 |
)
|
43 |
|
44 |
+
def _generate_automatic_mask(self):
|
45 |
image = cv2.imread("test/assets/liberty.jpg")
|
46 |
|
47 |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
48 |
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
49 |
)
|
50 |
|
51 |
+
segmentation_masks, bboxes, predicted_iou, stability_score = (
|
52 |
+
sam_model.generate_automatic_masks(image)
|
53 |
+
)
|
54 |
|
55 |
+
self.assertEqual(len(segmentation_masks.shape), 4)
|
56 |
+
self.assertEqual(segmentation_masks[0].shape, image.shape)
|
57 |
+
self.assertEqual(segmentation_masks.shape[3], 3)
|
58 |
+
self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8)
|
59 |
+
self.assertEqual(len(bboxes.shape), 2)
|
60 |
+
self.assertEqual(bboxes[0].shape, (4,))
|
61 |
+
self.assertEqual(type(bboxes[0][0]), numpy.uint32)
|
62 |
+
self.assertEqual(len(predicted_iou.shape), 1)
|
63 |
+
self.assertEqual(type(predicted_iou[0]), numpy.float32)
|
64 |
+
self.assertEqual(len(stability_score.shape), 1)
|
65 |
+
self.assertEqual(type(stability_score[0]), numpy.float32)
|
66 |
|
67 |
+
for segmentation_mask in segmentation_masks:
|
68 |
+
self.assertEqual(segmentation_mask.shape, image.shape)
|
69 |
|
70 |
+
def test_generate_masks_from_image(self):
|
71 |
+
image = cv2.imread("test/assets/liberty.jpg")
|
72 |
+
|
73 |
+
sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
|
74 |
+
sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
|
75 |
+
)
|
76 |
+
|
77 |
+
mask_chw, mask_iou = sam_model.generate_masks_from_image(
|
78 |
+
image, None, None, None
|
79 |
+
)
|
80 |
+
|
81 |
+
self.assertEqual(len(mask_chw.shape), 3)
|
82 |
+
self.assertEqual(mask_chw[0].shape, image.shape)
|
83 |
+
self.assertEqual(mask_chw.shape[0], 1)
|
84 |
+
|
85 |
+
self.assertEqual(len(mask_iou.shape), 1)
|
86 |
+
self.assertEqual(mask_iou.shape[0], 1)
|