|
import os |
|
import cv2 |
|
import imghdr |
|
import shutil |
|
import warnings |
|
import numpy as np |
|
import gradio as gr |
|
from dataclasses import dataclass |
|
from mivolo.predictor import Predictor |
|
from utils import is_url, download_file, get_jpg_files, MODEL_DIR |
|
|
|
TMP_DIR = "./__pycache__" |
|
|
|
|
|
@dataclass |
|
class Cfg: |
|
detector_weights: str |
|
checkpoint: str |
|
device: str = "cpu" |
|
with_persons: bool = True |
|
disable_faces: bool = False |
|
draw: bool = True |
|
|
|
|
|
class ValidImgDetector: |
|
predictor = None |
|
|
|
def __init__(self): |
|
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt" |
|
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar" |
|
predictor_cfg = Cfg(detector_path, age_gender_path) |
|
self.predictor = Predictor(predictor_cfg) |
|
|
|
def _detect( |
|
self, |
|
image: np.ndarray, |
|
score_threshold: float, |
|
iou_threshold: float, |
|
mode: str, |
|
predictor: Predictor, |
|
) -> np.ndarray: |
|
|
|
predictor.detector.detector_kwargs["conf"] = score_threshold |
|
predictor.detector.detector_kwargs["iou"] = iou_threshold |
|
if mode == "Use persons and faces": |
|
use_persons = True |
|
disable_faces = False |
|
|
|
elif mode == "Use persons only": |
|
use_persons = True |
|
disable_faces = True |
|
|
|
elif mode == "Use faces only": |
|
use_persons = False |
|
disable_faces = False |
|
|
|
predictor.age_gender_model.meta.use_persons = use_persons |
|
predictor.age_gender_model.meta.disable_faces = disable_faces |
|
|
|
detected_objects, out_im = predictor.recognize(image) |
|
has_child, has_female, has_male = False, False, False |
|
if len(detected_objects.ages) > 0: |
|
has_child = min(detected_objects.ages) < 18 |
|
has_female = "female" in detected_objects.genders |
|
has_male = "male" in detected_objects.genders |
|
|
|
return out_im[:, :, ::-1], has_child, has_female, has_male |
|
|
|
def valid_img(self, img_path): |
|
image = cv2.imread(img_path) |
|
return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor) |
|
|
|
|
|
def infer(photo: str): |
|
if is_url(photo): |
|
if os.path.exists(TMP_DIR): |
|
shutil.rmtree(TMP_DIR) |
|
|
|
photo = download_file(photo, f"{TMP_DIR}/download.jpg") |
|
|
|
detector = ValidImgDetector() |
|
if not photo or not os.path.exists(photo) or imghdr.what(photo) == None: |
|
return None, None, None, "请正确输入图片 Please input image correctly" |
|
|
|
return detector.valid_img(photo) |
|
|
|
|
|
if __name__ == "__main__": |
|
with gr.Blocks() as iface: |
|
warnings.filterwarnings("ignore") |
|
with gr.Tab("上传模式 Upload Mode"): |
|
gr.Interface( |
|
fn=infer, |
|
inputs=gr.Image(label="上传照片 Upload Photo", type="filepath"), |
|
outputs=[ |
|
gr.Image(label="检测结果 Detection Result", type="numpy"), |
|
gr.Textbox(label="存在儿童 Has Child"), |
|
gr.Textbox(label="存在女性 Has Female"), |
|
gr.Textbox(label="存在男性 Has Male"), |
|
], |
|
examples=get_jpg_files(f"{MODEL_DIR}/examples"), |
|
allow_flagging="never", |
|
) |
|
|
|
with gr.Tab("在线模式 Online Mode"): |
|
gr.Interface( |
|
fn=infer, |
|
inputs=gr.Textbox(label="网络图片链接 Online Picture URL"), |
|
outputs=[ |
|
gr.Image(label="检测结果 Detection Result", type="numpy"), |
|
gr.Textbox(label="存在儿童 Has Child"), |
|
gr.Textbox(label="存在女性 Has Female"), |
|
gr.Textbox(label="存在男性 Has Male"), |
|
], |
|
allow_flagging="never", |
|
cache_examples=False, |
|
) |
|
|
|
iface.launch() |
|
|