xqt commited on
Commit
f3d3559
β€’
1 Parent(s): 00f618e

REF: Uses internal variable for auto mask and image segmentation.

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. SegmentAnything2AssistApp.py +434 -166
  3. src/SegmentAnything2Assist.py +241 -192
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .tmp/
 
 
1
+ .tmp/
2
+ .venv/
SegmentAnything2AssistApp.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio
2
- import gradio_image_annotation
3
  import gradio_imageslider
4
  import spaces
5
  import torch
@@ -8,29 +8,65 @@ import src.SegmentAnything2Assist as SegmentAnything2Assist
8
 
9
  example_image_annotation = {
10
  "image": "assets/cars.jpg",
11
- "boxes": [{'label': '+', 'color': (0, 255, 0), 'xmin': 886, 'ymin': 551, 'xmax': 886, 'ymax': 551}, {'label': '-', 'color': (255, 0, 0), 'xmin': 1239, 'ymin': 576, 'xmax': 1239, 'ymax': 576}, {'label': '-', 'color': (255, 0, 0), 'xmin': 610, 'ymin': 574, 'xmax': 610, 'ymax': 574}, {'label': '', 'color': (0, 0, 255), 'xmin': 254, 'ymin': 466, 'xmax': 1347, 'ymax': 1047}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  }
13
 
 
14
  VERBOSE = True
 
 
 
 
 
 
15
 
16
- segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = "sam2_hiera_tiny", device = torch.device("cuda"))
17
- __image_point_coords = None
18
- __image_point_labels = None
19
- __image_box = None
20
- __current_mask = None
21
- __current_segment = None
22
 
23
  def __change_base_model(model_name, device):
24
  global segment_anything2assist
 
25
  try:
26
- segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(model_name = model_name, device = torch.device(device))
27
- gradio.Info(f"Model changed to {model_name} on {device}", duration = 5)
 
 
28
  except:
29
- gradio.Error(f"Model could not be changed", duration = 5)
 
30
 
31
  def __post_process_annotator_inputs(value):
32
- global __image_point_coords, __image_point_labels, __image_box
33
- global __current_mask, __current_segment
34
  if VERBOSE:
35
  print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.")
36
  __current_mask, __current_segment = None, None
@@ -38,111 +74,167 @@ def __post_process_annotator_inputs(value):
38
  __image_point_coords = []
39
  __image_point_labels = []
40
  __image_box = []
41
-
42
  b_has_box = False
43
  for box in value["boxes"]:
44
- if box['label'] == '':
45
  if not b_has_box:
46
  new_box = box.copy()
47
- new_box['color'] = (0, 0, 255)
48
  new_boxes.append(new_box)
49
  b_has_box = True
50
- __image_box = [
51
- box['xmin'],
52
- box['ymin'],
53
- box['xmax'],
54
- box['ymax']
55
- ]
56
-
57
-
58
- elif box['label'] == '+' or box['label'] == '-':
59
  new_box = box.copy()
60
- new_box['color'] = (0, 255, 0) if box['label'] == '+' else (255, 0, 0)
61
- new_box['xmin'] = int((box['xmin'] + box['xmax']) / 2)
62
- new_box['ymin'] = int((box['ymin'] + box['ymax']) / 2)
63
- new_box['xmax'] = new_box['xmin']
64
- new_box['ymax'] = new_box['ymin']
65
  new_boxes.append(new_box)
66
-
67
- __image_point_coords.append([new_box['xmin'], new_box['ymin']])
68
- __image_point_labels.append(1 if box['label'] == '+' else 0)
69
-
70
- if len(__image_box) == 0:
71
- __image_box = None
72
-
73
- if len(__image_point_coords) == 0:
74
- __image_point_coords = None
75
-
76
- if len(__image_point_labels) == 0:
77
- __image_point_labels = None
78
 
79
  if VERBOSE:
80
  print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.")
81
 
 
82
 
83
 
84
- @spaces.GPU(duration = 60)
85
- def __generate_mask(value, mask_threshold, max_hole_area, max_sprinkle_area, image_output_mode):
86
- global __current_mask, __current_segment
87
- global __image_point_coords, __image_point_labels, __image_box
 
 
 
 
88
  global segment_anything2assist
89
 
90
  # Force post processing of annotated image
91
- __post_process_annotator_inputs(value)
 
 
92
 
93
  if VERBOSE:
94
  print("SegmentAnything2AssistApp::__generate_mask::Called.")
95
  mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image(
96
  value["image"],
97
- __image_point_coords,
98
- __image_point_labels,
99
- __image_box,
100
  mask_threshold,
101
  max_hole_area,
102
- max_sprinkle_area
103
  )
104
 
105
  if VERBOSE:
106
  print("SegmentAnything2AssistApp::__generate_mask::Masks generated.")
107
 
108
- __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(value["image"], mask_chw[0])
 
 
109
 
110
  if VERBOSE:
111
  print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.")
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if image_output_mode == "Mask":
114
- return [value["image"], __current_mask]
 
 
 
 
 
 
115
  elif image_output_mode == "Segment":
116
- return [value["image"], __current_segment]
 
 
 
 
 
 
117
  else:
118
  gradio.Warning("This is an issue, please report the problem!", duration=5)
119
- return gradio_imageslider.ImageSlider(render = True)
 
 
 
 
 
 
 
120
 
121
- def __change_output_mode(image_input, radio):
122
- global __current_mask, __current_segment
123
- global __image_point_coords, __image_point_labels, __image_box
124
  if VERBOSE:
125
  print("SegmentAnything2AssistApp::__generate_mask::Called.")
126
  if __current_mask is None or __current_segment is None:
127
  gradio.Warning("Configuration was changed, generate the mask again", duration=5)
128
- return gradio_imageslider.ImageSlider(render = True)
129
  if radio == "Mask":
130
  return [image_input["image"], __current_mask]
131
  elif radio == "Segment":
132
  return [image_input["image"], __current_segment]
133
  else:
134
  gradio.Warning("This is an issue, please report the problem!", duration=5)
135
- return gradio_imageslider.ImageSlider(render = True)
136
-
137
- def __generate_multi_mask_output(image, auto_list, auto_mode, auto_bbox_mode):
 
 
 
138
  global segment_anything2assist
139
- image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image(image, [int(i) - 1 for i in auto_list])
140
-
 
 
 
 
 
 
 
141
  output_1 = image_with_bbox if auto_bbox_mode else image
142
  output_2 = mask if auto_mode == "Mask" else segment
143
  return [output_1, output_2]
144
-
145
- @spaces.GPU(duration = 60)
 
