Peng Shiya commited on
Commit
50a5f00
1 Parent(s): 7fb3377

feature: enable_segment_all

Browse files
Files changed (2) hide show
  1. app.py +13 -10
  2. app_configs.py +1 -0
app.py CHANGED
@@ -20,7 +20,6 @@ def load_sam_instance():
20
  if sam is None:
21
  gr.Info('Initialising SAM, hang in there...')
22
  if not os.path.exists(configs.model_ckpt_path):
23
- gr.Info('Downloading weights from hugging face hub')
24
  chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
25
  else:
26
  chkpt_path = configs.model_ckpt_path
@@ -46,12 +45,12 @@ with block:
46
  point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
47
  reset_btn = gr.Button('Reset')
48
  run_btn = gr.Button('Run', variant = 'primary')
 
49
  with gr.Column():
50
  with gr.Tab('Cutout'):
51
  cutout_gallery = gr.Gallery()
52
  with gr.Tab('Annotation'):
53
  masks_annotated_image = gr.AnnotatedImage(label='Segments')
54
- gr.Examples(examples=[['examples/cat-256.png','examples/cat-256.png']],inputs=[input_image, raw_image])
55
 
56
  # components
57
  components = {point_coords, point_labels, raw_image, input_image, point_label_radio, reset_btn, run_btn, cutout_gallery, masks_annotated_image}
@@ -70,7 +69,7 @@ with block:
70
  x, y = evt.index
71
  color = red if point_label_radio == 0 else blue
72
  img = np.array(input_image)
73
- cv2.circle(img, (x, y), 10, color, -1)
74
  img = Image.fromarray(img)
75
  point_coords.append([x,y])
76
  point_labels.append(point_label_radio)
@@ -78,19 +77,23 @@ with block:
78
  input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
79
 
80
  # event - inference
81
- def on_run_btn_click(data):
82
  sam = load_sam_instance()
83
- image = data[raw_image]
84
- if len(data[point_coords]) == 0:
85
- masks, _ = service.predict_all(sam, image)
 
 
 
 
86
  else:
87
  masks, _ = service.predict_conditioned(sam,
88
  image,
89
- point_coords=np.array(data[point_coords]),
90
- point_labels=np.array(data[point_labels]))
91
  annotated = (image, [(masks[i], f'Mask {i}') for i in range(len(masks))])
92
  cutouts = [service.cutout(image, mask) for mask in masks]
93
- return cutouts, annotated, masks
94
  run_btn.click(on_run_btn_click, components, [cutout_gallery, masks_annotated_image], queue=True)
95
 
96
  if __name__ == '__main__':
 
20
  if sam is None:
21
  gr.Info('Initialising SAM, hang in there...')
22
  if not os.path.exists(configs.model_ckpt_path):
 
23
  chkpt_path = hf_hub_download("ybelkada/segment-anything", configs.model_ckpt_path)
24
  else:
25
  chkpt_path = configs.model_ckpt_path
 
45
  point_label_radio = gr.Radio(label='Point Label', choices=[1,0], value=1)
46
  reset_btn = gr.Button('Reset')
47
  run_btn = gr.Button('Run', variant = 'primary')
48
+ gr.Examples(examples=[['examples/cat-256.png','examples/cat-256.png']],inputs=[input_image, raw_image])
49
  with gr.Column():
50
  with gr.Tab('Cutout'):
51
  cutout_gallery = gr.Gallery()
52
  with gr.Tab('Annotation'):
53
  masks_annotated_image = gr.AnnotatedImage(label='Segments')
 
54
 
55
  # components
56
  components = {point_coords, point_labels, raw_image, input_image, point_label_radio, reset_btn, run_btn, cutout_gallery, masks_annotated_image}
 
69
  x, y = evt.index
70
  color = red if point_label_radio == 0 else blue
71
  img = np.array(input_image)
72
+ cv2.circle(img, (x, y), 5, color, -1)
73
  img = Image.fromarray(img)
74
  point_coords.append([x,y])
75
  point_labels.append(point_label_radio)
 
77
  input_image.select(on_input_image_select, [input_image, point_coords, point_labels, point_label_radio], [input_image, point_coords, point_labels], queue=False)
78
 
79
  # event - inference
80
+ def on_run_btn_click(inputs):
81
  sam = load_sam_instance()
82
+ image = inputs[raw_image]
83
+ if len(inputs[point_coords]) == 0:
84
+ if configs.enable_segment_all:
85
+ masks, _ = service.predict_all(sam, image)
86
+ else:
87
+ gr.Warning('Segment-all disabled, set point label(s) before running')
88
+ return inputs[cutout_gallery], inputs[masks_annotated_image]
89
  else:
90
  masks, _ = service.predict_conditioned(sam,
91
  image,
92
+ point_coords=np.array(inputs[point_coords]),
93
+ point_labels=np.array(inputs[point_labels]))
94
  annotated = (image, [(masks[i], f'Mask {i}') for i in range(len(masks))])
95
  cutouts = [service.cutout(image, mask) for mask in masks]
96
+ return cutouts, annotated
97
  run_btn.click(on_run_btn_click, components, [cutout_gallery, masks_annotated_image], queue=True)
98
 
99
  if __name__ == '__main__':
app_configs.py CHANGED
@@ -2,3 +2,4 @@ model_type = r'vit_b'
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
  device = 'cpu'
 
 
2
  # model_ckpt_path = None
3
  model_ckpt_path = "checkpoints/sam_vit_b_01ec64.pth"
4
  device = 'cpu'
5
+ enable_segment_all = False