SkalskiP commited on
Commit
c34b7e0
β€’
1 Parent(s): b643479

The existing video processing pipeline was enhanced by adding segmentation and mask generation functionality.

Browse files
Files changed (1) hide show
  1. app.py +70 -18
app.py CHANGED
@@ -1,50 +1,101 @@
1
- import torch
2
- import time
3
  import uuid
4
- from typing import Tuple
5
 
6
  import gradio as gr
7
- import supervision as sv
8
  import numpy as np
9
- from tqdm import tqdm
10
- from transformers import pipeline
11
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
12
 
13
  START_FRAME = 0
14
  END_FRAME = 10
15
  TOTAL = END_FRAME - START_FRAME
 
16
 
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  SAM_GENERATOR = pipeline(
19
  task="mask-generation",
20
- model="facebook/sam-vit-base",
21
  device=DEVICE)
22
- MASK_ANNOTATOR = sv.MaskAnnotator(
23
- color=sv.Color.red(),
24
- color_lookup=sv.ColorLookup.INDEX)
25
 
26
 
27
  def run_sam(frame: np.ndarray) -> sv.Detections:
28
  # convert from Numpy BGR to PIL RGB
29
  image = Image.fromarray(frame[:, :, ::-1])
30
-
31
  outputs = SAM_GENERATOR(image)
32
  mask = np.array(outputs['masks'])
33
  return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str:
37
  video_info = sv.VideoInfo.from_video_path(source_video)
38
  frame_iterator = iter(sv.get_video_frames_generator(
39
  source_path=source_video, start=START_FRAME, end=END_FRAME))
40
 
41
- with sv.VideoSink(f"{name}.mp4", video_info=video_info) as sink:
42
- for _ in tqdm(range(TOTAL), desc="Masking frames"):
43
- frame = next(frame_iterator)
44
- detections = run_sam(frame)
45
- annotated_frame = MASK_ANNOTATOR.annotate(
46
- scene=frame.copy(), detections=detections)
47
- sink.write_frame(annotated_frame)
48
  return f"{name}.mp4"
49
 
50
 
@@ -60,6 +111,7 @@ def process(
60
 
61
 
62
  with gr.Blocks() as demo:
 
63
  with gr.Row():
64
  with gr.Column():
65
  source_video_player = gr.Video(
 
 
 
1
  import uuid
2
+ from typing import Tuple, List
3
 
4
  import gradio as gr
 
5
  import numpy as np
6
+ import supervision as sv
7
+ import torch
8
  from PIL import Image
9
+ from tqdm import tqdm
10
+ from transformers import pipeline, CLIPModel, CLIPProcessor
11
+
12
+ MARKDOWN = """
13
+ # Auto ProPainter
14
+ This is a demo for automatic removal of objects from videos using
15
+ [Segment Anything Model](https://github.com/facebookresearch/segment-anything),
16
+ [MetaCLIP](https://github.com/facebookresearch/MetaCLIP), and
17
+ [ProPainter](https://github.com/sczhou/ProPainter) combo.
18
+ """
19
 
20
  START_FRAME = 0
21
  END_FRAME = 10
22
  TOTAL = END_FRAME - START_FRAME
23
+ MINIMUM_AREA = 0.01
24
 
25
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
  SAM_GENERATOR = pipeline(
27
  task="mask-generation",
28
+ model="facebook/sam-vit-large",
29
  device=DEVICE)
30
+ CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
31
+ CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
 
32
 
33
 
34
  def run_sam(frame: np.ndarray) -> sv.Detections:
35
  # convert from Numpy BGR to PIL RGB
36
  image = Image.fromarray(frame[:, :, ::-1])
 
37
  outputs = SAM_GENERATOR(image)
38
  mask = np.array(outputs['masks'])
39
  return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
40
 
41
 
42
+ def run_clip(frame: np.ndarray, text: List[str]) -> np.ndarray:
43
+ # convert from Numpy BGR to PIL RGB
44
+ image = Image.fromarray(frame[:, :, ::-1])
45
+ inputs = CLIP_PROCESSOR(text=text, images=image, return_tensors="pt").to(DEVICE)
46
+ outputs = CLIP_MODEL(**inputs)
47
+ probs = outputs.logits_per_image.softmax(dim=1)
48
+ return probs.detach().cpu().numpy()
49
+
50
+
51
+ def gray_background(image: np.ndarray, mask: np.ndarray, gray_value=128):
52
+ gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
53
+ return np.where(mask[..., None], image, gray_color)
54
+
55
+
56
+ def filter_detections_by_area(frame: np.ndarray, detections: sv.Detections, minimum_area: float) -> sv.Detections:
57
+ frame_width, frame_height = frame.shape[1], frame.shape[0]
58
+ frame_area = frame_width * frame_height
59
+ return detections[detections.area > minimum_area * frame_area]
60
+
61
+
62
+ def filter_detections_by_prompt(frame: np.ndarray, detections: sv.Detections, prompt: str, confidence: float) -> sv.Detections:
63
+ text = [f"a picture of {prompt}", "a picture of background"]
64
+ filtering_mask = []
65
+ for xyxy, mask in zip(detections.xyxy, detections.mask):
66
+ crop = gray_background(
67
+ image=sv.crop_image(image=frame, xyxy=xyxy),
68
+ mask=sv.crop_image(image=mask, xyxy=xyxy))
69
+ probs = run_clip(frame=crop, text=text)
70
+ filtering_mask.append(probs[0][0] > confidence)
71
+
72
+ return detections[np.array(filtering_mask)]
73
+
74
+
75
+ def mask_frame(frame: np.ndarray, prompt: str, confidence: float) -> np.ndarray:
76
+ detections = run_sam(frame)
77
+ detections = filter_detections_by_area(
78
+ frame=frame, detections=detections, minimum_area=MINIMUM_AREA)
79
+ detections = filter_detections_by_prompt(
80
+ frame=frame, detections=detections, prompt=prompt, confidence=confidence)
81
+ # converting set of masks to a single mask
82
+ mask = np.any(detections.mask, axis=0).astype(np.uint8) * 255
83
+ # converting single channel mask to 3 channel mask
84
+ return np.repeat(mask[:, :, np.newaxis], 3, axis=2)
85
+
86
+
87
  def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str:
88
  video_info = sv.VideoInfo.from_video_path(source_video)
89
  frame_iterator = iter(sv.get_video_frames_generator(
90
  source_path=source_video, start=START_FRAME, end=END_FRAME))
91
 
92
+ with sv.ImageSink(name, image_name_pattern="{:05d}.png") as image_sink:
93
+ with sv.VideoSink(f"{name}.mp4", video_info=video_info) as video_sink:
94
+ for _ in tqdm(range(TOTAL), desc="Masking frames"):
95
+ frame = next(frame_iterator)
96
+ annotated_frame = mask_frame(frame, prompt, confidence)
97
+ video_sink.write_frame(annotated_frame)
98
+ image_sink.save_image(annotated_frame)
99
  return f"{name}.mp4"
100
 
101
 
 
111
 
112
 
113
  with gr.Blocks() as demo:
114
+ gr.Markdown(MARKDOWN)
115
  with gr.Row():
116
  with gr.Column():
117
  source_video_player = gr.Video(