Spaces:
Runtime error
Runtime error
The existing video processing pipeline was enhanced by adding segmentation and mask generation functionality.
Browse files
app.py
CHANGED
@@ -1,50 +1,101 @@
|
|
1 |
-
import torch
|
2 |
-
import time
|
3 |
import uuid
|
4 |
-
from typing import Tuple
|
5 |
|
6 |
import gradio as gr
|
7 |
-
import supervision as sv
|
8 |
import numpy as np
|
9 |
-
|
10 |
-
|
11 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
START_FRAME = 0
|
14 |
END_FRAME = 10
|
15 |
TOTAL = END_FRAME - START_FRAME
|
|
|
16 |
|
17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
SAM_GENERATOR = pipeline(
|
19 |
task="mask-generation",
|
20 |
-
model="facebook/sam-vit-
|
21 |
device=DEVICE)
|
22 |
-
|
23 |
-
|
24 |
-
color_lookup=sv.ColorLookup.INDEX)
|
25 |
|
26 |
|
27 |
def run_sam(frame: np.ndarray) -> sv.Detections:
|
28 |
# convert from Numpy BGR to PIL RGB
|
29 |
image = Image.fromarray(frame[:, :, ::-1])
|
30 |
-
|
31 |
outputs = SAM_GENERATOR(image)
|
32 |
mask = np.array(outputs['masks'])
|
33 |
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str:
|
37 |
video_info = sv.VideoInfo.from_video_path(source_video)
|
38 |
frame_iterator = iter(sv.get_video_frames_generator(
|
39 |
source_path=source_video, start=START_FRAME, end=END_FRAME))
|
40 |
|
41 |
-
with sv.
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
return f"{name}.mp4"
|
49 |
|
50 |
|
@@ -60,6 +111,7 @@ def process(
|
|
60 |
|
61 |
|
62 |
with gr.Blocks() as demo:
|
|
|
63 |
with gr.Row():
|
64 |
with gr.Column():
|
65 |
source_video_player = gr.Video(
|
|
|
|
|
|
|
1 |
import uuid
|
2 |
+
from typing import Tuple, List
|
3 |
|
4 |
import gradio as gr
|
|
|
5 |
import numpy as np
|
6 |
+
import supervision as sv
|
7 |
+
import torch
|
8 |
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import pipeline, CLIPModel, CLIPProcessor
|
11 |
+
|
12 |
+
MARKDOWN = """
|
13 |
+
# Auto ProPainter
|
14 |
+
This is a demo for automatic removal of objects from videos using
|
15 |
+
[Segment Anything Model](https://github.com/facebookresearch/segment-anything),
|
16 |
+
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP), and
|
17 |
+
[ProPainter](https://github.com/sczhou/ProPainter) combo.
|
18 |
+
"""
|
19 |
|
20 |
START_FRAME = 0
|
21 |
END_FRAME = 10
|
22 |
TOTAL = END_FRAME - START_FRAME
|
23 |
+
MINIMUM_AREA = 0.01
|
24 |
|
25 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
26 |
SAM_GENERATOR = pipeline(
|
27 |
task="mask-generation",
|
28 |
+
model="facebook/sam-vit-large",
|
29 |
device=DEVICE)
|
30 |
+
CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE)
|
31 |
+
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m")
|
|
|
32 |
|
33 |
|
34 |
def run_sam(frame: np.ndarray) -> sv.Detections:
|
35 |
# convert from Numpy BGR to PIL RGB
|
36 |
image = Image.fromarray(frame[:, :, ::-1])
|
|
|
37 |
outputs = SAM_GENERATOR(image)
|
38 |
mask = np.array(outputs['masks'])
|
39 |
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
40 |
|
41 |
|
42 |
+
def run_clip(frame: np.ndarray, text: List[str]) -> np.ndarray:
|
43 |
+
# convert from Numpy BGR to PIL RGB
|
44 |
+
image = Image.fromarray(frame[:, :, ::-1])
|
45 |
+
inputs = CLIP_PROCESSOR(text=text, images=image, return_tensors="pt").to(DEVICE)
|
46 |
+
outputs = CLIP_MODEL(**inputs)
|
47 |
+
probs = outputs.logits_per_image.softmax(dim=1)
|
48 |
+
return probs.detach().cpu().numpy()
|
49 |
+
|
50 |
+
|
51 |
+
def gray_background(image: np.ndarray, mask: np.ndarray, gray_value=128):
|
52 |
+
gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8)
|
53 |
+
return np.where(mask[..., None], image, gray_color)
|
54 |
+
|
55 |
+
|
56 |
+
def filter_detections_by_area(frame: np.ndarray, detections: sv.Detections, minimum_area: float) -> sv.Detections:
|
57 |
+
frame_width, frame_height = frame.shape[1], frame.shape[0]
|
58 |
+
frame_area = frame_width * frame_height
|
59 |
+
return detections[detections.area > minimum_area * frame_area]
|
60 |
+
|
61 |
+
|
62 |
+
def filter_detections_by_prompt(frame: np.ndarray, detections: sv.Detections, prompt: str, confidence: float) -> sv.Detections:
|
63 |
+
text = [f"a picture of {prompt}", "a picture of background"]
|
64 |
+
filtering_mask = []
|
65 |
+
for xyxy, mask in zip(detections.xyxy, detections.mask):
|
66 |
+
crop = gray_background(
|
67 |
+
image=sv.crop_image(image=frame, xyxy=xyxy),
|
68 |
+
mask=sv.crop_image(image=mask, xyxy=xyxy))
|
69 |
+
probs = run_clip(frame=crop, text=text)
|
70 |
+
filtering_mask.append(probs[0][0] > confidence)
|
71 |
+
|
72 |
+
return detections[np.array(filtering_mask)]
|
73 |
+
|
74 |
+
|
75 |
+
def mask_frame(frame: np.ndarray, prompt: str, confidence: float) -> np.ndarray:
|
76 |
+
detections = run_sam(frame)
|
77 |
+
detections = filter_detections_by_area(
|
78 |
+
frame=frame, detections=detections, minimum_area=MINIMUM_AREA)
|
79 |
+
detections = filter_detections_by_prompt(
|
80 |
+
frame=frame, detections=detections, prompt=prompt, confidence=confidence)
|
81 |
+
# converting set of masks to a single mask
|
82 |
+
mask = np.any(detections.mask, axis=0).astype(np.uint8) * 255
|
83 |
+
# converting single channel mask to 3 channel mask
|
84 |
+
return np.repeat(mask[:, :, np.newaxis], 3, axis=2)
|
85 |
+
|
86 |
+
|
87 |
def mask_video(source_video: str, prompt: str, confidence: float, name: str) -> str:
|
88 |
video_info = sv.VideoInfo.from_video_path(source_video)
|
89 |
frame_iterator = iter(sv.get_video_frames_generator(
|
90 |
source_path=source_video, start=START_FRAME, end=END_FRAME))
|
91 |
|
92 |
+
with sv.ImageSink(name, image_name_pattern="{:05d}.png") as image_sink:
|
93 |
+
with sv.VideoSink(f"{name}.mp4", video_info=video_info) as video_sink:
|
94 |
+
for _ in tqdm(range(TOTAL), desc="Masking frames"):
|
95 |
+
frame = next(frame_iterator)
|
96 |
+
annotated_frame = mask_frame(frame, prompt, confidence)
|
97 |
+
video_sink.write_frame(annotated_frame)
|
98 |
+
image_sink.save_image(annotated_frame)
|
99 |
return f"{name}.mp4"
|
100 |
|
101 |
|
|
|
111 |
|
112 |
|
113 |
with gr.Blocks() as demo:
|
114 |
+
gr.Markdown(MARKDOWN)
|
115 |
with gr.Row():
|
116 |
with gr.Column():
|
117 |
source_video_player = gr.Video(
|