rexma commited on
Commit
3d4535a
1 Parent(s): d4bcb75

Work in process

Browse files
Files changed (3) hide show
  1. Dockerfile +3 -3
  2. app.py +104 -329
  3. docker-compose.yml +16 -0
Dockerfile CHANGED
@@ -18,11 +18,11 @@ RUN add-apt-repository -y -r ppa:jonathonf/ffmpeg-4 \
18
  && rm -rf /var/lib/apt/lists/*
19
  WORKDIR /app
20
  RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
21
- RUN pip3 install numpy matplotlib pillow gradio==3.38.0 opencv-python ffmpeg-python
22
  RUN git clone https://github.com/facebookresearch/segment-anything-2.git
23
  WORKDIR /app/segment-anything-2
24
- RUN pip3 install -e .
25
- # RUN pip3 install -e ".[demo]"
26
  WORKDIR /app/segment-anything-2/checkpoints
27
  RUN ./download_ckpts.sh
28
  WORKDIR /app
 
18
  && rm -rf /var/lib/apt/lists/*
19
  WORKDIR /app
20
  RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
21
+ RUN pip3 install numpy matplotlib pillow gradio==3.38.0 opencv-python ffmpeg-python moviepy
22
  RUN git clone https://github.com/facebookresearch/segment-anything-2.git
23
  WORKDIR /app/segment-anything-2
24
+ # RUN pip3 install -e .
25
+ RUN pip3 install -e ".[demo]"
26
  WORKDIR /app/segment-anything-2/checkpoints
27
  RUN ./download_ckpts.sh
28
  WORKDIR /app
app.py CHANGED
@@ -11,245 +11,17 @@
11
  # print("Command failed with return code:", result.returncode)
12
  import gc
13
  import math
 
 
14
  import os
 
15
  os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
16
- import shutil
17
  import ffmpeg
18
- import zipfile
19
- import gradio as gr
20
- import torch
21
- import numpy as np
22
- import matplotlib.pyplot as plt
23
- from PIL import Image
24
- from sam2.build_sam import build_sam2
25
- from sam2.sam2_image_predictor import SAM2ImagePredictor
26
- from sam2.build_sam import build_sam2_video_predictor
27
  import cv2
28
 
29
- def clean(Seg_Tracker):
30
- if Seg_Tracker is not None:
31
- predictor, inference_state, image_predictor = Seg_Tracker
32
- predictor.reset_state(inference_state)
33
- del predictor
34
- del inference_state
35
- del image_predictor
36
- del Seg_Tracker
37
- gc.collect()
38
- torch.cuda.empty_cache()
39
- return None, ({}, {}), None, None, 0, None, None, None, 0
40
 
41
- def change_video(input_video):
42
- if input_video is None:
43
- return 0, 0
44
- cap = cv2.VideoCapture(input_video)
45
- fps = cap.get(cv2.CAP_PROP_FPS)
46
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
47
- cap.release()
48
- scale_slider = gr.Slider.update(minimum=1.0,
49
- maximum=fps,
50
- step=1.0,
51
- value=fps,)
52
- frame_per = gr.Slider.update(minimum= 0.0,
53
- maximum= total_frames / fps,
54
- step=1.0/fps,
55
- value=0.0,)
56
- return scale_slider, frame_per
57
-
58
- def get_meta_from_video(Seg_Tracker, input_video, scale_slider, checkpoint):
59
-
60
- output_dir = '/tmp/output_frames'
61
- output_masks_dir = '/tmp/output_masks'
62
- output_combined_dir = '/tmp/output_combined'
63
- clear_folder(output_dir)
64
- clear_folder(output_masks_dir)
65
- clear_folder(output_combined_dir)
66
- if input_video is None:
67
- return None, ({}, {}), None, None, 0, None, None, None, 0
68
- cap = cv2.VideoCapture(input_video)
69
- fps = cap.get(cv2.CAP_PROP_FPS)
70
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
71
- cap.release()
72
- frame_interval = max(1, int(fps // scale_slider))
73
- print(f"frame_interval: {frame_interval}")
74
- try:
75
- ffmpeg.input(input_video, hwaccel='cuda').output(
76
- os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
77
- vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
78
- ).run()
79
- except:
80
- print(f"ffmpeg cuda err")
81
- ffmpeg.input(input_video).output(
82
- os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
83
- vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
84
- ).run()
85
-
86
- first_frame_path = os.path.join(output_dir, '0000000.jpg')
87
- first_frame = cv2.imread(first_frame_path)
88
- first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
89
-
90
- if Seg_Tracker is not None:
91
- del Seg_Tracker
92
- Seg_Tracker = None
93
- gc.collect()
94
- torch.cuda.empty_cache()
95
- torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
96
- if torch.cuda.get_device_properties(0).major >= 8:
97
- torch.backends.cuda.matmul.allow_tf32 = True
98
- torch.backends.cudnn.allow_tf32 = True
99
-
100
- if checkpoint == "tiny":
101
- sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_tiny.pt"
102
- model_cfg = "sam2_hiera_t.yaml"
103
- elif checkpoint == "samll":
104
- sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_small.pt"
105
- model_cfg = "sam2_hiera_s.yaml"
106
- elif checkpoint == "base-plus":
107
- sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_base_plus.pt"
108
- model_cfg = "sam2_hiera_b+.yaml"
109
- elif checkpoint == "large":
110
- sam2_checkpoint = "segment-anything-2/checkpoints/sam2_hiera_large.pt"
111
- model_cfg = "sam2_hiera_l.yaml"
112
-
113
- predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device="cuda")
114
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")
115
-
116
- image_predictor = SAM2ImagePredictor(sam2_model)
117
- inference_state = predictor.init_state(video_path=output_dir)
118
- predictor.reset_state(inference_state)
119
- frame_per = gr.Slider.update(minimum= 0.0,
120
- maximum= total_frames / fps,
121
- step=frame_interval / fps,
122
- value=0.0,)
123
- return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0
124
-
125
- def mask2bbox(mask):
126
- if len(np.where(mask > 0)[0]) == 0:
127
- print(f'not mask')
128
- return np.array([0, 0, 0, 0]).astype(np.int64), False
129
- x_ = np.sum(mask, axis=0)
130
- y_ = np.sum(mask, axis=1)
131
- x0 = np.min(np.nonzero(x_)[0])
132
- x1 = np.max(np.nonzero(x_)[0])
133
- y0 = np.min(np.nonzero(y_)[0])
134
- y1 = np.max(np.nonzero(y_)[0])
135
- return np.array([x0, y0, x1, y1]).astype(np.int64), True
136
-
137
- def sam_stroke(Seg_Tracker, drawing_board, last_draw, frame_num, ann_obj_id):
138
- predictor, inference_state, image_predictor = Seg_Tracker
139
- image_path = f'/tmp/output_frames/{frame_num:07d}.jpg'
140
- image = cv2.imread(image_path)
141
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
142
- display_image = drawing_board["image"]
143
- image_predictor.set_image(image)
144
- input_mask = drawing_board["mask"]
145
- input_mask[input_mask != 0] = 255
146
- if last_draw is not None:
147
- diff_mask = cv2.absdiff(input_mask, last_draw)
148
- input_mask = diff_mask
149
- bbox, hasMask = mask2bbox(input_mask[:, :, 0])
150
- if not hasMask :
151
- return Seg_Tracker, display_image, display_image
152
- masks, scores, logits = image_predictor.predict( point_coords=None, point_labels=None, box=bbox[None, :], multimask_output=False,)
153
- mask = masks > 0.0
154
- masked_frame = show_mask(mask, display_image, ann_obj_id)
155
- masked_with_rect = draw_rect(masked_frame, bbox, ann_obj_id)
156
- frame_idx, object_ids, masks = predictor.add_new_mask(inference_state, frame_idx=frame_num, obj_id=ann_obj_id, mask=mask[0])
157
- last_draw = drawing_board["mask"]
158
- return Seg_Tracker, masked_with_rect, masked_with_rect, last_draw
159
-
160
- def draw_rect(image, bbox, obj_id):
161
- cmap = plt.get_cmap("tab10")
162
- color = np.array(cmap(obj_id)[:3])
163
- rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
164
- inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
165
- x0, y0, x1, y1 = bbox
166
- image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), rgb_color, thickness=2)
167
- return image_with_rect
168
-
169
- def sam_click(Seg_Tracker, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
170
- points_dict, labels_dict = click_stack
171
- predictor, inference_state, image_predictor = Seg_Tracker
172
- ann_frame_idx = frame_num # the frame index we interact with
173
- print(f'ann_frame_idx: {ann_frame_idx}')
174
- point = np.array([[evt.index[0], evt.index[1]]], dtype=np.float32)
175
- if point_mode == "Positive":
176
- label = np.array([1], np.int32)
177
- else:
178
- label = np.array([0], np.int32)
179
-
180
- if ann_frame_idx not in points_dict:
181
- points_dict[ann_frame_idx] = {}
182
- if ann_frame_idx not in labels_dict:
183
- labels_dict[ann_frame_idx] = {}
184
-
185
- if ann_obj_id not in points_dict[ann_frame_idx]:
186
- points_dict[ann_frame_idx][ann_obj_id] = np.empty((0, 2), dtype=np.float32)
187
- if ann_obj_id not in labels_dict[ann_frame_idx]:
188
- labels_dict[ann_frame_idx][ann_obj_id] = np.empty((0,), dtype=np.int32)
189
-
190
- points_dict[ann_frame_idx][ann_obj_id] = np.append(points_dict[ann_frame_idx][ann_obj_id], point, axis=0)
191
- labels_dict[ann_frame_idx][ann_obj_id] = np.append(labels_dict[ann_frame_idx][ann_obj_id], label, axis=0)
192
-
193
- click_stack = (points_dict, labels_dict)
194
-
195
- frame_idx, out_obj_ids, out_mask_logits = predictor.add_new_points(
196
- inference_state=inference_state,
197
- frame_idx=ann_frame_idx,
198
- obj_id=ann_obj_id,
199
- points=points_dict[ann_frame_idx][ann_obj_id],
200
- labels=labels_dict[ann_frame_idx][ann_obj_id],
201
- )
202
-
203
- image_path = f'/tmp/output_frames/{ann_frame_idx:07d}.jpg'
204
- image = cv2.imread(image_path)
205
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
206
-
207
- masked_frame = image.copy()
208
- for i, obj_id in enumerate(out_obj_ids):
209
- mask = (out_mask_logits[i] > 0.0).cpu().numpy()
210
- masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
211
- masked_frame_with_markers = draw_markers(masked_frame, points_dict[ann_frame_idx], labels_dict[ann_frame_idx])
212
-
213
- return Seg_Tracker, masked_frame_with_markers, masked_frame_with_markers, click_stack
214
-
215
- def draw_markers(image, points_dict, labels_dict):
216
- cmap = plt.get_cmap("tab10")
217
- image_h, image_w = image.shape[:2]
218
- marker_size = max(1, int(min(image_h, image_w) * 0.05))
219
-
220
- for obj_id in points_dict:
221
- color = np.array(cmap(obj_id)[:3])
222
- rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
223
- inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
224
- for point, label in zip(points_dict[obj_id], labels_dict[obj_id]):
225
- x, y = int(point[0]), int(point[1])
226
- if label == 1:
227
- cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_CROSS, markerSize=marker_size, thickness=2)
228
- else:
229
- cv2.drawMarker(image, (x, y), inv_color, markerType=cv2.MARKER_TILTED_CROSS, markerSize=int(marker_size / np.sqrt(2)), thickness=2)
230
-
231
- return image
232
-
233
- def show_mask(mask, image=None, obj_id=None):
234
- cmap = plt.get_cmap("tab10")
235
- cmap_idx = 0 if obj_id is None else obj_id
236
- color = np.array([*cmap(cmap_idx)[:3], 0.6])
237
-
238
- h, w = mask.shape[-2:]
239
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
240
- mask_image = (mask_image * 255).astype(np.uint8)
241
- if image is not None:
242
- image_h, image_w = image.shape[:2]
243
- if (image_h, image_w) != (h, w):
244
- raise ValueError(f"Image dimensions ({image_h}, {image_w}) and mask dimensions ({h}, {w}) do not match")
245
- colored_mask = np.zeros_like(image, dtype=np.uint8)
246
- for c in range(3):
247
- colored_mask[..., c] = mask_image[..., c]
248
- alpha_mask = mask_image[..., 3] / 255.0
249
- for c in range(3):
250
- image[..., c] = np.where(alpha_mask > 0, (1 - alpha_mask) * image[..., c] + alpha_mask * colored_mask[..., c], image[..., c])
251
- return image
252
- return mask_image
253
 
254
  def show_res_by_slider(frame_per, click_stack):
255
  image_path = '/tmp/output_frames'
@@ -274,85 +46,11 @@ def show_res_by_slider(frame_per, click_stack):
274
  print(f"{chosen_frame_path}")
275
  chosen_frame_show = cv2.imread(chosen_frame_path)
276
  chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
277
- points_dict, labels_dict = click_stack
278
  if frame_num in points_dict and frame_num in labels_dict:
279
  chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
280
  return chosen_frame_show, chosen_frame_show, frame_num
281
 
282
- def clear_folder(folder_path):
283
- if os.path.exists(folder_path):
284
- shutil.rmtree(folder_path)
285
- os.makedirs(folder_path)
286
-
287
- def zip_folder(folder_path, output_zip_path):
288
- with zipfile.ZipFile(output_zip_path, 'w', zipfile.ZIP_STORED) as zipf:
289
- for root, _, files in os.walk(folder_path):
290
- for file in files:
291
- file_path = os.path.join(root, file)
292
- zipf.write(file_path, os.path.relpath(file_path, folder_path))
293
-
294
- def tracking_objects(Seg_Tracker, frame_num, input_video):
295
- output_dir = '/tmp/output_frames'
296
- output_masks_dir = '/tmp/output_masks'
297
- output_combined_dir = '/tmp/output_combined'
298
- output_video_path = '/tmp/output_video.mp4'
299
- output_zip_path = '/tmp/output_masks.zip'
300
- clear_folder(output_masks_dir)
301
- clear_folder(output_combined_dir)
302
- if os.path.exists(output_video_path):
303
- os.remove(output_video_path)
304
- if os.path.exists(output_zip_path):
305
- os.remove(output_zip_path)
306
- video_segments = {}
307
- predictor, inference_state, image_predictor = Seg_Tracker
308
- for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
309
- video_segments[out_frame_idx] = {
310
- out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
311
- for i, out_obj_id in enumerate(out_obj_ids)
312
- }
313
- frame_files = sorted([f for f in os.listdir(output_dir) if f.endswith('.jpg')])
314
- # for frame_idx in sorted(video_segments.keys()):
315
- for frame_file in frame_files:
316
- frame_idx = int(os.path.splitext(frame_file)[0])
317
- frame_path = os.path.join(output_dir, frame_file)
318
- image = cv2.imread(frame_path)
319
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
320
- masked_frame = image.copy()
321
- if frame_idx in video_segments:
322
- for obj_id, mask in video_segments[frame_idx].items():
323
- masked_frame = show_mask(mask, image=masked_frame, obj_id=obj_id)
324
- mask_output_path = os.path.join(output_masks_dir, f'{obj_id}_{frame_idx:07d}.png')
325
- cv2.imwrite(mask_output_path, show_mask(mask))
326
- combined_output_path = os.path.join(output_combined_dir, f'{frame_idx:07d}.png')
327
- combined_image_bgr = cv2.cvtColor(masked_frame, cv2.COLOR_RGB2BGR)
328
- cv2.imwrite(combined_output_path, combined_image_bgr)
329
- if frame_idx == frame_num:
330
- final_masked_frame = masked_frame
331
-
332
- cap = cv2.VideoCapture(input_video)
333
- fps = cap.get(cv2.CAP_PROP_FPS)
334
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
335
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
336
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
337
- cap.release()
338
- # output_frames = int(total_frames * scale_slider)
339
- output_frames = len([name for name in os.listdir(output_combined_dir) if os.path.isfile(os.path.join(output_combined_dir, name)) and name.endswith('.png')])
340
- out_fps = fps * output_frames / total_frames
341
- # ffmpeg.input(os.path.join(output_combined_dir, '%07d.png'), framerate=out_fps).output(output_video_path, vcodec='h264_nvenc', pix_fmt='yuv420p').run()
342
- fourcc = cv2.VideoWriter_fourcc(*"mp4v")
343
- out = cv2.VideoWriter(output_video_path, fourcc, out_fps, (frame_width, frame_height))
344
-
345
- for i in range(output_frames):
346
- frame_path = os.path.join(output_combined_dir, f'{i:07d}.png')
347
- frame = cv2.imread(frame_path)
348
- out.write(frame)
349
-
350
- out.release()
351
-
352
- zip_folder(output_masks_dir, output_zip_path)
353
- print("done")
354
- return final_masked_frame, final_masked_frame, output_video_path, output_video_path, output_zip_path
355
-
356
  def increment_ann_obj_id(ann_obj_id):
357
  ann_obj_id += 1
358
  return ann_obj_id
@@ -360,7 +58,87 @@ def increment_ann_obj_id(ann_obj_id):
360
  def drawing_board_get_input_first_frame(input_first_frame):
361
  return input_first_frame
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def seg_track_app():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  ##########################################################
366
  ###################### Front-end ########################
@@ -438,8 +216,7 @@ def seg_track_app():
438
  '''
439
  )
440
 
441
- click_stack = gr.State(({}, {}))
442
- Seg_Tracker = gr.State(None)
443
  frame_num = gr.State(value=(int(0)))
444
  ann_obj_id = gr.State(value=(int(0)))
445
  last_draw = gr.State(None)
@@ -474,7 +251,7 @@ def seg_track_app():
474
 
475
  tab_click = gr.Tab(label="Point Prompt")
476
  with tab_click:
477
- input_first_frame = gr.Image(label='Segment result of first frame',interactive=True).style(height=550)
478
  with gr.Row():
479
  point_mode = gr.Radio(
480
  choices=["Positive", "Negative"],
@@ -549,18 +326,16 @@ def seg_track_app():
549
  preprocess_button.click(
550
  fn=get_meta_from_video,
551
  inputs=[
552
- Seg_Tracker,
553
  input_video,
554
  scale_slider,
555
- checkpoint
556
  ],
557
  outputs=[
558
- Seg_Tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
559
  ]
560
  )
561
 
562
  frame_per.release(
563
- fn=show_res_by_slider,
564
  inputs=[
565
  frame_per, click_stack
566
  ],
@@ -571,20 +346,21 @@ def seg_track_app():
571
 
572
  # Interactively modify the mask acc click
573
  input_first_frame.select(
574
- fn=sam_click,
575
  inputs=[
576
- Seg_Tracker, frame_num, point_mode, click_stack, ann_obj_id
577
  ],
578
  outputs=[
579
- Seg_Tracker, input_first_frame, drawing_board, click_stack
580
  ]
581
  )
582
 
583
  # Track object in video
584
  track_for_video.click(
585
- fn=tracking_objects,
586
  inputs=[
587
- Seg_Tracker,
 
588
  frame_num,
589
  input_video,
590
  ],
@@ -599,11 +375,9 @@ def seg_track_app():
599
 
600
  reset_button.click(
601
  fn=clean,
602
- inputs=[
603
- Seg_Tracker
604
- ],
605
  outputs=[
606
- Seg_Tracker, click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
607
  ]
608
  )
609
 
@@ -624,12 +398,12 @@ def seg_track_app():
624
  )
625
 
626
  seg_acc_stroke.click(
627
- fn=sam_stroke,
628
  inputs=[
629
- Seg_Tracker, drawing_board, last_draw, frame_num, ann_obj_id
630
  ],
631
  outputs=[
632
- Seg_Tracker, input_first_frame, drawing_board, last_draw
633
  ]
634
  )
635
 
@@ -640,7 +414,8 @@ def seg_track_app():
640
  )
641
 
642
  app.queue(concurrency_count=1)
643
- app.launch(debug=True, enable_queue=True, share=False)
644
 
645
  if __name__ == "__main__":
646
- seg_track_app()
 
 
11
  # print("Command failed with return code:", result.returncode)
12
  import gc
13
  import math
14
+ # import multiprocessing as mp
15
+ import torch.multiprocessing as mp
16
  import os
17
+ from process_wrappers import clear_folder, draw_markers, sam_click_wrapper1, sam_stroke_process, tracking_objects_process
18
  os.environ["TORCH_CUDNN_SDPA_ENABLED"] = "1"
 
19
  import ffmpeg
 
 
 
 
 
 
 
 
 
20
  import cv2
21
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def clean():
24
+ return ({}, {}, {}), None, None, 0, None, None, None, 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def show_res_by_slider(frame_per, click_stack):
27
  image_path = '/tmp/output_frames'
 
46
  print(f"{chosen_frame_path}")
47
  chosen_frame_show = cv2.imread(chosen_frame_path)
48
  chosen_frame_show = cv2.cvtColor(chosen_frame_show, cv2.COLOR_BGR2RGB)
49
+ points_dict, labels_dict, masks_dict = click_stack
50
  if frame_num in points_dict and frame_num in labels_dict:
51
  chosen_frame_show = draw_markers(chosen_frame_show, points_dict[frame_num], labels_dict[frame_num])
52
  return chosen_frame_show, chosen_frame_show, frame_num
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def increment_ann_obj_id(ann_obj_id):
55
  ann_obj_id += 1
56
  return ann_obj_id
 
58
  def drawing_board_get_input_first_frame(input_first_frame):
59
  return input_first_frame
60
 
61
+ def sam_stroke_wrapper(click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id):
62
+ queue = mp.Queue()
63
+ p = mp.Process(target=sam_stroke_process, args=(queue, click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id))
64
+ p.start()
65
+ error, result = queue.get()
66
+ p.join()
67
+ if error:
68
+ raise Exception(f"Error in sam_stroke_process: {error}")
69
+ return result
70
+
71
+ def tracking_objects_wrapper(click_stack, checkpoint, frame_num, input_video):
72
+ queue = mp.Queue()
73
+ p = mp.Process(target=tracking_objects_process, args=(queue, click_stack, checkpoint, frame_num, input_video))
74
+ p.start()
75
+ error, result = queue.get()
76
+ p.join()
77
+ if error:
78
+ raise Exception(f"Error in sam_stroke_process: {error}")
79
+ return result
80
+
81
  def seg_track_app():
82
+ import gradio as gr
83
+
84
+ def sam_click_wrapper(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
85
+ return sam_click_wrapper1(checkpoint, frame_num, point_mode, click_stack, ann_obj_id, [evt.index[0], evt.index[1]])
86
+
87
+ def change_video(input_video):
88
+ import gradio as gr
89
+ if input_video is None:
90
+ return 0, 0
91
+ cap = cv2.VideoCapture(input_video)
92
+ fps = cap.get(cv2.CAP_PROP_FPS)
93
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
94
+ cap.release()
95
+ scale_slider = gr.Slider.update(minimum=1.0,
96
+ maximum=fps,
97
+ step=1.0,
98
+ value=fps,)
99
+ frame_per = gr.Slider.update(minimum= 0.0,
100
+ maximum= total_frames / fps,
101
+ step=1.0/fps,
102
+ value=0.0,)
103
+ return scale_slider, frame_per
104
+
105
+ def get_meta_from_video(input_video, scale_slider):
106
+ import gradio as gr
107
+ output_dir = '/tmp/output_frames'
108
+ output_masks_dir = '/tmp/output_masks'
109
+ output_combined_dir = '/tmp/`output_combined`'
110
+ clear_folder(output_dir)
111
+ clear_folder(output_masks_dir)
112
+ clear_folder(output_combined_dir)
113
+ if input_video is None:
114
+ return ({}, {}, {}), None, None, 0, None, None, None, 0
115
+ cap = cv2.VideoCapture(input_video)
116
+ fps = cap.get(cv2.CAP_PROP_FPS)
117
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
118
+ cap.release()
119
+ frame_interval = max(1, int(fps // scale_slider))
120
+ print(f"frame_interval: {frame_interval}")
121
+ try:
122
+ ffmpeg.input(input_video, hwaccel='cuda').output(
123
+ os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
124
+ vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
125
+ ).run()
126
+ except:
127
+ print(f"ffmpeg cuda err")
128
+ ffmpeg.input(input_video).output(
129
+ os.path.join(output_dir, '%07d.jpg'), q=2, start_number=0,
130
+ vf=rf'select=not(mod(n\,{frame_interval}))', vsync='vfr'
131
+ ).run()
132
+
133
+ first_frame_path = os.path.join(output_dir, '0000000.jpg')
134
+ first_frame = cv2.imread(first_frame_path)
135
+ first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
136
+
137
+ frame_per = gr.Slider.update(minimum= 0.0,
138
+ maximum= total_frames / fps,
139
+ step=frame_interval / fps,
140
+ value=0.0,)
141
+ return ({}, {}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0
142
 
143
  ##########################################################
144
  ###################### Front-end ########################
 
216
  '''
217
  )
218
 
219
+ click_stack = gr.State(({}, {}, {}))
 
220
  frame_num = gr.State(value=(int(0)))
221
  ann_obj_id = gr.State(value=(int(0)))
222
  last_draw = gr.State(None)
 
251
 
252
  tab_click = gr.Tab(label="Point Prompt")
253
  with tab_click:
254
+ input_first_frame = gr.Image(label='Segment result of first frame',interactive=True, height=550)
255
  with gr.Row():
256
  point_mode = gr.Radio(
257
  choices=["Positive", "Negative"],
 
326
  preprocess_button.click(
327
  fn=get_meta_from_video,
328
  inputs=[
 
329
  input_video,
330
  scale_slider,
 
331
  ],
332
  outputs=[
333
+ click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
334
  ]
335
  )
336
 
337
  frame_per.release(
338
+ fn= show_res_by_slider,
339
  inputs=[
340
  frame_per, click_stack
341
  ],
 
346
 
347
  # Interactively modify the mask acc click
348
  input_first_frame.select(
349
+ fn=sam_click_wrapper,
350
  inputs=[
351
+ checkpoint, frame_num, point_mode, click_stack, ann_obj_id
352
  ],
353
  outputs=[
354
+ input_first_frame, drawing_board, click_stack
355
  ]
356
  )
357
 
358
  # Track object in video
359
  track_for_video.click(
360
+ fn=tracking_objects_wrapper,
361
  inputs=[
362
+ click_stack,
363
+ checkpoint,
364
  frame_num,
365
  input_video,
366
  ],
 
375
 
376
  reset_button.click(
377
  fn=clean,
378
+ inputs=[],
 
 
379
  outputs=[
380
+ click_stack, input_first_frame, drawing_board, frame_per, output_video, output_mp4, output_mask, ann_obj_id
381
  ]
382
  )
383
 
 
398
  )
399
 
400
  seg_acc_stroke.click(
401
+ fn=sam_stroke_wrapper,
402
  inputs=[
403
+ click_stack, checkpoint, drawing_board, last_draw, frame_num, ann_obj_id
404
  ],
405
  outputs=[
406
+ click_stack, input_first_frame, drawing_board, last_draw
407
  ]
408
  )
409
 
 
414
  )
415
 
416
  app.queue(concurrency_count=1)
417
+ app.launch(debug=True, share=False)
418
 
419
  if __name__ == "__main__":
420
+ mp.set_start_method('spawn', force=True)
421
+ seg_track_app()
docker-compose.yml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ services:
3
+ webuimedsam2:
4
+ init: true
5
+ restart: "always"
6
+ image: webuimedsam2
7
+ deploy:
8
+ resources:
9
+ reservations:
10
+ devices:
11
+ - driver: nvidia
12
+ capabilities: [gpu]
13
+ ports:
14
+ - "7860:7860"
15
+
16
+