import unittest import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist import cv2 import numpy class TestSegmentAnything2Assist(unittest.TestCase): def setUp(self) -> None: return super().setUp() def tearDown(self) -> None: return super().tearDown() def _loading_all_sam_model_types(self): # Test loading all types of SAM2 models. all_sam_models_type = [ "sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_large", ] for sam_model_type in all_sam_models_type: sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name=sam_model_type, download=True, device="cpu" ) self.assertEqual(sam_model.is_model_available(), True) sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name=sam_model_type, download=False, model_path=f".tmp/checkpoints/{sam_model_type}.pth", device="cpu", ) with self.assertRaises(Exception): sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name=sam_model_type, download=False, model_path=".", device="cpu", ) def _generate_automatic_mask(self): image = cv2.imread("test/assets/liberty.jpg") sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name="sam2_hiera_tiny", download=True, device="cpu" ) segmentation_masks, bboxes, predicted_iou, stability_score = ( sam_model.generate_automatic_masks(image) ) self.assertEqual(len(segmentation_masks.shape), 4) self.assertEqual(segmentation_masks[0].shape, image.shape) self.assertEqual(segmentation_masks.shape[3], 3) self.assertEqual(type(segmentation_masks[0][0][0][0]), numpy.uint8) self.assertEqual(len(bboxes.shape), 2) self.assertEqual(bboxes[0].shape, (4,)) self.assertEqual(type(bboxes[0][0]), numpy.uint32) self.assertEqual(len(predicted_iou.shape), 1) self.assertEqual(type(predicted_iou[0]), numpy.float32) self.assertEqual(len(stability_score.shape), 1) self.assertEqual(type(stability_score[0]), numpy.float32) for segmentation_mask in segmentation_masks: self.assertEqual(segmentation_mask.shape, image.shape) def test_generate_masks_from_image(self): image = cv2.imread("test/assets/liberty.jpg") sam_model = SegmentAnything2Assist.SegmentAnything2Assist( sam_model_name="sam2_hiera_tiny", download=True, device="cpu" ) mask_chw, mask_iou = sam_model.generate_masks_from_image( image, None, None, None ) self.assertEqual(len(mask_chw.shape), 3) self.assertEqual(mask_chw[0].shape, image.shape) self.assertEqual(mask_chw.shape[0], 1) self.assertEqual(len(mask_iou.shape), 1) self.assertEqual(mask_iou.shape[0], 1)