146
  def __generate_auto_mask(
147
  image,
148
  points_per_side,
@@ -159,13 +251,13 @@ def __generate_auto_mask(
159
  min_mask_region_area,
160
  use_m2m,
161
  multimask_output,
162
- output_mode
163
- ):
164
  global segment_anything2assist
165
  if VERBOSE:
166
- print("SegmentAnything2AssistApp::__generate_auto_mask::Called.")
167
-
168
- __auto_masks = segment_anything2assist.generate_automatic_masks(
169
  image,
170
  points_per_side,
171
  points_per_batch,
@@ -180,43 +272,84 @@ def __generate_auto_mask(
180
  crop_n_points_downscale_factor,
181
  min_mask_region_area,
182
  use_m2m,
183
- multimask_output
184
  )
185
-
186
  if len(__auto_masks) == 0:
187
- gradio.Warning("No masks generated, please tweak the advanced parameters.", duration = 5)
188
- return gradio_imageslider.ImageSlider(), \
189
- gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False), \
190
- gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False)
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  else:
192
  choices = [str(i) for i in range(len(__auto_masks))]
193
- returning_image = __generate_multi_mask_output(image, ["0"], output_mode, False)
194
- return returning_image, \
195
- gradio.CheckboxGroup(choices, value = ["0"], label = "Mask List", interactive = True), \
196
- gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = True)
197
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  with gradio.Blocks() as base_app:
199
  gradio.Markdown("# SegmentAnything2Assist")
200
  with gradio.Row():
201
  with gradio.Column():
202
  base_model_choice = gradio.Dropdown(
203
- ['sam2_hiera_large', 'sam2_hiera_small', 'sam2_hiera_base_plus','sam2_hiera_tiny'],
204
- value = 'sam2_hiera_tiny',
205
- label = "Model Choice"
206
- )
 
 
 
 
 
207
  with gradio.Column():
208
  base_gpu_choice = gradio.Dropdown(
209
- ['cpu', 'cuda'],
210
- value = 'cuda',
211
- label = "Device Choice"
212
  )
213
- base_model_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice])
214
- base_gpu_choice.change(__change_base_model, inputs = [base_model_choice, base_gpu_choice])
215
- with gradio.Tab(label = "Image Segmentation", id = "image_tab") as image_tab:
216
- gradio.Markdown("Image Segmentation", render = True)
 
 
 
 
 
 
217
  with gradio.Column():
218
- with gradio.Accordion("Image Annotation Documentation", open = False):
219
- gradio.Markdown("""
 
220
  Image annotation allows you to mark specific regions of an image with labels.
221
  In this app, you can annotate an image by drawing boxes and assigning labels to them.
222
  The labels can be either '+' or '-'.
@@ -229,88 +362,223 @@ with gradio.Blocks() as base_app:
229
  Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area.
230
  These options control the sensitivity and accuracy of the segmentation process.
231
  Experiment with different settings to achieve the desired results.
232
- """)
233
- image_input = gradio_image_annotation.image_annotator(example_image_annotation)
234
- with gradio.Accordion("Advanced Options", open = False):
235
- image_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "SAM Mask Threshold")
236
- image_generate_SAM_max_hole_area = gradio.Slider(0, 1000, 0, label = "SAM Max Hole Area")
237
- image_generate_SAM_max_sprinkle_area = gradio.Slider(0, 1000, 0, label = "SAM Max Sprinkle Area")
 
 
 
 
 
 
 
 
 
238
  image_generate_mask_button = gradio.Button("Generate Mask")
239
- image_output = gradio_imageslider.ImageSlider()
240
- image_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode")
241
-
242
- image_input.change(__post_process_annotator_inputs, inputs = [image_input])
243
- image_generate_mask_button.click(__generate_mask, inputs = [
244
- image_input,
245
- image_generate_SAM_mask_threshold,
246
- image_generate_SAM_max_hole_area,
247
- image_generate_SAM_max_sprinkle_area,
248
- image_output_mode
249
- ],
250
- outputs = [image_output])
251
- image_output_mode.change(__change_output_mode, inputs = [image_input, image_output_mode], outputs = [image_output])
252
- with gradio.Tab(label = "Auto Segmentation", id = "auto_tab"):
253
- gradio.Markdown("Auto Segmentation", render = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  with gradio.Column():
255
- with gradio.Accordion("Auto Annotation Documentation", open = False):
256
- gradio.Markdown("""
257
- """)
 
 
258
  auto_input = gradio.Image("assets/cars.jpg")
259
- with gradio.Accordion("Advanced Options", open = False):
260
- auto_generate_SAM_points_per_side = gradio.Slider(1, 64, 32, 1, label = "Points Per Side", interactive = True)
261
- auto_generate_SAM_points_per_batch = gradio.Slider(1, 64, 32, 1, label = "Points Per Batch", interactive = True)
262
- auto_generate_SAM_pred_iou_thresh = gradio.Slider(0.0, 1.0, 0.8, 1, label = "Pred IOU Threshold", interactive = True)
263
- auto_generate_SAM_stability_score_thresh = gradio.Slider(0.0, 1.0, 0.95, label = "Stability Score Threshold", interactive = True)
264
- auto_generate_SAM_stability_score_offset = gradio.Slider(0.0, 1.0, 1.0, label = "Stability Score Offset", interactive = True)
265
- auto_generate_SAM_mask_threshold = gradio.Slider(0.0, 1.0, 0.0, label = "Mask Threshold", interactive = True)
266
- auto_generate_SAM_box_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Box NMS Threshold", interactive = True)
267
- auto_generate_SAM_crop_n_layers = gradio.Slider(0, 10, 0, 1, label = "Crop N Layers", interactive = True)
268
- auto_generate_SAM_crop_nms_thresh = gradio.Slider(0.0, 1.0, 0.7, label = "Crop NMS Threshold", interactive = True)
269
- auto_generate_SAM_crop_overlay_ratio = gradio.Slider(0.0, 1.0, 512 / 1500, label = "Crop Overlay Ratio", interactive = True)
270
- auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider(1, 10, 1, label = "Crop N Points Downscale Factor", interactive = True)
271
- auto_generate_SAM_min_mask_region_area = gradio.Slider(0, 1000, 0, label = "Min Mask Region Area", interactive = True)
272
- auto_generate_SAM_use_m2m = gradio.Checkbox(label = "Use M2M", interactive = True)
273
- auto_generate_SAM_multimask_output = gradio.Checkbox(value = True, label = "Multi Mask Output", interactive = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  auto_generate_button = gradio.Button("Generate Auto Mask")
275
  with gradio.Row():
276
  with gradio.Column():
277
- auto_output_mode = gradio.Radio(["Segment", "Mask"], value = "Segment", label = "Output Mode", interactive = True)
278
- auto_output_list = gradio.CheckboxGroup([], value = [], label = "Mask List", interactive = False)
279
- auto_output_bbox = gradio.Checkbox(value = False, label = "Show Bounding Box", interactive = False)
280
- with gradio.Column(scale = 3):
 
 
 
 
 
 
 
 
 
281
  auto_output = gradio_imageslider.ImageSlider()
282
-
 
 
 
 
 
 
 
 
 
 
283
  auto_generate_button.click(
284
- __generate_auto_mask,
285
- inputs = [
286
- auto_input,
287
- auto_generate_SAM_points_per_side,
288
- auto_generate_SAM_points_per_batch,
289
- auto_generate_SAM_pred_iou_thresh,
290
- auto_generate_SAM_stability_score_thresh,
291
- auto_generate_SAM_stability_score_offset,
292
- auto_generate_SAM_mask_threshold,
293
- auto_generate_SAM_box_nms_thresh,
294
- auto_generate_SAM_crop_n_layers,
295
- auto_generate_SAM_crop_nms_thresh,
296
- auto_generate_SAM_crop_overlay_ratio,
297
- auto_generate_SAM_crop_n_points_downscale_factor,
298
- auto_generate_SAM_min_mask_region_area,
299
- auto_generate_SAM_use_m2m,
300
- auto_generate_SAM_multimask_output,
301
- auto_output_mode
302
  ],
303
- outputs = [
304
  auto_output,
305
  auto_output_list,
306
- auto_output_bbox
307
- ]
 
 
308
  )
309
- auto_output_list.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output])
310
- auto_output_bbox.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output])
311
- auto_output_mode.change(__generate_multi_mask_output, inputs = [auto_input, auto_output_list, auto_output_mode, auto_output_bbox], outputs = [auto_output])
312
-
313
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  if __name__ == "__main__":
315
  base_app.launch()
