xqt commited on
Commit
ed1b7ea
β€’
1 Parent(s): 2e4869a

UPD: Added a YOLOv10 assisted Segmentation Mode.

Browse files
SegmentAnything2AssistApp.py CHANGED
@@ -50,11 +50,11 @@ DEBUG = False
50
 
51
 
52
  segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(
53
- model_name="sam2_hiera_tiny", device=torch.device("cuda")
54
  )
55
 
56
 
57
- def __change_base_model(model_name, device):
58
  global segment_anything2assist
59
  gradio.Info(f"Changing model to {model_name} on {device}", duration=3)
60
  try:
@@ -113,7 +113,7 @@ def __post_process_annotator_inputs(value):
113
 
114
 
115
  @spaces.GPU(duration=60)
116
- def __generate_mask(
117
  value,
118
  mask_threshold,
119
  max_hole_area,
@@ -128,7 +128,7 @@ def __generate_mask(
128
  )
129
 
130
  if VERBOSE:
131
- print("SegmentAnything2AssistApp::__generate_mask::Called.")
132
  mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image(
133
  value["image"],
134
  image_point_coords,
@@ -140,14 +140,16 @@ def __generate_mask(
140
  )
141
 
142
  if VERBOSE:
143
- print("SegmentAnything2AssistApp::__generate_mask::Masks generated.")
144
 
145
  __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(
146
  value["image"], mask_chw[0]
147
  )
148
 
149
  if VERBOSE:
150
- print("SegmentAnything2AssistApp::__generate_mask::Masks and Segments created.")
 
 
151
 
152
  __image_box = gradio.DataFrame(value=[[]])
153
  __image_point_coords = gradio.DataFrame(value=[[]])
@@ -201,9 +203,9 @@ def __generate_mask(
201
  )
202
 
203
 
204
- def __change_output_mode(image_input, radio, __current_mask, __current_segment):
205
  if VERBOSE:
206
- print("SegmentAnything2AssistApp::__generate_mask::Called.")
207
  if __current_mask is None or __current_segment is None:
208
  gradio.Warning("Configuration was changed, generate the mask again", duration=5)
209
  return gradio_imageslider.ImageSlider(render=True)
@@ -216,9 +218,7 @@ def __change_output_mode(image_input, radio, __current_mask, __current_segment):
216
  return gradio_imageslider.ImageSlider(render=True)
217
 
218
 
219
- def __generate_multi_mask_output(
220
- image, auto_list, auto_mode, auto_bbox_mode, masks, bboxes
221
- ):
222
  global segment_anything2assist
223
 
224
  # When value from gallery is called, it is a tuple
@@ -235,7 +235,7 @@ def __generate_multi_mask_output(
235
 
236
 
237
  @spaces.GPU(duration=60)
238
- def __generate_auto_mask(
239
  image,
240
  points_per_side,
241
  points_per_batch,
@@ -255,7 +255,7 @@ def __generate_auto_mask(
255
  ):
256
  global segment_anything2assist
257
  if VERBOSE:
258
- print("SegmentAnything2AssistApp::__generate_auto_mask::Called.")
259
 
260
  __auto_masks, masks, bboxes = segment_anything2assist.generate_automatic_masks(
261
  image,
@@ -296,7 +296,7 @@ def __generate_auto_mask(
296
  else:
297
  choices = [str(i) for i in range(len(__auto_masks))]
298
 
299
- returning_image = __generate_multi_mask_output(
300
  image, ["0"], output_mode, False, masks, bboxes
301
  )
302
  return (
@@ -318,6 +318,74 @@ def __generate_auto_mask(
318
  )
319
 
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  with gradio.Blocks() as base_app:
322
  gradio.Markdown(
323
  """
@@ -342,14 +410,14 @@ with gradio.Blocks() as base_app:
342
  ["cpu", "cuda"], value="cuda", label="Device Choice"
343
  )
344
  base_model_choice.change(
345
- __change_base_model, inputs=[base_model_choice, base_gpu_choice]
346
  )
347
  base_gpu_choice.change(
348
- __change_base_model, inputs=[base_model_choice, base_gpu_choice]
349
  )
350
 
351
  # Image Segmentation
352
- with gradio.Tab(label="Image Segmentation", id="image_tab") as image_tab:
353
  gradio.Markdown("Image Segmentation", render=True)
354
  with gradio.Column():
355
  with gradio.Accordion("Image Annotation Documentation", open=False):
@@ -419,7 +487,7 @@ with gradio.Blocks() as base_app:
419
 
420
  # image_input.change(__post_process_annotator_inputs, inputs = [image_input])
421
  image_generate_mask_button.click(
422
- __generate_mask,
423
  inputs=[
424
  image_input,
425
  image_generate_SAM_mask_threshold,
@@ -436,7 +504,7 @@ with gradio.Blocks() as base_app:
436
  ],
437
  )
438
  image_output_mode.change(
439
- __change_output_mode,
440
  inputs=[
441
  image_input,
442
  image_output_mode,
@@ -447,7 +515,7 @@ with gradio.Blocks() as base_app:
447
  )
448
 
449
  # Auto Segmentation
450
- with gradio.Tab(label="Auto Segmentation", id="auto_tab"):
451
  gradio.Markdown("Auto Segmentation", render=True)
452
  with gradio.Column():
453
  with gradio.Accordion("Auto Annotation Documentation", open=False):
@@ -558,7 +626,7 @@ with gradio.Blocks() as base_app:
558
  )
559
 
560
  auto_generate_button.click(
561
- __generate_auto_mask,
562
  inputs=[
563
  auto_input,
564
  auto_generate_SAM_points_per_side,
@@ -586,7 +654,7 @@ with gradio.Blocks() as base_app:
586
  ],
587
  )
588
  auto_output_list.change(
589
- __generate_multi_mask_output,
590
  inputs=[
591
  auto_input,
592
  auto_output_list,
@@ -598,7 +666,7 @@ with gradio.Blocks() as base_app:
598
  outputs=[auto_output],
599
  )
600
  auto_output_bbox.change(
601
- __generate_multi_mask_output,
602
  inputs=[
603
  auto_input,
604
  auto_output_list,
@@ -610,7 +678,7 @@ with gradio.Blocks() as base_app:
610
  outputs=[auto_output],
611
  )
612
  auto_output_mode.change(
613
- __generate_multi_mask_output,
614
  inputs=[
615
  auto_input,
616
  auto_output_list,
@@ -622,6 +690,121 @@ with gradio.Blocks() as base_app:
622
  outputs=[auto_output],
623
  )
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
 
626
  if __name__ == "__main__":
627
  base_app.launch()
 
50
 
51
 
52
  segment_anything2assist = SegmentAnything2Assist.SegmentAnything2Assist(
53
+ sam_model_name="sam2_hiera_tiny", device=torch.device("cuda")
54
  )
55
 
56
 
57
+ def change_base_model(model_name, device):
58
  global segment_anything2assist
59
  gradio.Info(f"Changing model to {model_name} on {device}", duration=3)
60
  try:
 
113
 
114
 
115
  @spaces.GPU(duration=60)
116
+ def generate_image_mask(
117
  value,
118
  mask_threshold,
119
  max_hole_area,
 
128
  )
129
 
130
  if VERBOSE:
131
+ print("SegmentAnything2AssistApp::generate_image_mask::Called.")
132
  mask_chw, mask_iou = segment_anything2assist.generate_masks_from_image(
133
  value["image"],
134
  image_point_coords,
 
140
  )
141
 
142
  if VERBOSE:
143
+ print("SegmentAnything2AssistApp::generate_image_mask::Masks generated.")
144
 
145
  __current_mask, __current_segment = segment_anything2assist.apply_mask_to_image(
146
  value["image"], mask_chw[0]
147
  )
148
 
149
  if VERBOSE:
150
+ print(
151
+ "SegmentAnything2AssistApp::generate_image_mask::Masks and Segments created."
152
+ )
153
 
154
  __image_box = gradio.DataFrame(value=[[]])
155
  __image_point_coords = gradio.DataFrame(value=[[]])
 
203
  )
204
 
205
 
206
+ def on_image_output_mode_change(image_input, radio, __current_mask, __current_segment):
207
  if VERBOSE:
208
+ print("SegmentAnything2AssistApp::generate_image_mask::Called.")
209
  if __current_mask is None or __current_segment is None:
210
  gradio.Warning("Configuration was changed, generate the mask again", duration=5)
211
  return gradio_imageslider.ImageSlider(render=True)
 
218
  return gradio_imageslider.ImageSlider(render=True)
219
 
220
 
221
+ def __generate_auto_mask(image, auto_list, auto_mode, auto_bbox_mode, masks, bboxes):
 
 
222
  global segment_anything2assist
223
 
224
  # When value from gallery is called, it is a tuple
 
235
 
236
 
237
  @spaces.GPU(duration=60)
238
+ def generate_auto_mask(
239
  image,
240
  points_per_side,
241
  points_per_batch,
 
255
  ):
256
  global segment_anything2assist
257
  if VERBOSE:
258
+ print("SegmentAnything2AssistApp::generate_auto_mask::Called.")
259
 
260
  __auto_masks, masks, bboxes = segment_anything2assist.generate_automatic_masks(
261
  image,
 
296
  else:
297
  choices = [str(i) for i in range(len(__auto_masks))]
298
 
299
+ returning_image = __generate_auto_mask(
300
  image, ["0"], output_mode, False, masks, bboxes
301
  )
302
  return (
 
318
  )
319
 
320
 
321
+ def __generate_yolo_mask(
322
+ image,
323
+ yolo_mask,
324
+ output_mode,
325
+ ):
326
+ global segment_anything2assist
327
+ if VERBOSE:
328
+ print("SegmentAnything2AssistApp::generate_yolo_mask::Called.")
329
+
330
+ mask = yolo_mask[4]
331
+
332
+ if output_mode == "Mask":
333
+ return [image, mask]
334
+
335
+ mask, output_image = segment_anything2assist.apply_mask_to_image(image, mask)
336
+
337
+ if output_mode == "Segment":
338
+ return [image, output_image]
339
+
340
+
341
+ @spaces.GPU(duration=60)
342
+ def generate_yolo_mask(
343
+ image,
344
+ yolo_model_choice,
345
+ mask_threshold,
346
+ max_hole_area,
347
+ max_sprinkle_area,
348
+ output_mode,
349
+ ):
350
+ global segment_anything2assist
351
+ if VERBOSE:
352
+ print("SegmentAnything2AssistApp::generate_yolo_mask::Called.")
353
+
354
+ results = segment_anything2assist.generate_mask_from_image_with_yolo(
355
+ image,
356
+ YOLOv10ModelName=yolo_model_choice,
357
+ mask_threshold=mask_threshold,
358
+ max_hole_area=max_hole_area,
359
+ max_sprinkle_area=max_sprinkle_area,
360
+ )
361
+
362
+ if len(results) > 0:
363
+ if VERBOSE:
364
+ print("SegmentAnything2AssistApp::generate_yolo_mask::Masks generated.")
365
+
366
+ yolo_masks = []
367
+ for result in results:
368
+ yolo_mask = [
369
+ result["name"],
370
+ result["class"],
371
+ result["confidence"],
372
+ [result["box"]],
373
+ result["mask_chw"],
374
+ result["mask_iou"][0].item(),
375
+ ]
376
+ yolo_masks.append(yolo_mask)
377
+
378
+ return __generate_yolo_mask(image, yolo_masks[0], output_mode), gradio.Dataset(
379
+ label="YOLOv10 Assisted Masks", type="values", samples=yolo_masks
380
+ )
381
+
382
+ else:
383
+ if VERBOSE:
384
+ print("SegmentAnything2AssistApp::generate_yolo_mask::No masks generated.")
385
+
386
+ return gradio.ImageSlider(), gradio.Dataset()
387
+
388
+
389
  with gradio.Blocks() as base_app:
390
  gradio.Markdown(
391
  """
 
410
  ["cpu", "cuda"], value="cuda", label="Device Choice"
411
  )
412
  base_model_choice.change(
413
+ change_base_model, inputs=[base_model_choice, base_gpu_choice]
414
  )
415
  base_gpu_choice.change(
416
+ change_base_model, inputs=[base_model_choice, base_gpu_choice]
417
  )
418
 
419
  # Image Segmentation
420
+ with gradio.Tab(label="πŸŒ† Image Segmentation", id="image_tab") as image_tab:
421
  gradio.Markdown("Image Segmentation", render=True)
422
  with gradio.Column():
423
  with gradio.Accordion("Image Annotation Documentation", open=False):
 
487
 
488
  # image_input.change(__post_process_annotator_inputs, inputs = [image_input])
489
  image_generate_mask_button.click(
490
+ generate_image_mask,
491
  inputs=[
492
  image_input,
493
  image_generate_SAM_mask_threshold,
 
504
  ],
505
  )
506
  image_output_mode.change(
507
+ on_image_output_mode_change,
508
  inputs=[
509
  image_input,
510
  image_output_mode,
 
515
  )
516
 
517
  # Auto Segmentation
518
+ with gradio.Tab(label="πŸ€– Auto Segmentation", id="auto_tab"):
519
  gradio.Markdown("Auto Segmentation", render=True)
520
  with gradio.Column():
521
  with gradio.Accordion("Auto Annotation Documentation", open=False):
 
626
  )
627
 
628
  auto_generate_button.click(
629
+ generate_auto_mask,
630
  inputs=[
631
  auto_input,
632
  auto_generate_SAM_points_per_side,
 
654
  ],
655
  )
656
  auto_output_list.change(
657
+ __generate_auto_mask,
658
  inputs=[
659
  auto_input,
660
  auto_output_list,
 
666
  outputs=[auto_output],
667
  )
668
  auto_output_bbox.change(
669
+ __generate_auto_mask,
670
  inputs=[
671
  auto_input,
672
  auto_output_list,
 
678
  outputs=[auto_output],
679
  )
680
  auto_output_mode.change(
681
+ __generate_auto_mask,
682
  inputs=[
683
  auto_input,
684
  auto_output_list,
 
690
  outputs=[auto_output],
691
  )
692
 
693
+ # YOLOv10 assisted Segmentation.
694
+ with gradio.Tab("πŸ€™ YOLOv10 assisted Segmentation"):
695
+ gradio.Markdown("YOLOv10 assisted Segmentation")
696
+ with gradio.Column():
697
+ with gradio.Accordion("YOLOv10 Documentation", open=False):
698
+ gradio.Markdown(
699
+ """
700
+ ### πŸ–ΌοΈ YOLOv10 Assisted Segmentation Documentation
701
+
702
+ YOLOv10 assisted segmentation allows you to generate masks for an image using the YOLOv10 model.
703
+ In this app, you can configure various settings to control the mask generation process.
704
+
705
+ **πŸ“ How to Use YOLOv10 Assisted Segmentation:**
706
+ - Upload or select an image.
707
+ - Choose the desired YOLOv10 model from the dropdown.
708
+ - Adjust the advanced settings to fine-tune the mask generation process.
709
+ - Click the 'Generate YOLOv10 Mask' button to generate masks.
710
+
711
+ **βš™οΈ Advanced Settings:**
712
+ - **SAM Mask Threshold:** Threshold for the SAM mask generation.
713
+ - **Max Hole Area:** Maximum area for holes in the mask.
714
+ - **Max Sprinkle Area:** Maximum area for sprinkled regions in the mask.
715
+
716
+ **🎨 Generating Masks:**
717
+ - Once you have configured the settings, click the 'Generate YOLOv10 Mask' button.
718
+ - The masks will be generated based on the selected parameters.
719
+ - You can view the generated masks and adjust the settings if needed.
720
+ """
721
+ )
722
+
723
+ yolo_input = gradio.Image("assets/cars.jpg")
724
+ yolo_model_choice = gradio.Dropdown(
725
+ choices=["nano", "small", "medium", "base", "large", "xlarge"],
726
+ value="nano",
727
+ label="YOLOv10 Model Choice",
728
+ )
729
+ with gradio.Accordion("Advanced Settings", open=False):
730
+ yolo_generate_SAM_mask_threshold = gradio.Slider(
731
+ 0.0, 1.0, 0.0, label="SAM Mask Threshold"
732
+ )
733
+ yolo_generate_SAM_max_hole_area = gradio.Slider(
734
+ 0, 1000, 0, label="SAM Max Hole Area"
735
+ )
736
+ yolo_generate_SAM_max_sprinkle_area = gradio.Slider(
737
+ 0, 1000, 0, label="SAM Max Sprinkle Area"
738
+ )
739
+
740
+ yolo_generate_mask_button = gradio.Button("Generate YOLOv10 Mask")
741
+ with gradio.Row():
742
+ with gradio.Column():
743
+ yolo_output_mode = gradio.Radio(
744
+ ["Segment", "Mask"], value="Segment", label="Output Mode"
745
+ )
746
+ with gradio.Column(scale=3):
747
+ yolo_output = gradio_imageslider.ImageSlider()
748
+
749
+ with gradio.Accordion("Debug 1", open=DEBUG, visible=DEBUG):
750
+ __yolo_name = gradio.Textbox(
751
+ label="Name", interactive=DEBUG, visible=DEBUG
752
+ )
753
+ __yolo_class = gradio.Number(
754
+ label="Class", interactive=DEBUG, visible=DEBUG
755
+ )
756
+ __yolo_confidence = gradio.Number(
757
+ label="Confidence", interactive=DEBUG, visible=DEBUG
758
+ )
759
+ __yolo_box = gradio.DataFrame(
760
+ value=[[1, 2, 3, 4]], label="Box", interactive=DEBUG, visible=DEBUG
761
+ )
762
+ __yolo_mask = gradio.Image(
763
+ label="Mask", interactive=DEBUG, visible=DEBUG
764
+ )
765
+ __yolo_mask_iou = gradio.Number(
766
+ label="Mask IOU", interactive=DEBUG, visible=DEBUG
767
+ )
768
+
769
+ with gradio.Row():
770
+ yolo_masks = gradio.Dataset(
771
+ label="YOLOv10 Assisted Masks",
772
+ type="values",
773
+ components=[
774
+ __yolo_name,
775
+ __yolo_class,
776
+ __yolo_confidence,
777
+ __yolo_box,
778
+ __yolo_mask,
779
+ __yolo_mask_iou,
780
+ ],
781
+ )
782
+
783
+ yolo_generate_mask_button.click(
784
+ generate_yolo_mask,
785
+ inputs=[
786
+ yolo_input,
787
+ yolo_model_choice,
788
+ yolo_generate_SAM_mask_threshold,
789
+ yolo_generate_SAM_max_hole_area,
790
+ yolo_generate_SAM_max_sprinkle_area,
791
+ yolo_output_mode,
792
+ ],
793
+ outputs=[yolo_output, yolo_masks],
794
+ )
795
+
796
+ yolo_masks.click(
797
+ __generate_yolo_mask,
798
+ inputs=[yolo_input, yolo_masks, yolo_output_mode],
799
+ outputs=[yolo_output],
800
+ )
801
+
802
+ yolo_output_mode.change(
803
+ __generate_yolo_mask,
804
+ inputs=[yolo_input, yolo_masks, yolo_output_mode],
805
+ outputs=[yolo_output],
806
+ )
807
+
808
 
809
  if __name__ == "__main__":
810
  base_app.launch()
requirements.txt CHANGED
@@ -41,6 +41,7 @@ pandas==2.2.2
41
  pillow==10.4.0
42
  portalocker==2.10.1
43
  psutil==5.9.8
 
44
  pydantic==2.8.2
45
  pydantic_core==2.20.1
46
  pydub==0.25.1
@@ -53,7 +54,10 @@ PyYAML==6.0.2
53
  requests==2.32.3
54
  rich==13.7.1
55
  ruff==0.6.2
 
56
  SAM-2 @ git+https://github.com/facebookresearch/segment-anything-2.git@7e1596c0b6462eb1d1ba7e1492430fed95023598
 
 
57
  semantic-version==2.10.0
58
  setuptools==73.0.1
59
  shellingham==1.5.4
@@ -62,11 +66,13 @@ sniffio==1.3.1
62
  spaces==0.29.3
63
  starlette==0.38.2
64
  sympy==1.13.2
 
65
  tomlkit==0.12.0
66
  tqdm==4.66.5
67
  typer==0.12.5
68
  typing_extensions==4.12.2
69
  tzdata==2024.1
 
70
  urllib3==2.2.2
71
  uvicorn==0.30.6
72
  websockets==12.0
 
41
  pillow==10.4.0
42
  portalocker==2.10.1
43
  psutil==5.9.8
44
+ py-cpuinfo==9.0.0
45
  pydantic==2.8.2
46
  pydantic_core==2.20.1
47
  pydub==0.25.1
 
54
  requests==2.32.3
55
  rich==13.7.1
56
  ruff==0.6.2
57
+ safetensors==0.4.5
58
  SAM-2 @ git+https://github.com/facebookresearch/segment-anything-2.git@7e1596c0b6462eb1d1ba7e1492430fed95023598
59
+ scipy==1.14.1
60
+ seaborn==0.13.2
61
  semantic-version==2.10.0
62
  setuptools==73.0.1
63
  shellingham==1.5.4
 
66
  spaces==0.29.3
67
  starlette==0.38.2
68
  sympy==1.13.2
69
+ thop==0.1.1.post2209072238
70
  tomlkit==0.12.0
71
  tqdm==4.66.5
72
  typer==0.12.5
73
  typing_extensions==4.12.2
74
  tzdata==2024.1
75
+ ultralytics @ git+https://github.com/THU-MIG/yolov10.git@cd2f79c70299c9041fb6d19617ef1296f47575b1
76
  urllib3==2.2.2
77
  uvicorn==0.30.6
78
  websockets==12.0
src/SegmentAnything2Assist.py CHANGED
@@ -10,6 +10,8 @@ import pickle
10
  import sam2.build_sam
11
  import sam2.automatic_mask_generator
12
 
 
 
13
  import cv2
14
 
15
  SAM2_MODELS = {
@@ -39,7 +41,7 @@ SAM2_MODELS = {
39
  class SegmentAnything2Assist:
40
  def __init__(
41
  self,
42
- model_name: (
43
  str
44
  | typing.Literal[
45
  "sam2_hiera_tiny",
@@ -56,32 +58,35 @@ class SegmentAnything2Assist:
56
  download: bool = True,
57
  device: str | torch.device = torch.device("cpu"),
58
  verbose: bool = True,
 
59
  ) -> None:
60
  assert (
61
- model_name in SAM2_MODELS.keys()
62
- ), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}"
63
  assert configuration in ["Automatic Mask Generator", "Image"]
64
 
65
- self.model_name = model_name
66
  self.configuration = configuration
67
- self.config_file = SAM2_MODELS[model_name]["config_file"]
68
  self.device = device
69
 
70
  self.download_url = (
71
  download_url
72
  if download_url is not None
73
- else SAM2_MODELS[model_name]["download_url"]
74
  )
75
  self.model_path = (
76
  model_path
77
  if model_path is not None
78
- else SAM2_MODELS[model_name]["model_path"]
79
  )
80
  os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
81
  self.verbose = verbose
82
 
83
  if self.verbose:
84
- print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}")
 
 
85
  print(
86
  f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}"
87
  )
@@ -109,6 +114,8 @@ class SegmentAnything2Assist:
109
  if self.verbose:
110
  print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
111
 
 
 
112
  def is_model_available(self) -> bool:
113
  ret = os.path.exists(self.model_path)
114
  if self.verbose:
@@ -264,3 +271,43 @@ class SegmentAnything2Assist:
264
  all_masks = all_masks.astype(numpy.uint8)
265
  image_with_segments = cv2.bitwise_and(image, image, mask=all_masks)
266
  return image_with_bounding_boxes, all_masks, image_with_segments
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import sam2.build_sam
11
  import sam2.automatic_mask_generator
12
 
13
+ from . import YOLOv10Plugin
14
+
15
  import cv2
16
 
17
  SAM2_MODELS = {
 
41
  class SegmentAnything2Assist:
42
  def __init__(
43
  self,
44
+ sam_model_name: (
45
  str
46
  | typing.Literal[
47
  "sam2_hiera_tiny",
 
58
  download: bool = True,
59
  device: str | torch.device = torch.device("cpu"),
60
  verbose: bool = True,
61
+ YOLOv10Model: YOLOv10Plugin.YOLOv10Plugin | None = None,
62
  ) -> None:
63
  assert (
64
+ sam_model_name in SAM2_MODELS.keys()
65
+ ), f"`sam_model_name` should be either one of {list(SAM2_MODELS.keys())}"
66
  assert configuration in ["Automatic Mask Generator", "Image"]
67
 
68
+ self.sam_model_name = sam_model_name
69
  self.configuration = configuration
70
+ self.config_file = SAM2_MODELS[sam_model_name]["config_file"]
71
  self.device = device
72
 
73
  self.download_url = (
74
  download_url
75
  if download_url is not None
76
+ else SAM2_MODELS[sam_model_name]["download_url"]
77
  )
78
  self.model_path = (
79
  model_path
80
  if model_path is not None
81
+ else SAM2_MODELS[sam_model_name]["model_path"]
82
  )
83
  os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
84
  self.verbose = verbose
85
 
86
  if self.verbose:
87
+ print(
88
+ f"SegmentAnything2Assist::__init__::Model Name: {self.sam_model_name}"
89
+ )
90
  print(
91
  f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}"
92
  )
 
114
  if self.verbose:
115
  print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
116
 
117
+ self.YOLOv10Model = YOLOv10Model
118
+
119
  def is_model_available(self) -> bool:
120
  ret = os.path.exists(self.model_path)
121
  if self.verbose:
 
271
  all_masks = all_masks.astype(numpy.uint8)
272
  image_with_segments = cv2.bitwise_and(image, image, mask=all_masks)
273
  return image_with_bounding_boxes, all_masks, image_with_segments
274
+
275
+ def generate_mask_from_image_with_yolo(
276
+ self,
277
+ image,
278
+ YOLOv10Model: YOLOv10Plugin.YOLOv10Plugin | None = None,
279
+ YOLOv10ModelName: str | None = None,
280
+ mask_threshold=0.0,
281
+ max_hole_area=0.0,
282
+ max_sprinkle_area=0.0,
283
+ ):
284
+ if self.YOLOv10Model is None:
285
+ assert bool(YOLOv10Model) != bool(
286
+ YOLOv10ModelName
287
+ ), "Either YOLOv10Model or YOLOv10ModelName should be provided."
288
+
289
+ if YOLOv10Model is not None:
290
+ self.YOLOv10Model = self.YOLOv10Model
291
+
292
+ if YOLOv10ModelName is not None:
293
+ self.YOLOv10Model = YOLOv10Plugin.YOLOv10Plugin(
294
+ yolo_model_name=YOLOv10ModelName
295
+ )
296
+
297
+ results = self.YOLOv10Model.detect(image)
298
+
299
+ for _, result in enumerate(results):
300
+ mask_chw, mask_iou = self.generate_masks_from_image(
301
+ image,
302
+ point_coords=None,
303
+ point_labels=None,
304
+ box=result["box"],
305
+ mask_threshold=mask_threshold,
306
+ max_hole_area=max_hole_area,
307
+ max_sprinkle_area=max_sprinkle_area,
308
+ )
309
+
310
+ results[_]["mask_chw"] = numpy.squeeze(mask_chw, 0)
311
+ results[_]["mask_iou"] = mask_iou
312
+
313
+ return results
src/YOLOv10Plugin.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ import ultralytics
3
+
4
+ YOLO_V10_MODELS = {
5
+ "nano": "jameslahm/yolov10n",
6
+ "small": "jameslahm/yolov10s",
7
+ "medium": "jameslahm/yolov10m",
8
+ "base": "jameslahm/yolov10b",
9
+ "large": "jameslahm/yolov10l",
10
+ "xlarge": "jameslahm/yolov10x",
11
+ }
12
+
13
+
14
+ class YOLOv10Plugin:
15
+ def __init__(
16
+ self,
17
+ yolo_model_name: (
18
+ str
19
+ | typing.Literal[
20
+ "nano",
21
+ "small",
22
+ "medium",
23
+ "base",
24
+ "large",
25
+ "xlarge",
26
+ ]
27
+ ) = "nano",
28
+ verbose: bool = True,
29
+ ):
30
+ assert (
31
+ yolo_model_name in YOLO_V10_MODELS.keys()
32
+ ), f"`yolo_model_name` should be either one of {list(YOLO_V10_MODELS.keys())}"
33
+ self.yolo_model_name = yolo_model_name
34
+ self.model = ultralytics.YOLOv10.from_pretrained(
35
+ YOLO_V10_MODELS[yolo_model_name]
36
+ )
37
+
38
+ self.verbose = verbose
39
+ if self.verbose:
40
+ print(f"YOLOv10Plugin::__init__::Model Name: {self.yolo_model_name}")
41
+
42
+ def detect(self, image):
43
+ results = self.model(image)
44
+ results = results[0].summary()
45
+
46
+ out = []
47
+ for result in results:
48
+ out.append(
49
+ {
50
+ "name": result["name"],
51
+ "class": result["class"],
52
+ "confidence": result["confidence"],
53
+ "box": [
54
+ int(result["box"]["x1"]),
55
+ int(result["box"]["y1"]),
56
+ int(result["box"]["x2"]),
57
+ int(result["box"]["y2"]),
58
+ ],
59
+ }
60
+ )
61
+
62
+ return out
63
+
64
+
65
+ if __name__ == "__main__":
66
+ yolo = YOLOv10Plugin()
67
+ yolo.detect("https://ultralytics.com/images/zidane.jpg")
src/__init__.py ADDED
File without changes