Spaces:
Running
Running
import os | |
import cv2 | |
import imghdr | |
import shutil | |
import warnings | |
import numpy as np | |
import gradio as gr | |
from dataclasses import dataclass | |
from mivolo.predictor import Predictor | |
from utils import is_url, download_file, get_jpg_files, MODEL_DIR | |
TMP_DIR = "./__pycache__" | |
class Cfg: | |
detector_weights: str | |
checkpoint: str | |
device: str = "cpu" | |
with_persons: bool = True | |
disable_faces: bool = False | |
draw: bool = True | |
class ValidImgDetector: | |
predictor = None | |
def __init__(self): | |
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt" | |
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar" | |
predictor_cfg = Cfg(detector_path, age_gender_path) | |
self.predictor = Predictor(predictor_cfg) | |
def _detect( | |
self, | |
image: np.ndarray, | |
score_threshold: float, | |
iou_threshold: float, | |
mode: str, | |
predictor: Predictor, | |
) -> np.ndarray: | |
# input is rgb image, output must be rgb too | |
predictor.detector.detector_kwargs["conf"] = score_threshold | |
predictor.detector.detector_kwargs["iou"] = iou_threshold | |
if mode == "Use persons and faces": | |
use_persons = True | |
disable_faces = False | |
elif mode == "Use persons only": | |
use_persons = True | |
disable_faces = True | |
elif mode == "Use faces only": | |
use_persons = False | |
disable_faces = False | |
predictor.age_gender_model.meta.use_persons = use_persons | |
predictor.age_gender_model.meta.disable_faces = disable_faces | |
# image = image[:, :, ::-1] # RGB -> BGR | |
detected_objects, out_im = predictor.recognize(image) | |
has_child, has_female, has_male = False, False, False | |
if len(detected_objects.ages) > 0: | |
has_child = min(detected_objects.ages) < 18 | |
has_female = "female" in detected_objects.genders | |
has_male = "male" in detected_objects.genders | |
return out_im[:, :, ::-1], has_child, has_female, has_male | |
def valid_img(self, img_path): | |
image = cv2.imread(img_path) | |
return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor) | |
def infer(photo: str): | |
if is_url(photo): | |
if os.path.exists(TMP_DIR): | |
shutil.rmtree(TMP_DIR) | |
photo = download_file(photo, f"{TMP_DIR}/download.jpg") | |
detector = ValidImgDetector() | |
if not photo or not os.path.exists(photo) or imghdr.what(photo) == None: | |
return None, None, None, "请正确输入图片 Please input image correctly" | |
return detector.valid_img(photo) | |
if __name__ == "__main__": | |
with gr.Blocks() as iface: | |
warnings.filterwarnings("ignore") | |
with gr.Tab("上传模式 Upload Mode"): | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Image(label="上传照片 Upload Photo", type="filepath"), | |
outputs=[ | |
gr.Image(label="检测结果 Detection Result", type="numpy"), | |
gr.Textbox(label="存在儿童 Has Child"), | |
gr.Textbox(label="存在女性 Has Female"), | |
gr.Textbox(label="存在男性 Has Male"), | |
], | |
examples=get_jpg_files(f"{MODEL_DIR}/examples"), | |
allow_flagging="never", | |
) | |
with gr.Tab("在线模式 Online Mode"): | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Textbox(label="网络图片链接 Online Picture URL"), | |
outputs=[ | |
gr.Image(label="检测结果 Detection Result", type="numpy"), | |
gr.Textbox(label="存在儿童 Has Child"), | |
gr.Textbox(label="存在女性 Has Female"), | |
gr.Textbox(label="存在男性 Has Male"), | |
], | |
allow_flagging="never", | |
cache_examples=False, | |
) | |
iface.launch() | |