316
-
 
1
  import gradio
2
+ import gradio_image_annotation
3
  import gradio_imageslider
4
  import spaces
5
  import torch
 
8
 
9
  example_image_annotation = {
10
  "image": "assets/cars.jpg",
11
+ "boxes": [
12
+ {
13
+ "label": "+",
14
+ "color": (0, 255, 0),
15
+ "xmin": 886,
16
+ "ymin": 551,
17
+ "xmax": 886,
18
+ "ymax": 551,
19
+ },
20
+ {
21
+ "label": "-",
22
+ "color": (255, 0, 0),
23
+ "xmin": 1239,
24
+ "ymin": 576,
25
+ "xmax": 1239,
26
+ "ymax": 576,
27
+ },
28
+ {
29
+ "label": "-",
30
+ "color": (255, 0, 0),
31
+ "xmin": 610,
32
+ "ymin": 574,
33
+ "xmax": 610,
34
+ "ymax": 574,
35
+ },
36
+ {
37
+ "label": "",
38
+ "color": (0, 0, 255),
39
+ "xmin": 254,
40
+ "ymin": 466,
41
+ "xmax": 1347,
42
+ "ymax": 1047,
43
+ },
44
+ ],
45
  }
46
 
47
+
48
  VERBOSE = True
49
+ DEBUG = False
50
+
51
+
52
+ segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(
53
+ model_name="sam2_hiera_tiny", device=torch.device("cpu")
54
+ )
55
 
 
 
 
 
 
 
56
 
57
  def __change_base_model(model_name, device):
58
  global segment_anything2assist
59
+ gradio.Info(f"Changing model to {model_name} on {device}", duration=3)
60
  try:
61
+ segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(
62
+ model_name=model_name, device=torch.device(device)
63
+ )
64
+ gradio.Info(f"Model has been changed to {model_name} on {device}", duration=5)
65
  except:
66
+ gradio.Error(f"Model could not be changed", duration=5)
67
+
68
 
69
  def __post_process_annotator_inputs(value):
 
 
70
  if VERBOSE:
71
  print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Called.")
72
  __current_mask, __current_segment = None, None
 
74
  __image_point_coords = []
75
  __image_point_labels = []
76
  __image_box = []
77
+
78
  b_has_box = False
79
  for box in value["boxes"]:
80
+ if box["label"] == "":
81
  if not b_has_box:
82
  new_box = box.copy()
83
+ new_box["color"] = (0, 0, 255)
84
  new_boxes.append(new_box)
85
  b_has_box = True
86
+ __image_box = [box["xmin"], box["ymin"], box["xmax"], box["ymax"]]
87
+
88
+ elif box["label"] == "+" or box["label"] == "-":
 
 
 
 
 
 
89
  new_box = box.copy()
90
+ new_box["color"] = (0, 255, 0) if box["label"] == "+" else (255, 0, 0)
91
+ new_box["xmin"] = int((box["xmin"] + box["xmax"]) / 2)
92
+ new_box["ymin"] = int((box["ymin"] + box["ymax"]) / 2)
93
+ new_box["xmax"] = new_box["xmin"]
94
+ new_box["ymax"] = new_box["ymin"]
95
  new_boxes.append(new_box)
96
+
97
+ __image_point_coords.append([new_box["xmin"], new_box["ymin"]])
98
+ __image_point_labels.append(1 if box["label"] == "+" else 0)
99
+
100
+ if len(__image_box) == 0:
101
+ __image_box = None
102
+
103
+ if len(__image_point_coords) == 0:
104
+ __image_point_coords = None
105
+
106
+ if len(__image_point_labels) == 0:
107
+ __image_point_labels = None
108
 
109
  if VERBOSE:
110
  print("SegmentAnything2AssistApp::____post_process_annotator_inputs::Done.")
111
 
112
+ return __image_point_coords, __image_point_labels, __image_box
113
 
114
 
115
+ @spaces.GPU(duration=60)
116
+ def __generate_mask(
117
+ value,
118
+ mask_threshold,
119
+ max_hole_area,
120
+ max_sprinkle_area,
121
+ image_output_mode,
122
+ ):
123
  global segment_anything2assist
124
 
125
  # Force post processing of annotated image
126
+ image_point_coords, image_point_labels, image_box = __post_process_annotator_inputs(
127
+ value
128
+ )
129
 
130
  if VERBOSE:
131
  print("SegmentAnything2AssistApp::__generate_mask::Called.")
132
  mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image(
133
  value["image"],
134
+ image_point_coords,
135
+ image_point_labels,
136
+ image_box,
137
  mask_threshold,
138
  max_hole_area,
139
+ max_sprinkle_area,
140
  )
141
 
142
  if VERBOSE:
143
  print("SegmentAnything2AssistApp::__generate_mask::Masks generated.")
144
 
145
+ __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(
146
+ value["image"], mask_chw[0]
147
+ )
148
 
149
  if VERBOSE:
150
  print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.")
151
 
