AIBoy1993 commited on
Commit
c119102
1 Parent(s): 81a1f3d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -56
app.py CHANGED
@@ -14,8 +14,25 @@ models = {
14
  'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
15
  }
16
 
17
- def inference(device, model_type, input_img, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
18
- stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
20
  mask_generator = SamAutomaticMaskGenerator(
21
  sam,
@@ -33,37 +50,43 @@ def inference(device, model_type, input_img, points_per_side, pred_iou_thresh, s
33
  output_mode='binary_mask'
34
  )
35
 
36
- masks = mask_generator.generate(input_img)
37
- sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
38
-
39
- mask_all = np.ones((input_img.shape[0], input_img.shape[1], 3))
40
- for ann in sorted_anns:
41
- m = ann['segmentation']
42
- color_mask = np.random.random((1, 3)).tolist()[0]
43
- for i in range(3):
44
- mask_all[m==True, i] = color_mask[i]
45
- result = input_img / 255 * 0.3 + mask_all * 0.7
46
-
47
- return result, mask_all
48
-
 
 
 
 
 
 
49
 
50
 
51
  with gr.Blocks() as demo:
52
  with gr.Row():
53
  gr.Markdown(
54
  '''# Segment Anything!🚀
55
- 分割一切!CV的GPT-3时刻!
56
- [**官方网址**](https://segment-anything.com/)
57
  '''
58
  )
59
  with gr.Row():
60
- # 选择模型类型
61
- model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="选择模型")
62
- # 选择device
63
- device = gr.Dropdown(["cpu"], value='cpu', label="选择你的硬件")
64
 
65
  # 参数
66
- with gr.Accordion(label='参数调整', open=False):
67
  with gr.Row():
68
  points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
69
  info='''The number of points to be sampled along one side of the image. The total
@@ -88,43 +111,63 @@ with gr.Blocks() as demo:
88
  info='''The box IoU cutoff used by non-maximal suppression to filter duplicate
89
  masks between different crops.''')
90
 
91
- # 显示图片
92
- with gr.Row().style(equal_height=True):
93
- with gr.Column():
94
- input_image = gr.Image(type="numpy")
95
- with gr.Row():
96
- button = gr.Button("Auto!")
97
- with gr.Tab(label='原图+mask'):
98
- image_output = gr.Image(type='numpy')
99
- with gr.Tab(label='Mask'):
100
- mask_output = gr.Image(type='numpy')
101
-
102
- gr.Examples(
103
- examples=[os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"),
104
- os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"),
105
- os.path.join(os.path.dirname(__file__), "./images/1.jpg"),
106
- os.path.join(os.path.dirname(__file__), "./images/2.jpg"),
107
- os.path.join(os.path.dirname(__file__), "./images/3.jpg"),
108
- os.path.join(os.path.dirname(__file__), "./images/4.jpg"),
109
- os.path.join(os.path.dirname(__file__), "./images/5.jpg"),
110
- os.path.join(os.path.dirname(__file__), "./images/6.jpg"),
111
- os.path.join(os.path.dirname(__file__), "./images/7.jpg"),
112
- os.path.join(os.path.dirname(__file__), "./images/8.jpg"),
113
- ],
114
- inputs=input_image,
115
- outputs=image_output,
116
- )
117
-
118
-
119
- # 按钮交互
120
- button.click(inference, inputs=[device, model_type, input_image, points_per_side, pred_iou_thresh,
121
- stability_score_thresh, min_mask_region_area, stability_score_offset, box_nms_thresh,
122
- crop_n_layers, crop_nms_thresh],
123
- outputs=[image_output, mask_output])
 
 
 
 
 
 
 
 
 
 
124
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
- demo.launch()
128
 
129
 
130
 
 
14
  'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
15
  }
16
 
17
+
18
+ def segment_one(img, mask_generator, seed=None):
19
+ if seed is not None:
20
+ np.random.seed(seed)
21
+ masks = mask_generator.generate(img)
22
+ sorted_anns = sorted(masks, key=(lambda x: x['area']), reverse=True)
23
+ mask_all = np.ones((img.shape[0], img.shape[1], 3))
24
+ for ann in sorted_anns:
25
+ m = ann['segmentation']
26
+ color_mask = np.random.random((1, 3)).tolist()[0]
27
+ for i in range(3):
28
+ mask_all[m == True, i] = color_mask[i]
29
+ result = img / 255 * 0.3 + mask_all * 0.7
30
+ return result, mask_all
31
+
32
+
33
+ def inference(device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh, min_mask_region_area,
34
+ stability_score_offset, box_nms_thresh, crop_n_layers, crop_nms_thresh, input_x, progress=gr.Progress()):
35
+ # sam model
36
  sam = sam_model_registry[model_type](checkpoint=models[model_type]).to(device)
37
  mask_generator = SamAutomaticMaskGenerator(
38
  sam,
 
50
  output_mode='binary_mask'
51
  )
52
 
53
+ # input is image, type: numpy
54
+ if type(input_x) == np.ndarray:
55
+ result, mask_all = segment_one(input_x, mask_generator)
56
+ return result, mask_all
57
+ elif isinstance(input_x, str): # input is video, type: path (str)
58
+ cap = cv2.VideoCapture(input_x) # read video
59
+ frames_num = cap.get(cv2.CAP_PROP_FRAME_COUNT)
60
+ W, H = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
61
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
62
+ print(fps)
63
+ out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc('x', '2', '6', '4'), fps, (W, H), isColor=True)
64
+ for _ in progress.tqdm(range(int(frames_num)), desc='Processing video ({} frames, size {}x{})'.format(int(frames_num), W, H)):
65
+ ret, frame = cap.read() # read a frame
66
+ result, mask_all = segment_one(frame, mask_generator, seed=2023)
67
+ result = (result * 255).astype(np.uint8)
68
+ out.write(result)
69
+ out.release()
70
+ cap.release()
71
+ return 'output.mp4'
72
 
73
 
74
  with gr.Blocks() as demo:
75
  with gr.Row():
76
  gr.Markdown(
77
  '''# Segment Anything!🚀
78
+ The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks.
79
+ [**Official Project**](https://segment-anything.com/)
80
  '''
81
  )
82
  with gr.Row():
83
+ # select model
84
+ model_type = gr.Dropdown(["vit_b", "vit_l", "vit_h"], value='vit_b', label="Select Model")
85
+ # select device
86
+ device = gr.Dropdown(["cpu", "cuda"], value='cuda', label="Select Device")
87
 
88
  # 参数
89
+ with gr.Accordion(label='Parameters', open=False):
90
  with gr.Row():
91
  points_per_side = gr.Number(value=32, label="points_per_side", precision=0,
92
  info='''The number of points to be sampled along one side of the image. The total
 
111
  info='''The box IoU cutoff used by non-maximal suppression to filter duplicate
112
  masks between different crops.''')
113
 
114
+ # Show image
115
+ with gr.Tab(label='Image'):
116
+ with gr.Row().style(equal_height=True):
117
+ with gr.Column():
118
+ input_image = gr.Image(type="numpy")
119
+ with gr.Row():
120
+ button = gr.Button("Auto!")
121
+ with gr.Tab(label='Image+Mask'):
122
+ output_image = gr.Image(type='numpy')
123
+ with gr.Tab(label='Mask'):
124
+ output_mask = gr.Image(type='numpy')
125
+
126
+ gr.Examples(
127
+ examples=[os.path.join(os.path.dirname(__file__), "./images/53960-scaled.jpg"),
128
+ os.path.join(os.path.dirname(__file__), "./images/2388455-scaled.jpg"),
129
+ os.path.join(os.path.dirname(__file__), "./images/1.jpg"),
130
+ os.path.join(os.path.dirname(__file__), "./images/2.jpg"),
131
+ os.path.join(os.path.dirname(__file__), "./images/3.jpg"),
132
+ os.path.join(os.path.dirname(__file__), "./images/4.jpg"),
133
+ os.path.join(os.path.dirname(__file__), "./images/5.jpg"),
134
+ os.path.join(os.path.dirname(__file__), "./images/6.jpg"),
135
+ os.path.join(os.path.dirname(__file__), "./images/7.jpg"),
136
+ os.path.join(os.path.dirname(__file__), "./images/8.jpg"),
137
+ ],
138
+ inputs=input_image,
139
+ outputs=output_image,
140
+ )
141
+ # Show video
142
+ with gr.Tab(label='Video'):
143
+ with gr.Row().style(equal_height=True):
144
+ with gr.Column():
145
+ input_video = gr.Video()
146
+ with gr.Row():
147
+ button_video = gr.Button("Auto!")
148
+ output_video = gr.Video(format='mp4')
149
+ gr.Markdown('''
150
+ **Note:** processing video will take a long time, please upload a short video.
151
+ ''')
152
+ gr.Examples(
153
+ examples=[os.path.join(os.path.dirname(__file__), "./images/video1.mp4")],
154
+ inputs=input_video,
155
+ outputs=output_video
156
+ )
157
 
158
+ # button image
159
+ button.click(inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
160
+ min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
161
+ crop_nms_thresh, input_image],
162
+ outputs=[output_image, output_mask])
163
+ # button video
164
+ button_video.click(inference, inputs=[device, model_type, points_per_side, pred_iou_thresh, stability_score_thresh,
165
+ min_mask_region_area, stability_score_offset, box_nms_thresh, crop_n_layers,
166
+ crop_nms_thresh, input_video],
167
+ outputs=[output_video])
168
 
169
 
170
+ demo.queue().launch(debug=True, enable_queue=True)
171
 
172
 
173