AIBoy1993 commited on
Commit
ba0d063
1 Parent(s): 7301e8a

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +18 -16
  2. app.py +15 -74
  3. inference.py +156 -0
  4. requirements.txt +1 -2
README.md CHANGED
@@ -1,25 +1,18 @@
1
- ---
2
- title: Segment Anything
3
- emoji: 🚀
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.24.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
  # Segment Anything WebUI
13
 
14
- This project is based on **[Segment Anything Model](https://segment-anything.com/) ** by Meta. The UI is based on [Gradio](https://gradio.app/).
15
 
16
  - Try deme on HF: [AIBoy1993/segment_anything_webui](https://huggingface.co/spaces/AIBoy1993/segment_anything_webui)
 
17
 
18
  ![](./images/20230408023615.png)
19
 
20
  ## Change Logs
21
 
22
- - [2023-4-11] Support video segmentation.
 
 
 
23
 
24
  ## **Usage**
25
 
@@ -45,16 +38,25 @@ git clone https://github.com/5663015/segment_anything_webui.git
45
 
46
  - `vit_b`: [ViT-B SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
47
 
 
48
  - Run:
49
 
50
  ```
51
  python app.py
52
  ```
53
 
54
- **Note:** Default model is `vit_b`,the demo can run on CPU. Default device is `cuda`。
55
 
56
  ## TODO
57
 
58
- - Add segmentation prompt (point and box)
 
 
 
 
 
 
 
 
 
59
 
60
- - Add text prompt
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Segment Anything WebUI
2
 
3
+ This project is based on **[Segment Anything Model](https://segment-anything.com/)** by Meta. The UI is based on [Gradio](https://gradio.app/).
4
 
5
  - Try deme on HF: [AIBoy1993/segment_anything_webui](https://huggingface.co/spaces/AIBoy1993/segment_anything_webui)
6
+ - [GitHub](https://github.com/5663015/segment_anything_webui)
7
 
8
  ![](./images/20230408023615.png)
9
 
10
  ## Change Logs
11
 
12
+ - [2023-4-11]
13
+ - Support video segmentation. A short video can be automatically segmented by SAM.
14
+ - Support text prompt segmentation using [OWL-ViT](https://huggingface.co/docs/transformers/v4.27.2/en/model_doc/owlvit#overview) (Vision Transformer for Open-World Localization) model.
15
+
16
 
17
  ## **Usage**
18
 
 
38
 
39
  - `vit_b`: [ViT-B SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
40
 
41
+ - Under `checkpoints`, make a new folder named `models--google--owlvit-base-patch32`, and put the downloaded [OWL-ViT weights](https://huggingface.co/google/owlvit-base-patch32) files in `models--google--owlvit-base-patch32`.
42
  - Run:
43
 
44
  ```
45
  python app.py
46
  ```
47
 
48
+ **Note:** Default model is `vit_b`,the demo can run on CPU. Default device is `cpu`。
49
 
50
  ## TODO
51
 
52
+ - [x] Video segmentation
53
+
54
+ - [x] Add text prompt
55
+
56
+ - [ ] Add segmentation prompt (point and box)
57
+
58
+ ## Reference
59
+
60
+ - Thanks to the wonderful work [Segment Anything](https://segment-anything.com/) and [OWL-ViT](https://arxiv.org/abs/2205.06230)
61
+ - Some video processing code references [kadirnar/segment-anything-video](https://github.com/kadirnar/segment-anything-video), and some OWL-ViT code references [ngthanhtin/owlvit_segment_anything](https://github.com/ngthanhtin/owlvit_segment_anything).
62
 
 
app.py CHANGED
@@ -1,73 +1,8 @@
1
  import os
2
- import cv2
3
- import sys
4
- import numpy as np
5
  import gradio as gr
6
- from PIL import Image
7
- import matplotlib.pyplot as plt
8
- from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
9
 
10
 
11
- models = {
12
- 'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
13
- 'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
14
- 'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
15
- }
16
-
17
-
18
- def segment_one(img, mask_generator, seed=None):
19
- if seed is not None:
20
- np.random.seed(seed)
21
- masks = mask_generator.generate(img)
22
- sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
23
- mask_all = np.ones((img.shape[0], img.shape[1], 3))
24
- for ann in sorted_anns:
25
- m = ann['segmentation']
26
- color_mask = np.random.random((1, 3)).tolist()[0]
27
- for i in range(3):
28
- mask_all[m == True, i] = color_mask[i]
29
- result = img / 255 * 0.3 + mask_all * 0.7
30
- return result, mask_all
31
-
32
-
33
- def inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
34
- stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, input_x, progress=gr.Progress()):
35
- # sam model
36
- sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
37
- mask_generator = SamAutomaticMaskGenerator(
38
- sam,
39
- points_per_side=points_per_side,
40
- pred_iou_thresh=pred_iou_thresh,
41
- stability_score_thresh=stability_score_thresh,
42
- stability_score_offset=stability_score_offset,
43
- box_nms_thresh=box_nms_thresh,
44
- crop_n_layers=crop_n_layers,
45
- crop_nms_thresh=crop_nms_thresh,
46
- crop_overlap_ratio=512 / 1500,
47
- crop_n_points_downscale_factor=1,
48
- point_grids=None,
49
- min_mask_region_area=min_mask_region_area,
50
- output_mode='binary_mask'
51
- )
52
-
53
- # input is image, type: numpy
54
- if type(input_x) == np.ndarray:
55
- result, mask_all = segment_one(input_x, mask_generator)
56
- return result, mask_all
57
- elif isinstance(input_x, str): # input is video, type: path (str)
58
- cap = cv2.VideoCapture(input_x) # read video
59
- frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
60
- W, H = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
61
- fps = int(cap.get(cv2.CAP_PROP_FPS))
62
- out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc('x', '2', '6', '4'), fps, (W, H), isColor=True)
63
- for _ in progress.tqdm(range(int(frames_num)), desc='Processing video ({} frames, size {}x{})'.format(int(frames_num), W, H)):
64
- ret, frame = cap.read() # read a frame
65
- result, mask_all = segment_one(frame, mask_generator, seed=2023)
66
- result = (result * 255).astype(np.uint8)
67
- out.write(result)
68
- out.release()
69
- cap.release()
70
- return 'output.mp4'
71
 
72
 
73
  with gr.Blocks() as demo:
@@ -82,9 +17,9 @@ with gr.Blocks() as demo:
82
  # select model
83
  model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="Select Model")
84
  # select device
85
- device = gr.Dropdown(["cpu"], value='cpu', label="Select Device")
86
 
87
- # 参数
88
  with gr.Accordion(label='Parameters', open=False):
89
  with gr.Row():
90
  points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
@@ -115,8 +50,14 @@ with gr.Blocks() as demo:
115
  with gr.Row().style(equal_height=True):
116
  with gr.Column():
117
  input_image = gr.Image(type="numpy")
118
- with gr.Row():
119
- button = gr.Button("Auto!")
 
 
 
 
 
 
120
  with gr.Tab(label='Image+Mask'):
121
  output_image = gr.Image(type='numpy')
122
  with gr.Tab(label='Mask'):
@@ -157,14 +98,14 @@ with gr.Blocks() as demo:
157
  )
158
 
159
  # button image
160
- button.click(inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
161
  min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
162
- crop_nms_thresh, input_image],
163
  outputs=[output_image, output_mask])
164
  # button video
165
- button_video.click(inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
166
  min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
167
- crop_nms_thresh, input_video],
168
  outputs=[output_video])
169
 
170
 
 
1
  import os
 
 
 
2
  import gradio as gr
3
+ from inference import run_inference
 
 
4
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  with gr.Blocks() as demo:
 
17
  # select model
18
  model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="Select Model")
19
  # select device
20
+ device = gr.Dropdown(["cpu", "cuda"], value='cpu', label="Select Device")
21
 
22
+ # parameters
23
  with gr.Accordion(label='Parameters', open=False):
24
  with gr.Row():
25
  points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
 
50
  with gr.Row().style(equal_height=True):
51
  with gr.Column():
52
  input_image = gr.Image(type="numpy")
53
+ text = gr.Textbox(label='Text prompt(optional)', info=
54
+ 'If you type words, the OWL-ViT model will be used to detect the objects in the image, '
55
+ 'and the boxes will be feed into SAM model to predict mask. Please use English.',
56
+ placeholder='Multiple words are separated by commas')
57
+ owl_vit_threshold = gr.Slider(value=0.1, minimum=0, maximum=1.0, step=0.01, label="OWL ViT Object Detection threshold",
58
+ info='''A small threshold will generate more objects, but may causing OOM.
59
+ A big threshold may not detect objects, resulting in an error ''')
60
+ button = gr.Button("Auto!")
61
  with gr.Tab(label='Image+Mask'):
62
  output_image = gr.Image(type='numpy')
63
  with gr.Tab(label='Mask'):
 
98
  )
99
 
100
  # button image
101
+ button.click(run_inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
102
  min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
103
+ crop_nms_thresh, owl_vit_threshold, input_image, text],
104
  outputs=[output_image, output_mask])
105
  # button video
106
+ button_video.click(run_inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
107
  min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
108
+ crop_nms_thresh, owl_vit_threshold, input_video, text],
109
  outputs=[output_video])
110
 
111
 
inference.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image, ImageDraw
6
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
7
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
8
+ import gc
9
+
10
+ models = {
11
+ 'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
12
+ 'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
13
+ 'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
14
+ }
15
+
16
+
17
+ def plot_boxes(img, boxes):
18
+ img_pil = Image.fromarray(np.uint8(img * 255)).convert('RGB')
19
+ draw = ImageDraw.Draw(img_pil)
20
+ for box in boxes:
21
+ color = tuple(np.random.randint(0, 255, size=3).tolist())
22
+ x0, y0, x1, y1 = box
23
+ x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
24
+ draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
25
+ return img_pil
26
+
27
+
28
+ def segment_one(img, mask_generator, seed=None):
29
+ if seed is not None:
30
+ np.random.seed(seed)
31
+ masks = mask_generator.generate(img)
32
+ sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
33
+ mask_all = np.ones((img.shape[0], img.shape[1], 3))
34
+ for ann in sorted_anns:
35
+ m = ann['segmentation']
36
+ color_mask = np.random.random((1, 3)).tolist()[0]
37
+ for i in range(3):
38
+ mask_all[m == True, i] = color_mask[i]
39
+ result = img / 255 * 0.3 + mask_all * 0.7
40
+ return result, mask_all
41
+
42
+
43
+ def generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
44
+ min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh,
45
+ input_x, progress=gr.Progress()):
46
+ # sam model
47
+ sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
48
+ mask_generator = SamAutomaticMaskGenerator(
49
+ sam,
50
+ points_per_side=points_per_side,
51
+ pred_iou_thresh=pred_iou_thresh,
52
+ stability_score_thresh=stability_score_thresh,
53
+ stability_score_offset=stability_score_offset,
54
+ box_nms_thresh=box_nms_thresh,
55
+ crop_n_layers=crop_n_layers,
56
+ crop_nms_thresh=crop_nms_thresh,
57
+ crop_overlap_ratio=512 / 1500,
58
+ crop_n_points_downscale_factor=1,
59
+ point_grids=None,
60
+ min_mask_region_area=min_mask_region_area,
61
+ output_mode='binary_mask'
62
+ )
63
+
64
+ # input is image, type: numpy
65
+ if type(input_x) == np.ndarray:
66
+ result, mask_all = segment_one(input_x, mask_generator)
67
+ return result, mask_all
68
+ elif isinstance(input_x, str): # input is video, type: path (str)
69
+ cap = cv2.VideoCapture(input_x) # read video
70
+ frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
71
+ W, H = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
73
+ out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc('x', '2', '6', '4'), fps, (W, H), isColor=True)
74
+ for _ in progress.tqdm(range(int(frames_num)),
75
+ desc='Processing video ({} frames, size {}x{})'.format(int(frames_num), W, H)):
76
+ ret, frame = cap.read() # read a frame
77
+ result, mask_all = segment_one(frame, mask_generator, seed=2023)
78
+ result = (result * 255).astype(np.uint8)
79
+ out.write(result)
80
+ out.release()
81
+ cap.release()
82
+ return 'output.mp4'
83
+
84
+
85
+ def predictor_inference(device, model_type, input_x, input_text, owl_vit_threshold=0.1):
86
+ # sam model
87
+ sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
88
+ predictor = SamPredictor(sam)
89
+ predictor.set_image(input_x) # Process the image to produce an image embedding
90
+
91
+ # split input text
92
+ input_text = [input_text.split(',')]
93
+
94
+ # OWL-ViT model
95
+ # processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
96
+ # owlvit_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
97
+ processor = OwlViTProcessor.from_pretrained('./checkpoints/models--google--owlvit-base-patch32')
98
+ owlvit_model = OwlViTForObjectDetection.from_pretrained("./checkpoints/models--google--owlvit-base-patch32").to(device)
99
+
100
+ # get outputs
101
+ input_text = processor(text=input_text, images=input_x, return_tensors="pt").to(device)
102
+ outputs = owlvit_model(**input_text)
103
+ target_size = torch.Tensor([input_x.shape[:2]]).to(device)
104
+ results = processor.post_process_object_detection(outputs=outputs, target_sizes=target_size,
105
+ threshold=owl_vit_threshold)
106
+
107
+ # get the box with best score
108
+ scores = torch.sigmoid(outputs.logits)
109
+ # best_scores, best_idxs = torch.topk(scores, k=1, dim=1)
110
+ # best_idxs = best_idxs.squeeze(1).tolist()
111
+
112
+ i = 0 # Retrieve predictions for the first image for the corresponding text queries
113
+ boxes_tensor = results[i]["boxes"] # [best_idxs]
114
+ print(boxes_tensor.size())
115
+ boxes = boxes_tensor.cpu().detach().numpy()
116
+ transformed_boxes = predictor.transform.apply_boxes_torch(torch.Tensor(boxes).to(device),
117
+ input_x.shape[:2]) # apply transform to original boxes
118
+
119
+ # predict segmentation according to the boxes
120
+ masks, scores, logits = predictor.predict_torch(
121
+ point_coords=None,
122
+ point_labels=None,
123
+ boxes=transformed_boxes, # only one box
124
+ multimask_output=False,
125
+ )
126
+ masks = masks.cpu().detach().numpy()
127
+ mask_all = np.ones((input_x.shape[0], input_x.shape[1], 3))
128
+ for ann in masks:
129
+ color_mask = np.random.random((1, 3)).tolist()[0]
130
+ for i in range(3):
131
+ mask_all[ann[0] == True, i] = color_mask[i]
132
+ img = input_x / 255 * 0.3 + mask_all * 0.7
133
+ img = plot_boxes(img, boxes_tensor) # image + mask + boxes
134
+
135
+ # free the memory
136
+ owlvit_model.cpu()
137
+ del owlvit_model
138
+ del input_text
139
+ gc.collect()
140
+ torch.cuda.empty_cache()
141
+
142
+ return img, mask_all
143
+
144
+
145
+ def run_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
146
+ stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, owl_vit_threshold, input_x,
147
+ input_text):
148
+ print('prompt text: ', input_text)
149
+ if input_text != '' and not isinstance(input_x, str): # user input text
150
+ print('use predictor_inference')
151
+ return predictor_inference(device, model_type, input_x, input_text, owl_vit_threshold)
152
+ else:
153
+ print('use generator_inference')
154
+ return generator_inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
155
+ min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
156
+ crop_nms_thresh, input_x)
requirements.txt CHANGED
@@ -3,5 +3,4 @@ numpy==1.21.5
3
  opencv_python==4.6.0.66
4
  Pillow==9.5.0
5
  segment_anything==1.0
6
- torch
7
- torchvision
 
3
  opencv_python==4.6.0.66
4
  Pillow==9.5.0
5
  segment_anything==1.0
6
+ transformers==4.27.4