152
+ __image_box = gradio.DataFrame(value=[[]])
153
+ __image_point_coords = gradio.DataFrame(value=[[]])
154
+ if DEBUG:
155
+ __image_box = gradio.DataFrame(
156
+ value=[image_box],
157
+ label="Box",
158
+ interactive=False,
159
+ headers=["XMin", "YMin", "XMax", "YMax"],
160
+ )
161
+ x = []
162
+ for i, _ in enumerate(image_point_coords):
163
+ x.append(
164
+ [
165
+ image_point_labels[i],
166
+ image_point_coords[i][0],
167
+ image_point_coords[i][1],
168
+ ]
169
+ )
170
+ __image_point_coords = gradio.DataFrame(
171
+ value=x,
172
+ label="Point Coords",
173
+ interactive=False,
174
+ headers=["Label", "X", "Y"],
175
+ )
176
+
177
  if image_output_mode == "Mask":
178
+ return (
179
+ [value["image"], __current_mask],
180
+ __image_point_coords,
181
+ __image_box,
182
+ __current_mask,
183
+ __current_segment,
184
+ )
185
  elif image_output_mode == "Segment":
186
+ return (
187
+ [value["image"], __current_segment],
188
+ __image_point_coords,
189
+ __image_box,
190
+ __current_mask,
191
+ __current_segment,
192
+ )
193
  else:
194
  gradio.Warning("This is an issue, please report the problem!", duration=5)
195
+ return (
196
+ gradio_imageslider.ImageSlider(render=True),
197
+ __image_point_coords,
198
+ __image_box,
199
+ __current_mask,
200
+ __current_segment,
201
+ )
202
+
203
 
204
+ def __change_output_mode(image_input, radio, __current_mask, __current_segment):
 
 
205
  if VERBOSE:
206
  print("SegmentAnything2AssistApp::__generate_mask::Called.")
207
  if __current_mask is None or __current_segment is None:
208
  gradio.Warning("Configuration was changed, generate the mask again", duration=5)
209
+ return gradio_imageslider.ImageSlider(render=True)
210
  if radio == "Mask":
211
  return [image_input["image"], __current_mask]
212
  elif radio == "Segment":
213
  return [image_input["image"], __current_segment]
214
  else:
215
  gradio.Warning("This is an issue, please report the problem!", duration=5)
216
+ return gradio_imageslider.ImageSlider(render=True)
217
+
218
+
219
+ def __generate_multi_mask_output(
220
+ image, auto_list, auto_mode, auto_bbox_mode, masks, bboxes
221
+ ):
222
  global segment_anything2assist
223
+
224
+ # When value from gallery is called, it is a tuple
225
+ if type(masks[0]) == tuple:
226
+ masks = [mask[0] for mask in masks]
227
+
228
+ image_with_bbox, mask, segment = segment_anything2assist.apply_auto_mask_to_image(
229
+ image, [int(i) - 1 for i in auto_list], masks, bboxes
230
+ )
231
+
232
  output_1 = image_with_bbox if auto_bbox_mode else image
233
  output_2 = mask if auto_mode == "Mask" else segment
234
  return [output_1, output_2]
235
+
236
+
237
+ @spaces.GPU(duration=60)
238
  def __generate_auto_mask(
239
  image,
240
  points_per_side,
 
251
  min_mask_region_area,
252
  use_m2m,
253
  multimask_output,
254
+ output_mode,
255
+ ):
256
  global segment_anything2assist
257
  if VERBOSE:
258
+ print("SegmentAnything2AssistApp::__generate_auto_mask::Called.")
259
+
260
+ __auto_masks, masks, bboxes = segment_anything2assist.generate_automatic_masks(
261
  image,
262
  points_per_side,
263
  points_per_batch,
 
272
  crop_n_points_downscale_factor,
273
  min_mask_region_area,
274
  use_m2m,
275
+ multimask_output,
276
  )
277
+
278
  if len(__auto_masks) == 0:
279
+ gradio.Warning(
280
+ "No masks generated, please tweak the advanced parameters.", duration=5
281
+ )
282
+ return (
283
+ gradio_imageslider.ImageSlider(),
284
+ gradio.CheckboxGroup([], value=[], label="Mask List", interactive=False),
285
+ gradio.Checkbox(value=False, label="Show Bounding Box", interactive=False),
286
+ gradio.Gallery(
287
+ None, label="Output Gallery", interactive=False, type="numpy"
288
+ ),
289
+ gradio.DataFrame(
290
+ value=[[]],
291
+ label="Box",
292
+ interactive=False,
293
+ headers=["XMin", "YMin", "XMax", "YMax"],
294
+ ),
295
+ )
296
  else:
297
  choices = [str(i) for i in range(len(__auto_masks))]
298
+
299
+ returning_image = __generate_multi_mask_output(
300
+ image, ["0"], output_mode, False, masks, bboxes
301
+ )
302
+ return (
303
+ returning_image,
304
+ gradio.CheckboxGroup(
305
+ choices, value=["0"], label="Mask List", interactive=True
306
+ ),
307
+ gradio.Checkbox(value=False, label="Show Bounding Box", interactive=True),
308
+ gradio.Gallery(
309
+ masks, label="Output Gallery", interactive=True, type="numpy"
310
+ ),
311
+ gradio.DataFrame(
312
+ value=bboxes,
313
+ label="Box",
314
+ interactive=False,
315
+ headers=["XMin", "YMin", "XMax", "YMax"],
316
+ type="array",
317
+ ),
318
+ )
319
+
320
+
321
  with gradio.Blocks() as base_app:
322
  gradio.Markdown("# SegmentAnything2Assist")
323
  with gradio.Row():
324
  with gradio.Column():
325
  base_model_choice = gradio.Dropdown(
326
+ [
327
+ "sam2_hiera_large",
328
+ "sam2_hiera_small",
329
+ "sam2_hiera_base_plus",
330
+ "sam2_hiera_tiny",
331
+ ],
332
+ value="sam2_hiera_tiny",
333
+ label="Model Choice",
334
+ )
335
  with gradio.Column():
336
  base_gpu_choice = gradio.Dropdown(
337
+ ["cpu", "cuda"], value="cuda", label="Device Choice"
 
 
338
  )
339
+ base_model_choice.change(
340
+ __change_base_model, inputs=[base_model_choice, base_gpu_choice]
341
+ )
342
+ base_gpu_choice.change(
343
+ __change_base_model, inputs=[base_model_choice, base_gpu_choice]
344
+ )
345
+
346
+ # Image Segmentation
347
+ with gradio.Tab(label="Image Segmentation", id="image_tab") as image_tab:
348
+ gradio.Markdown("Image Segmentation", render=True)
349
  with gradio.Column():
