xqt commited on
Commit
e312782
β€’
1 Parent(s): ed1b7ea

UPD: added setup.py for installation

Browse files
SegmentAnything2AssistApp.py CHANGED
@@ -4,7 +4,7 @@ import gradio_imageslider
4
  import spaces
5
  import torch
6
 
7
- import src.SegmentAnything2Assist as SegmentAnything2Assist
8
 
9
  example_image_annotation = {
10
  "image": "assets/cars.jpg",
 
4
  import spaces
5
  import torch
6
 
7
+ import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
8
 
9
  example_image_annotation = {
10
  "image": "assets/cars.jpg",
setup.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="SegmentAnything2Assist",
5
+ version="0.1",
6
+ packages=find_packages(where="src"),
7
+ package_dir={"": "src"},
8
+ install_requires=[
9
+ "SAM-2 @ git+https://github.com/facebookresearch/segment-anything-2.git@7e1596c0b6462eb1d1ba7e1492430fed95023598",
10
+ "ultralytics @ git+https://github.com/THU-MIG/yolov10.git@cd2f79c70299c9041fb6d19617ef1296f47575b1",
11
+ "opencv-python==4.10.0.84",
12
+ ],
13
+ author="xqt",
14
+ author_email="[email protected]",
15
+ description="A package to segment anything and assist in the process",
16
+ long_description=open("README.md").read(),
17
+ long_description_content_type="text/markdown",
18
+ url="https://huggingface.co/spaces/xqt/Segment-Anything-2-Assist",
19
+ classifiers=[
20
+ "Programming Language :: Python :: 3",
21
+ "License :: OSI Approved :: MIT License",
22
+ "Operating System :: OS Independent",
23
+ ],
24
+ python_requires=">=3.8.0",
25
+ )
src/{YOLOv10Plugin.py β†’ SegmentAnything2Assist/Plugin/YOLOv10Plugin.py} RENAMED
File without changes
src/{__init__.py β†’ SegmentAnything2Assist/Plugin/__init__.py} RENAMED
File without changes
src/{SegmentAnything2Assist.py β†’ SegmentAnything2Assist/SegmentAnything2Assist.py} RENAMED
@@ -5,12 +5,11 @@ import tqdm
5
  import requests
6
  import torch
7
  import numpy
8
- import pickle
9
 
10
  import sam2.build_sam
11
  import sam2.automatic_mask_generator
12
 
13
- from . import YOLOv10Plugin
14
 
15
  import cv2
16
 
@@ -122,14 +121,17 @@ class SegmentAnything2Assist:
122
  print(f"SegmentAnything2Assist::is_model_available::{ret}")
123
  return ret
124
 
125
- def load_model(self) -> None:
126
  if self.is_model_available():
127
  self.sam2 = sam2.build_sam(checkpoint=self.model_path)
 
128
 
129
- def download_model(self, force: bool = False) -> None:
 
 
130
  if not force and self.is_model_available():
131
  print(f"{self.model_path} already exists. Skipping download.")
132
- return
133
 
134
  response = requests.get(self.download_url, stream=True)
135
  total_size = int(response.headers.get("content-length", 0))
@@ -141,10 +143,12 @@ class SegmentAnything2Assist:
141
  file.write(data)
142
  progress_bar.update(len(data))
143
 
 
 
144
  def generate_automatic_masks(
145
  self,
146
- image,
147
- points_per_side=32,
148
  points_per_batch=32,
149
  pred_iou_thresh=0.8,
150
  stability_score_thresh=0.95,
 
5
  import requests
6
  import torch
7
  import numpy
 
8
 
9
  import sam2.build_sam
10
  import sam2.automatic_mask_generator
11
 
12
+ from .Plugin import YOLOv10Plugin
13
 
14
  import cv2
15
 
 
121
  print(f"SegmentAnything2Assist::is_model_available::{ret}")
122
  return ret
123
 
124
+ def load_model(self) -> bool:
125
  if self.is_model_available():
126
  self.sam2 = sam2.build_sam(checkpoint=self.model_path)
127
+ return True
128
 
129
+ return False
130
+
131
+ def download_model(self, force: bool = False) -> bool:
132
  if not force and self.is_model_available():
133
  print(f"{self.model_path} already exists. Skipping download.")
134
+ return False
135
 
136
  response = requests.get(self.download_url, stream=True)
137
  total_size = int(response.headers.get("content-length", 0))
 
143
  file.write(data)
144
  progress_bar.update(len(data))
145
 
146
+ return True
147
+
148
  def generate_automatic_masks(
149
  self,
150
+ image: numpy.ndarray,
151
+ points_per_side=10,
152
  points_per_batch=32,
153
  pred_iou_thresh=0.8,
154
  stability_score_thresh=0.95,
src/SegmentAnything2Assist/__init__.py ADDED
File without changes
test/assets/liberty.jpg ADDED
test/test_module.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import src.SegmentAnything2Assist.SegmentAnything2Assist as SegmentAnything2Assist
3
+ import cv2
4
+
5
+
6
+ class TestSegmentAnything2Assist(unittest.TestCase):
7
+ def setUp(self) -> None:
8
+ return super().setUp()
9
+
10
+ def tearDown(self) -> None:
11
+ return super().tearDown()
12
+
13
+ def _loading_all_sam_model_types(self):
14
+ # Test loading all types of SAM2 models.
15
+ all_sam_models_type = [
16
+ "sam2_hiera_tiny",
17
+ "sam2_hiera_small",
18
+ "sam2_hiera_base_plus",
19
+ "sam2_hiera_large",
20
+ ]
21
+ for sam_model_type in all_sam_models_type:
22
+ sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
23
+ sam_model_name=sam_model_type, download=True, device="cpu"
24
+ )
25
+ self.assertEqual(sam_model.is_model_available(), True)
26
+
27
+ sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
28
+ sam_model_name=sam_model_type,
29
+ download=False,
30
+ model_path=f".tmp/checkpoints/{sam_model_type}.pth",
31
+ device="cpu",
32
+ )
33
+
34
+ with self.assertRaises(Exception):
35
+ sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
36
+ sam_model_name=sam_model_type,
37
+ download=False,
38
+ model_path=".",
39
+ device="cpu",
40
+ )
41
+
42
+ def test_generate_automatic_mask(self):
43
+ image = cv2.imread("test/assets/liberty.jpg")
44
+
45
+ sam_model = SegmentAnything2Assist.SegmentAnything2Assist(
46
+ sam_model_name="sam2_hiera_tiny", download=True, device="cpu"
47
+ )
48
+
49
+ masks, segmentation_masks, bboxes = sam_model.generate_automatic_masks(image)
50
+
51
+ print(type(masks[0]))
52
+ print(type(segmentation_masks[0]))
53
+ print(type(bboxes[0]))
54
+
55
+ self.assertEqual(len(masks), len(segmentation_masks))
56
+ self.assertEqual(len(masks), len(bboxes))
57
+
58
+ # for mask, segmentation_mask, bbox in zip(masks, segmentation_masks, bboxes):
59
+ self.assertEqual(segmentation_masks[0].shape, image.shape)