|
import os |
|
import cv2 |
|
import imghdr |
|
import shutil |
|
import warnings |
|
import numpy as np |
|
import gradio as gr |
|
from dataclasses import dataclass |
|
from mivolo.predictor import Predictor |
|
from utils import is_url, download_file, get_jpg_files, MODEL_DIR |
|
|
|
TMP_DIR = "./__pycache__" |
|
|
|
@dataclass |
|
class Cfg: |
|
detector_weights: str |
|
checkpoint: str |
|
device: str = "cpu" |
|
with_persons: bool = True |
|
disable_faces: bool = False |
|
draw: bool = True |
|
|
|
class ValidImgDetector: |
|
predictor = None |
|
|
|
def __init__(self): |
|
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt" |
|
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar" |
|
predictor_cfg = Cfg(detector_path, age_gender_path) |
|
self.predictor = Predictor(predictor_cfg) |
|
|
|
def _detect( |
|
self, |
|
image: np.ndarray, |
|
score_threshold: float, |
|
iou_threshold: float, |
|
mode: str, |
|
predictor: Predictor, |
|
) -> np.ndarray: |
|
predictor.detector.detector_kwargs["conf"] = score_threshold |
|
predictor.detector.detector_kwargs["iou"] = iou_threshold |
|
|
|
if mode == "Use persons and faces": |
|
use_persons = True |
|
disable_faces = False |
|
elif mode == "Use persons only": |
|
use_persons = True |
|
disable_faces = True |
|
elif mode == "Use faces only": |
|
use_persons = False |
|
disable_faces = False |
|
|
|
predictor.age_gender_model.meta.use_persons = use_persons |
|
predictor.age_gender_model.meta.disable_faces = disable_faces |
|
detected_objects, out_im = predictor.recognize(image) |
|
has_child, has_female, has_male = False, False, False |
|
if len(detected_objects.ages) > 0: |
|
has_child = min(detected_objects.ages) < 18 |
|
has_female = "female" in detected_objects.genders |
|
has_male = "male" in detected_objects.genders |
|
|
|
return out_im[:, :, ::-1], has_child, has_female, has_male |
|
|
|
def valid_video(self, video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
results = [] |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
out_frame, has_child, has_female, has_male = self._detect(frame, 0.4, 0.7, "Use persons and faces", self.predictor) |
|
results.append((out_frame, has_child, has_female, has_male)) |
|
|
|
cap.release() |
|
return results |
|
|
|
def infer(video_path: str): |
|
if is_url(video_path): |
|
if os.path.exists(TMP_DIR): |
|
shutil.rmtree(TMP_DIR) |
|
|
|
video_path = download_file(video_path, f"{TMP_DIR}/download.mp4") |
|
|
|
detector = ValidImgDetector() |
|
if not video_path or not os.path.exists(video_path): |
|
return None, None, None, "Please input the video correctly" |
|
|
|
results = detector.valid_video(video_path) |
|
|
|
|
|
if results: |
|
first_frame_result = results[0] |
|
return first_frame_result |
|
else: |
|
return None, None, None, "No frames detected in video." |
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as iface: |
|
warnings.filterwarnings("ignore") |
|
with gr.Tab("Upload Mode"): |
|
gr.Interface( |
|
fn=infer, |
|
inputs=gr.Video(label="Upload Video"), |
|
outputs=[ |
|
gr.Image(label="Detection Result", type="numpy"), |
|
gr.Textbox(label="Has Child"), |
|
gr.Textbox(label="Has Female"), |
|
gr.Textbox(label="Has Male"), |
|
], |
|
allow_flagging="never", |
|
) |
|
|
|
iface.launch() |
|
|