350
+ with gradio.Accordion("Image Annotation Documentation", open=False):
351
+ gradio.Markdown(
352
+ """
353
  Image annotation allows you to mark specific regions of an image with labels.
354
  In this app, you can annotate an image by drawing boxes and assigning labels to them.
355
  The labels can be either '+' or '-'.
 
362
  Note that the advanced options allow you to adjust the SAM mask threshold, maximum hole area, and maximum sprinkle area.
363
  These options control the sensitivity and accuracy of the segmentation process.
364
  Experiment with different settings to achieve the desired results.
365
+ """
366
+ )
367
+ image_input = gradio_image_annotation.image_annotator(
368
+ example_image_annotation
369
+ )
370
+ with gradio.Accordion("Advanced Options", open=False):
371
+ image_generate_SAM_mask_threshold = gradio.Slider(
372
+ 0.0, 1.0, 0.0, label="SAM Mask Threshold"
373
+ )
374
+ image_generate_SAM_max_hole_area = gradio.Slider(
375
+ 0, 1000, 0, label="SAM Max Hole Area"
376
+ )
377
+ image_generate_SAM_max_sprinkle_area = gradio.Slider(
378
+ 0, 1000, 0, label="SAM Max Sprinkle Area"
379
+ )
380
  image_generate_mask_button = gradio.Button("Generate Mask")
381
+ with gradio.Row():
382
+ with gradio.Column():
383
+ image_output_mode = gradio.Radio(
384
+ ["Segment", "Mask"], value="Segment", label="Output Mode"
385
+ )
386
+ with gradio.Column(scale=3):
387
+ image_output = gradio_imageslider.ImageSlider()
388
+
389
+ with gradio.Accordion("Debug", open=DEBUG, visible=DEBUG):
390
+ __image_point_coords = gradio.DataFrame(
391
+ value=[["+", 886, 551], ["-", 1239, 576]],
392
+ label="Point Coords",
393
+ interactive=False,
394
+ headers=["Label", "X", "Y"],
395
+ )
396
+ __image_box = gradio.DataFrame(
397
+ value=[[254, 466, 1347, 1047]],
398
+ label="Box",
399
+ interactive=False,
400
+ headers=["XMin", "YMin", "XMax", "YMax"],
401
+ )
402
+ __current_mask = gradio.Image(label="Current Mask", interactive=False)
403
+ __current_segment = gradio.Image(
404
+ label="Current Segment", interactive=False
405
+ )
406
+
407
+ # image_input.change(__post_process_annotator_inputs, inputs = [image_input])
408
+ image_generate_mask_button.click(
409
+ __generate_mask,
410
+ inputs=[
411
+ image_input,
412
+ image_generate_SAM_mask_threshold,
413
+ image_generate_SAM_max_hole_area,
414
+ image_generate_SAM_max_sprinkle_area,
415
+ image_output_mode,
416
+ ],
417
+ outputs=[
418
+ image_output,
419
+ __image_point_coords,
420
+ __image_box,
421
+ __current_mask,
422
+ __current_segment,
423
+ ],
424
+ )
425
+ image_output_mode.change(
426
+ __change_output_mode,
427
+ inputs=[
428
+ image_input,
429
+ image_output_mode,
430
+ __current_mask,
431
+ __current_segment,
432
+ ],
433
+ outputs=[image_output],
434
+ )
435
+
436
+ # Auto Segmentation
437
+ with gradio.Tab(label="Auto Segmentation", id="auto_tab"):
438
+ gradio.Markdown("Auto Segmentation", render=True)
439
  with gradio.Column():
440
+ with gradio.Accordion("Auto Annotation Documentation", open=False):
441
+ gradio.Markdown(
442
+ """
443
+ """
444
+ )
445
  auto_input = gradio.Image("assets/cars.jpg")
446
+ with gradio.Accordion("Advanced Options", open=False):
447
+ auto_generate_SAM_points_per_side = gradio.Slider(
448
+ 1, 64, 12, 1, label="Points Per Side", interactive=True
449
+ )
450
+ auto_generate_SAM_points_per_batch = gradio.Slider(
451
+ 1, 64, 32, 1, label="Points Per Batch", interactive=True
452
+ )
453
+ auto_generate_SAM_pred_iou_thresh = gradio.Slider(
454
+ 0.0, 1.0, 0.8, 1, label="Pred IOU Threshold", interactive=True
455
+ )
456
+ auto_generate_SAM_stability_score_thresh = gradio.Slider(
457
+ 0.0, 1.0, 0.95, label="Stability Score Threshold", interactive=True
458
+ )
459
+ auto_generate_SAM_stability_score_offset = gradio.Slider(
460
+ 0.0, 1.0, 1.0, label="Stability Score Offset", interactive=True
461
+ )
462
+ auto_generate_SAM_mask_threshold = gradio.Slider(
463
+ 0.0, 1.0, 0.0, label="Mask Threshold", interactive=True
464
+ )
465
+ auto_generate_SAM_box_nms_thresh = gradio.Slider(
466
+ 0.0, 1.0, 0.7, label="Box NMS Threshold", interactive=True
467
+ )
468
+ auto_generate_SAM_crop_n_layers = gradio.Slider(
469
+ 0, 10, 0, 1, label="Crop N Layers", interactive=True
470
+ )
471
+ auto_generate_SAM_crop_nms_thresh = gradio.Slider(
472
+ 0.0, 1.0, 0.7, label="Crop NMS Threshold", interactive=True
473
+ )
474
+ auto_generate_SAM_crop_overlay_ratio = gradio.Slider(
475
+ 0.0, 1.0, 512 / 1500, label="Crop Overlay Ratio", interactive=True
476
+ )
477
+ auto_generate_SAM_crop_n_points_downscale_factor = gradio.Slider(
478
+ 1, 10, 1, label="Crop N Points Downscale Factor", interactive=True
479
+ )
480
+ auto_generate_SAM_min_mask_region_area = gradio.Slider(
481
+ 0, 1000, 0, label="Min Mask Region Area", interactive=True
482
+ )
483
+ auto_generate_SAM_use_m2m = gradio.Checkbox(
484
+ label="Use M2M", interactive=True
485
+ )
486
+ auto_generate_SAM_multimask_output = gradio.Checkbox(
487
+ value=True, label="Multi Mask Output", interactive=True
488
+ )
489
  auto_generate_button = gradio.Button("Generate Auto Mask")
490
  with gradio.Row():
491
  with gradio.Column():
492
+ auto_output_mode = gradio.Radio(
493
+ ["Segment", "Mask"],
494
+ value="Segment",
495
+ label="Output Mode",
496
+ interactive=True,
497
+ )
498
+ auto_output_list = gradio.CheckboxGroup(
499
+ [], value=[], label="Mask List", interactive=False
500
+ )
501
+ auto_output_bbox = gradio.Checkbox(
502
+ value=False, label="Show Bounding Box", interactive=False
503
+ )
504
+ with gradio.Column(scale=3):
505
  auto_output = gradio_imageslider.ImageSlider()
