Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
import subprocess | |
import supervision as sv | |
import torch | |
import uuid | |
from PIL import Image | |
from tqdm import tqdm | |
from transformers import pipeline, CLIPModel, CLIPProcessor | |
from typing import Tuple, List | |
MARKDOWN = """ | |
# Auto β‘ ProPainter π§βπ¨ | |
This is a demo for automatic removal of objects from videos using | |
[Segment Anything Model](https://github.com/facebookresearch/segment-anything), | |
[MetaCLIP](https://github.com/facebookresearch/MetaCLIP), and | |
[ProPainter](https://github.com/sczhou/ProPainter) combo. | |
- [x] Automated object masking using SAM + MetaCLIP | |
- [x] Automated inpainting using ProPainter | |
- [ ] Automated β‘ object masking using FastSAM + MetaCLIP | |
""" | |
EXAMPLES = [ | |
["https://media.roboflow.com/supervision/video-examples/ball-juggling.mp4", "person", 0.6] | |
] | |
START_FRAME = 0 | |
END_FRAME = 10 | |
TOTAL = END_FRAME - START_FRAME | |
MINIMUM_AREA = 0.01 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
SAM_GENERATOR = pipeline( | |
task="mask-generation", | |
model="facebook/sam-vit-large", | |
device=DEVICE) | |
CLIP_MODEL = CLIPModel.from_pretrained("facebook/metaclip-b32-400m").to(DEVICE) | |
CLIP_PROCESSOR = CLIPProcessor.from_pretrained("facebook/metaclip-b32-400m") | |
def run_sam(frame: np.ndarray) -> sv.Detections: | |
# convert from Numpy BGR to PIL RGB | |
image = Image.fromarray(frame[:, :, ::-1]) | |
outputs = SAM_GENERATOR(image) | |
mask = np.array(outputs['masks']) | |
return sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask) | |
def run_clip(frame: np.ndarray, text: List[str]) -> np.ndarray: | |
# convert from Numpy BGR to PIL RGB | |
image = Image.fromarray(frame[:, :, ::-1]) | |
inputs = CLIP_PROCESSOR(text=text, images=image, return_tensors="pt").to(DEVICE) | |
outputs = CLIP_MODEL(**inputs) | |
probs = outputs.logits_per_image.softmax(dim=1) | |
return probs.detach().cpu().numpy() | |
def gray_background(image: np.ndarray, mask: np.ndarray, gray_value=128): | |
gray_color = np.array([gray_value, gray_value, gray_value], dtype=np.uint8) | |
return np.where(mask[..., None], image, gray_color) | |
def filter_detections_by_area(frame: np.ndarray, detections: sv.Detections, minimum_area: float) -> sv.Detections: | |
frame_width, frame_height = frame.shape[1], frame.shape[0] | |
frame_area = frame_width * frame_height | |
return detections[detections.area > minimum_area * frame_area] | |
def filter_detections_by_prompt(frame: np.ndarray, detections: sv.Detections, prompt: str, confidence: float) -> sv.Detections: | |
text = [f"a picture of {prompt}", "a picture of background"] | |
filtering_mask = [] | |
for xyxy, mask in zip(detections.xyxy, detections.mask): | |
crop = gray_background( | |
image=sv.crop_image(image=frame, xyxy=xyxy), | |
mask=sv.crop_image(image=mask, xyxy=xyxy)) | |
probs = run_clip(frame=crop, text=text) | |
filtering_mask.append(probs[0][0] > confidence) | |
return detections[np.array(filtering_mask)] | |
def mask_frame(frame: np.ndarray, prompt: str, confidence: float) -> np.ndarray: | |
detections = run_sam(frame) | |
detections = filter_detections_by_area( | |
frame=frame, detections=detections, minimum_area=MINIMUM_AREA) | |
detections = filter_detections_by_prompt( | |
frame=frame, detections=detections, prompt=prompt, confidence=confidence) | |
# converting set of masks to a single mask | |
mask = np.any(detections.mask, axis=0).astype(np.uint8) * 255 | |
# converting single channel mask to 3 channel mask | |
return np.repeat(mask[:, :, np.newaxis], 3, axis=2) | |
def mask_video(source_video: str, prompt: str, confidence: float, frames_dir: str, masked_frames_dir: str) -> None: | |
frame_iterator = iter(sv.get_video_frames_generator( | |
source_path=source_video, start=START_FRAME, end=END_FRAME)) | |
with sv.ImageSink(masked_frames_dir, image_name_pattern="{:05d}.png") as masked_frames_sink: | |
with sv.ImageSink(frames_dir, image_name_pattern="{:05d}.jpg") as frames_sink: | |
for _ in tqdm(range(TOTAL), desc="Masking frames"): | |
frame = next(frame_iterator) | |
frames_sink.save_image(frame) | |
masked_frame = mask_frame(frame, prompt, confidence) | |
masked_frames_sink.save_image(masked_frame) | |
return frames_dir, masked_frames_dir | |
def execute_command(command: str) -> None: | |
subprocess.run(command, check=True) | |
def paint_video(frames_dir: str, masked_frames_dir: str, results_dir: str) -> None: | |
command = [ | |
f"python", | |
f"inference_propainter.py", | |
f"--video={frames_dir}", | |
f"--mask={masked_frames_dir}", | |
f"--output={results_dir}", | |
f"--save_fps={25}" | |
] | |
execute_command(command) | |
def process( | |
source_video: str, | |
prompt: str, | |
confidence: float, | |
progress=gr.Progress(track_tqdm=True) | |
) -> Tuple[str, str]: | |
name = str(uuid.uuid4()) | |
frames_dir = f"{name}/frames" | |
masked_frames_dir = f"{name}/masked_frames" | |
results_dir = f"{name}/results" | |
mask_video(source_video, prompt, confidence, frames_dir, masked_frames_dir) | |
paint_video(frames_dir, masked_frames_dir, results_dir) | |
return f"{name}/results/frames/masked_in.mp4", f"{name}/results/frames/inpaint_out.mp4" | |
with gr.Blocks() as demo: | |
gr.Markdown(MARKDOWN) | |
with gr.Row(): | |
with gr.Column(): | |
source_video_player = gr.Video( | |
label="Source video", source="upload", format="mp4") | |
prompt_text = gr.Textbox( | |
label="Prompt", value="person") | |
confidence_slider = gr.Slider( | |
label="Confidence", minimum=0.5, maximum=1.0, step=0.05, value=0.6) | |
submit_button = gr.Button("Submit") | |
with gr.Column(): | |
masked_video_player = gr.Video(label="Masked video") | |
painted_video_player = gr.Video(label="Painted video") | |
with gr.Row(): | |
gr.Examples( | |
examples=EXAMPLES, | |
fn=process, | |
inputs=[source_video_player, prompt_text, confidence_slider], | |
outputs=[masked_video_player, painted_video_player], | |
cache_examples=False, | |
run_on_click=True | |
) | |
submit_button.click( | |
process, | |
inputs=[source_video_player, prompt_text, confidence_slider], | |
outputs=[masked_video_player, painted_video_player]) | |
demo.queue().launch(debug=False, show_error=True) | |