Spaces:
Running
on
Zero
Running
on
Zero
import unittest | |
import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist | |
import cv2 | |
import numpy | |
class TestSegmentAnything2Assist(unittest.TestCase): | |
def setUp(self) -> None: | |
return super().setUp() | |
def tearDown(self) -> None: | |
return super().tearDown() | |
def _loading_all_sam_model_types(self): | |
# Test loading all types of SAM2 models. | |
all_sam_models_type = [ | |
"sam2_hiera_tiny", | |
"sam2_hiera_small", | |
"sam2_hiera_base_plus", | |
"sam2_hiera_large", | |
] | |
for sam_model_type in all_sam_models_type: | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name=sam_model_type, download=True, device="cpu" | |
) | |
self.assertEqual(sam_model.is_model_available(), True) | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name=sam_model_type, | |
download=False, | |
model_path=f".tmp/checkpoints/{sam_model_type}.pth", | |
device="cpu", | |
) | |
with self.assertRaises(Exception): | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name=sam_model_type, | |
download=False, | |
model_path=".", | |
device="cpu", | |
) | |
def _generate_automatic_mask(self): | |
image = cv2.imread("test/assets/liberty.jpg") | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name="sam2_hiera_tiny", download=True, device="cpu" | |
) | |
segmentation_masks, bboxes, predicted_iou, stability_score = ( | |
sam_model.generate_automatic_masks(image) | |
) | |
self.assertEqual(len(segmentation_masks.shape), 4) | |
self.assertEqual(segmentation_masks[0].shape, image.shape) | |
self.assertEqual(segmentation_masks.shape[3], 3) | |
self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8) | |
self.assertEqual(len(bboxes.shape), 2) | |
self.assertEqual(bboxes[0].shape, (4,)) | |
self.assertEqual(type(bboxes[0][0]), numpy.uint32) | |
self.assertEqual(len(predicted_iou.shape), 1) | |
self.assertEqual(type(predicted_iou[0]), numpy.float32) | |
self.assertEqual(len(stability_score.shape), 1) | |
self.assertEqual(type(stability_score[0]), numpy.float32) | |
for segmentation_mask in segmentation_masks: | |
self.assertEqual(segmentation_mask.shape, image.shape) | |
def test_generate_masks_from_image(self): | |
image = cv2.imread("test/assets/liberty.jpg") | |
sam_model = SegmentAnything2Assist.SegmentAnything2Assist( | |
sam_model_name="sam2_hiera_tiny", download=True, device="cpu" | |
) | |
mask_chw, mask_iou = sam_model.generate_masks_from_image( | |
image, None, None, None | |
) | |
self.assertEqual(len(mask_chw.shape), 3) | |
self.assertEqual(mask_chw[0].shape, image.shape) | |
self.assertEqual(mask_chw.shape[0], 1) | |
self.assertEqual(len(mask_iou.shape), 1) | |
self.assertEqual(mask_iou.shape[0], 1) | |