506
+ with gradio.Accordion("Debug", open=DEBUG, visible=DEBUG):
507
+ __auto_output_gallery = gradio.Gallery(
508
+ None, label="Output Gallery", interactive=False, type="numpy"
509
+ )
510
+ __auto_bbox = gradio.DataFrame(
511
+ value=[[]],
512
+ label="Box",
513
+ interactive=False,
514
+ headers=["XMin", "YMin", "XMax", "YMax"],
515
+ )
516
+
517
  auto_generate_button.click(
518
+ __generate_auto_mask,
519
+ inputs=[
520
+ auto_input,
521
+ auto_generate_SAM_points_per_side,
522
+ auto_generate_SAM_points_per_batch,
523
+ auto_generate_SAM_pred_iou_thresh,
524
+ auto_generate_SAM_stability_score_thresh,
525
+ auto_generate_SAM_stability_score_offset,
526
+ auto_generate_SAM_mask_threshold,
527
+ auto_generate_SAM_box_nms_thresh,
528
+ auto_generate_SAM_crop_n_layers,
529
+ auto_generate_SAM_crop_nms_thresh,
530
+ auto_generate_SAM_crop_overlay_ratio,
531
+ auto_generate_SAM_crop_n_points_downscale_factor,
532
+ auto_generate_SAM_min_mask_region_area,
533
+ auto_generate_SAM_use_m2m,
534
+ auto_generate_SAM_multimask_output,
535
+ auto_output_mode,
536
  ],
537
+ outputs=[
538
  auto_output,
539
  auto_output_list,
540
+ auto_output_bbox,
541
+ __auto_output_gallery,
542
+ __auto_bbox,
543
+ ],
544
  )
545
+ auto_output_list.change(
546
+ __generate_multi_mask_output,
547
+ inputs=[
548
+ auto_input,
549
+ auto_output_list,
550
+ auto_output_mode,
551
+ auto_output_bbox,
552
+ __auto_output_gallery,
553
+ __auto_bbox,
554
+ ],
555
+ outputs=[auto_output],
556
+ )
557
+ auto_output_bbox.change(
558
+ __generate_multi_mask_output,
559
+ inputs=[
560
+ auto_input,
561
+ auto_output_list,
562
+ auto_output_mode,
563
+ auto_output_bbox,
564
+ __auto_output_gallery,
565
+ __auto_bbox,
566
+ ],
567
+ outputs=[auto_output],
568
+ )
569
+ auto_output_mode.change(
570
+ __generate_multi_mask_output,
571
+ inputs=[
572
+ auto_input,
573
+ auto_output_list,
574
+ auto_output_mode,
575
+ auto_output_bbox,
576
+ __auto_output_gallery,
577
+ __auto_bbox,
578
+ ],
579
+ outputs=[auto_output],
580
+ )
581
+
582
+
583
  if __name__ == "__main__":
584
  base_app.launch()
 
src/SegmentAnything2Assist.py CHANGED
@@ -14,204 +14,253 @@ import cv2
14
 
15
  SAM2_MODELS = {
16
  "sam2_hiera_tiny": {
17
- "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
18
- "model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt",
19
- "config_file": "sam2_hiera_t.yaml"
20
  },
21
  "sam2_hiera_small": {
22
- "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
23
- "model_path": ".tmp/checkpoints/sam2_hiera_small.pt",
24
- "config_file": "sam2_hiera_s.yaml"
25
  },
26
  "sam2_hiera_base_plus": {
27
- "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
28
- "model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt",
29
- "config_file": "sam2_hiera_b+.yaml"
30
  },
31
  "sam2_hiera_large": {
32
- "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
33
- "model_path": ".tmp/checkpoints/sam2_hiera_large.pt",
34
- "config_file": "sam2_hiera_l.yaml"
35
  },
36
  }
37
-
 
38
  class SegmentAnything2Assist:
39
- def __init__(
40
- self,
41
- model_name: str | typing.Literal["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_large"] = "sam2_hiera_small",
42
- configuration: str |typing.Literal["Automatic Mask Generator", "Image"] = "Automatic Mask Generator",
43
- download_url: str | None = None,
44
- model_path: str | None = None,
45
- download: bool = True,
46
- device: str | torch.device = torch.device("cpu"),
47
- verbose: bool = True
48
- ) -> None:
49
- assert model_name in SAM2_MODELS.keys(), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}"
50
- assert configuration in ["Automatic Mask Generator", "Image"]
51
-
52
- self.model_name = model_name
53
- self.configuration = configuration
54
- self.config_file = SAM2_MODELS[model_name]["config_file"]
55
- self.device = device
56
-
57
- self.download_url = download_url if download_url is not None else SAM2_MODELS[model_name]["download_url"]
58
- self.model_path = model_path if model_path is not None else SAM2_MODELS[model_name]["model_path"]
59
- os.makedirs(os.path.dirname(self.model_path), exist_ok = True)
60
- self.verbose = verbose
61
-
62
- if self.verbose:
63
- print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}")
64
- print(f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}")
65
- print(f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}")
66
- print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}")
67
- print(f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}")
68
-
69
- if download:
70
- self.download_model()
71
-
72
- if self.is_model_available():
73
- self.sam2 = sam2.build_sam.build_sam2(config_file = self.config_file, ckpt_path = self.model_path, device = self.device)
74
- if self.verbose:
75
- print("SegmentAnything2Assist::__init__::SAM2 is loaded.")
76
- else:
77
- self.sam2 = None
78
- if self.verbose:
79
- print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
80
-
81
-
82
- def is_model_available(self) -> bool:
83
- ret = os.path.exists(self.model_path)
84
- if self.verbose:
85
- print(f"SegmentAnything2Assist::is_model_available::{ret}")
86
- return ret
87
-
88
- def load_model(self) -> None:
89
- if self.is_model_available():
90
- self.sam2 = sam2.build_sam(checkpoint = self.model_path)
91
-
92
- def download_model(
93
- self,
94
- force: bool = False
95
- ) -> None:
96
- if not force and self.is_model_available():
97
- print(f"{self.model_path} already exists. Skipping download.")
98
- return
99
-
100
- response = requests.get(self.download_url, stream=True)
101
- total_size = int(response.headers.get('content-length', 0))
102
-
103
- with open(self.model_path, 'wb') as file, tqdm.tqdm(total = total_size, unit = 'B', unit_scale = True) as progress_bar:
104
- for data in response.iter_content(chunk_size = 1024):
105
- file.write(data)
106
- progress_bar.update(len(data))
107
-
108
- def generate_automatic_masks(
109
- self,
110
- image,
111
- points_per_side = 32,
112
- points_per_batch = 32,
113
- pred_iou_thresh = 0.8,
114
- stability_score_thresh = 0.95,
115
- stability_score_offset = 1.0,
116
- mask_threshold = 0.0,
117
- box_nms_thresh = 0.7,
118
- crop_n_layers = 0,
119
- crop_nms_thresh = 0.7,
120
- crop_overlay_ratio = 512 / 1500,
121
- crop_n_points_downscale_factor = 1,
122
- min_mask_region_area = 0,
123
- use_m2m = False,
124
- multimask_output = True
125
- ):
126
- if self.sam2 is None:
127
- print("SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded.")
128
- return None
129
-
130
- generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator(
131
- model = self.sam2,
132
- points_per_side = points_per_side,
133
- points_per_batch = points_per_batch,
134
- pred_iou_thresh = pred_iou_thresh,
135
- stability_score_thresh = stability_score_thresh,
136
- stability_score_offset = stability_score_offset,
137
- mask_threshold = mask_threshold,
138
- box_nms_thresh = box_nms_thresh,
139
- crop_n_layers = crop_n_layers,
140
- crop_nms_thresh = crop_nms_thresh,
141
- crop_overlay_ratio = crop_overlay_ratio,
142
- crop_n_points_downscale_factor = crop_n_points_downscale_factor,
143
- min_mask_region_area = min_mask_region_area,
144
- use_m2m = use_m2m,
145
- multimask_output = multimask_output
146
- )
147
- masks = generator.generate(image)
148
-
149
- pickle.dump(masks, open(".tmp/auto_masks.pkl", "wb"))
150
-
151
- return masks
152
-
153
- def generate_masks_from_image(
154
- self,
155
- image,
156
- point_coords,
157
- point_labels,
158
- box,
159
- mask_threshold = 0.0,
160
- max_hole_area = 0.0,
161
- max_sprinkle_area = 0.0
162
- ):
163
- generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
164
- self.sam2,
165
- mask_threshold = mask_threshold,
166
- max_hole_area = max_hole_area,
167
- max_sprinkle_area = max_sprinkle_area
168
- )
169
- generator.set_image(image)
170
-
171
- masks_chw, mask_iou, mask_low_logits = generator.predict(
172
- point_coords = numpy.array(point_coords) if point_coords is not None else None,
173
- point_labels = numpy.array(point_labels) if point_labels is not None else None,
174
- box = numpy.array(box) if box is not None else None,
175
- multimask_output = False
176
- )
177
-
178
- return masks_chw, mask_iou
179
-
180
- def apply_mask_to_image(
181
- self,
182
- image,
183
- mask
184
- ):
185
- mask = numpy.array(mask)
186
- mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8)
187
- segment = cv2.bitwise_and(image, image, mask = mask)
188
- return mask, segment
189
-
190
- def apply_auto_mask_to_image(
191
- self,
192
- image,
193
- auto_list
194
- ):
195
- if not os.path.exists(".tmp/auto_masks.pkl"):
196
- return
197
-
198
- masks = pickle.load(open(".tmp/auto_masks.pkl", "rb"))
199
-
200
- image_with_bounding_boxes = image.copy()
201
- all_masks = None
202
- for _ in auto_list:
203
- mask = numpy.array(masks[_]['segmentation'])
204
- mask = numpy.where(mask == True, 255, 0).astype(numpy.uint8)
205
- bbox = masks[_]["bbox"]
206
- if all_masks is None:
207
- all_masks = mask
208
- else:
209
- all_masks = cv2.bitwise_or(all_masks, mask)
210
-
211
- random_color = numpy.random.randint(0, 255, size = 3)
212
- image_with_bounding_boxes = cv2.rectangle(image_with_bounding_boxes, (int(bbox[0]), int(bbox[1])), (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), random_color.tolist(), 2)
213
- image_with_bounding_boxes = cv2.putText(image_with_bounding_boxes, f"{_ + 1}", (int(bbox[0]), int(bbox[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, random_color.tolist(), 2)
214
-
215
- all_masks = numpy.where(all_masks > 0, 255, 0).astype(numpy.uint8)
216
- image_with_segments = cv2.bitwise_and(image, image, mask = all_masks)
217
- return image_with_bounding_boxes, all_masks, image_with_segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  SAM2_MODELS = {
16
  "sam2_hiera_tiny": {
17
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
18
+ "model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt",
19
+ "config_file": "sam2_hiera_t.yaml",
20
  },
21
  "sam2_hiera_small": {
22
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
23
+ "model_path": ".tmp/checkpoints/sam2_hiera_small.pt",
24
+ "config_file": "sam2_hiera_s.yaml",
25
  },
26
  "sam2_hiera_base_plus": {
27
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
28
+ "model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt",
29
+ "config_file": "sam2_hiera_b+.yaml",
30
  },
31
  "sam2_hiera_large": {
32
+ "download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
33
+ "model_path": ".tmp/checkpoints/sam2_hiera_large.pt",
34
+ "config_file": "sam2_hiera_l.yaml",
35
  },
36
  }
37
+
38
+
39
  class SegmentAnything2Assist:
40
+ def __init__(
41
+ self,
42
+ model_name: (
43
+ str
44
+ | typing.Literal[
45
+ "sam2_hiera_tiny",
46
+ "sam2_hiera_small",
47
+ "sam2_hiera_base_plus",
48
+ "sam2_hiera_large",
49
+ ]
50
+ ) = "sam2_hiera_small",
51
+ configuration: (
52
+ str | typing.Literal["Automatic Mask Generator", "Image"]
53
+ ) = "Automatic Mask Generator",
54
+ download_url: str | None = None,
55
+ model_path: str | None = None,
56
+ download: bool = True,
57
+ device: str | torch.device = torch.device("cpu"),
58
+ verbose: bool = True,
59
+ ) -> None:
60
+ assert (
61
+ model_name in SAM2_MODELS.keys()
62
+ ), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}"
63
+ assert configuration in ["Automatic Mask Generator", "Image"]
64
+
65
+ self.model_name = model_name
66
+ self.configuration = configuration
67
+ self.config_file = SAM2_MODELS[model_name]["config_file"]
68
+ self.device = device
69
+
70
+ self.download_url = (
71
+ download_url
72
+ if download_url is not None
73
+ else SAM2_MODELS[model_name]["download_url"]
74
+ )
75
+ self.model_path = (
76
+ model_path
77
+ if model_path is not None
78
+ else SAM2_MODELS[model_name]["model_path"]
79
+ )
80
+ os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
81
+ self.verbose = verbose
82
+
83
+ if self.verbose:
84
+ print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}")
85
+ print(
86
+ f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}"
87
+ )
88
+ print(
89
+ f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}"
90
+ )
91
+ print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}")
92
+ print(
93
+ f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}"
94
+ )
95
+
96
+ if download:
97
+ self.download_model()
98
+
99
+ if self.is_model_available():
100
+ self.sam2 = sam2.build_sam.build_sam2(
101
+ config_file=self.config_file,
102
+ ckpt_path=self.model_path,
103
+ device=self.device,
104
+ )
105
+ if self.verbose:
106
+ print("SegmentAnything2Assist::__init__::SAM2 is loaded.")
107
+ else:
108
+ self.sam2 = None
109
+ if self.verbose:
110
+ print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
111
+
112
+ def is_model_available(self) -> bool:
113
+ ret = os.path.exists(self.model_path)
114
+ if self.verbose:
115
+ print(f"SegmentAnything2Assist::is_model_available::{ret}")
116
+ return ret
117
+
118
+ def load_model(self) -> None:
119
+ if self.is_model_available():
120
+ self.sam2 = sam2.build_sam(checkpoint=self.model_path)
121
+
122
+ def download_model(self, force: bool = False) -> None:
123
+ if not force and self.is_model_available():
124
+ print(f"{self.model_path} already exists. Skipping download.")
125
+ return
126
+
127
+ response = requests.get(self.download_url, stream=True)
128
+ total_size = int(response.headers.get("content-length", 0))
129
+
130
+ with open(self.model_path, "wb") as file, tqdm.tqdm(
131
+ total=total_size, unit="B", unit_scale=True
132
+ ) as progress_bar:
133
+ for data in response.iter_content(chunk_size=1024):
134
+ file.write(data)
135
+ progress_bar.update(len(data))
136
+
137
+ def generate_automatic_masks(
138
+ self,
139
+ image,
140
+ points_per_side=32,
141
+ points_per_batch=32,
142
+ pred_iou_thresh=0.8,
143
+ stability_score_thresh=0.95,
144
+ stability_score_offset=1.0,
145
+ mask_threshold=0.0,
146
+ box_nms_thresh=0.7,
147
+ crop_n_layers=0,
148
+ crop_nms_thresh=0.7,
149
+ crop_overlay_ratio=512 / 1500,
150
+ crop_n_points_downscale_factor=1,
151
+ min_mask_region_area=0,
152
+ use_m2m=False,
153
+ multimask_output=True,
154
+ ):
155
+ if self.sam2 is None:
156
+ print(
157
+ "SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded."
158
+ )
159
+ return None
160
+
161
+ generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator(
162
+ model=self.sam2,
163
+ points_per_side=points_per_side,
164
+ points_per_batch=points_per_batch,
165
+ pred_iou_thresh=pred_iou_thresh,
166
+ stability_score_thresh=stability_score_thresh,
167
+ stability_score_offset=stability_score_offset,
168
+ mask_threshold=mask_threshold,
169
+ box_nms_thresh=box_nms_thresh,
170
+ crop_n_layers=crop_n_layers,
171
+ crop_nms_thresh=crop_nms_thresh,
172
+ crop_overlay_ratio=crop_overlay_ratio,
173
+ crop_n_points_downscale_factor=crop_n_points_downscale_factor,
174
+ min_mask_region_area=min_mask_region_area,
175
+ use_m2m=use_m2m,
176
+ multimask_output=multimask_output,
177
+ )
178
+ masks = generator.generate(image)
179
+ segmentation_masks = [mask for mask in masks]
180
+ segmentation_masks = [
181
+ numpy.where(mask["segmentation"] == True, 255, 0).astype(numpy.uint8)
182
+ for mask in segmentation_masks
183
+ ]
184
+ segmentation_masks = [
185
+ cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) for mask in segmentation_masks
186
+ ]
187
+ bbox_masks = [mask["bbox"] for mask in masks]
188
+
189
+ return masks, segmentation_masks, bbox_masks
190
+
191
+ def generate_masks_from_image(
192
+ self,
193
+ image,
194
+ point_coords,
195
+ point_labels,
196
+ box,
197
+ mask_threshold=0.0,
198
+ max_hole_area=0.0,
199
+ max_sprinkle_area=0.0,
200
+ ):
201
+ generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
202
+ self.sam2,
203
+ mask_threshold=mask_threshold,
204
+ max_hole_area=max_hole_area,
205
+ max_sprinkle_area=max_sprinkle_area,
206
+ )
207
+ generator.set_image(image)
208
+
209
+ masks_chw, mask_iou, mask_low_logits = generator.predict(
210
+ point_coords=(
211
+ numpy.array(point_coords) if point_coords is not None else None
212
+ ),
213
+ point_labels=(
214
+ numpy.array(point_labels) if point_labels is not None else None
215
+ ),
216
+ box=numpy.array(box) if box is not None else None,
217
+ multimask_output=False,
218
+ )
219
+
220
+ return masks_chw, mask_iou
221
+
222
+ def apply_mask_to_image(self, image, mask):
223
+ mask = numpy.array(mask)
224
+ mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8)
225
+ segment = cv2.bitwise_and(image, image, mask=mask)
226
+ return mask, segment
227
+
228
+ def apply_auto_mask_to_image(self, image, auto_list, masks, bboxes):
229
+ image_with_bounding_boxes = image.copy()
230
+ all_masks = None
231
+
232
+ cv2.imwrite(".tmp/mask_2.png", masks[3])
233
+
234
+ for _ in auto_list:
235
+ mask = masks[_]
236
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
237
+
238
+ bbox = bboxes[_]
239
+ if all_masks is None:
240
+ all_masks = mask
241
+ else:
242
+ all_masks = cv2.bitwise_or(all_masks, mask)
243
+
244
+ cv2.imwrite(".tmp/mask_3.png", masks[3])
245
+
246
+ random_color = numpy.random.randint(0, 255, size=3)
247
+ image_with_bounding_boxes = cv2.rectangle(
248
+ image_with_bounding_boxes,
249
+ (int(bbox[0]), int(bbox[1])),
250
+ (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])),
251
+ random_color.tolist(),
252
+ 2,
253
+ )
254
+ image_with_bounding_boxes = cv2.putText(
255
+ image_with_bounding_boxes,
256
+ f"{_ + 1}",
257
+ (int(bbox[0]), int(bbox[1]) - 10),
258
+ cv2.FONT_HERSHEY_SIMPLEX,
259
+ 0.5,
260
+ random_color.tolist(),
261
+ 2,
262
+ )
263
+
264
+ all_masks = all_masks.astype(numpy.uint8)
265
+ image_with_segments = cv2.bitwise_and(image, image, mask=all_masks)
266
+ return image_with_bounding_boxes, all_masks, image_with_segments