MiVOLO / app.py
admin
fix cache example
e314b9a
raw
history blame
4 kB
import os
import cv2
import imghdr
import shutil
import warnings
import numpy as np
import gradio as gr
from dataclasses import dataclass
from mivolo.predictor import Predictor
from utils import is_url, download_file, get_jpg_files, MODEL_DIR
TMP_DIR = "./__pycache__"
@dataclass
class Cfg:
detector_weights: str
checkpoint: str
device: str = "cpu"
with_persons: bool = True
disable_faces: bool = False
draw: bool = True
class ValidImgDetector:
predictor = None
def __init__(self):
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
predictor_cfg = Cfg(detector_path, age_gender_path)
self.predictor = Predictor(predictor_cfg)
def _detect(
self,
image: np.ndarray,
score_threshold: float,
iou_threshold: float,
mode: str,
predictor: Predictor,
) -> np.ndarray:
# input is rgb image, output must be rgb too
predictor.detector.detector_kwargs["conf"] = score_threshold
predictor.detector.detector_kwargs["iou"] = iou_threshold
if mode == "Use persons and faces":
use_persons = True
disable_faces = False
elif mode == "Use persons only":
use_persons = True
disable_faces = True
elif mode == "Use faces only":
use_persons = False
disable_faces = False
predictor.age_gender_model.meta.use_persons = use_persons
predictor.age_gender_model.meta.disable_faces = disable_faces
# image = image[:, :, ::-1] # RGB -> BGR
detected_objects, out_im = predictor.recognize(image)
has_child, has_female, has_male = False, False, False
if len(detected_objects.ages) > 0:
has_child = min(detected_objects.ages) < 18
has_female = "female" in detected_objects.genders
has_male = "male" in detected_objects.genders
return out_im[:, :, ::-1], has_child, has_female, has_male
def valid_img(self, img_path):
image = cv2.imread(img_path)
return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)
def infer(photo: str):
if is_url(photo):
if os.path.exists(TMP_DIR):
shutil.rmtree(TMP_DIR)
photo = download_file(photo, f"{TMP_DIR}/download.jpg")
detector = ValidImgDetector()
if not photo or not os.path.exists(photo) or imghdr.what(photo) == None:
return None, None, None, "请正确输入图片 Please input image correctly"
return detector.valid_img(photo)
if __name__ == "__main__":
with gr.Blocks() as iface:
warnings.filterwarnings("ignore")
with gr.Tab("上传模式 Upload Mode"):
gr.Interface(
fn=infer,
inputs=gr.Image(label="上传照片 Upload Photo", type="filepath"),
outputs=[
gr.Image(label="检测结果 Detection Result", type="numpy"),
gr.Textbox(label="存在儿童 Has Child"),
gr.Textbox(label="存在女性 Has Female"),
gr.Textbox(label="存在男性 Has Male"),
],
examples=get_jpg_files(f"{MODEL_DIR}/examples"),
allow_flagging="never",
cache_examples=False,
)
with gr.Tab("在线模式 Online Mode"):
gr.Interface(
fn=infer,
inputs=gr.Textbox(label="网络图片链接 Online Picture URL"),
outputs=[
gr.Image(label="检测结果 Detection Result", type="numpy"),
gr.Textbox(label="存在儿童 Has Child"),
gr.Textbox(label="存在女性 Has Female"),
gr.Textbox(label="存在男性 Has Male"),
],
allow_flagging="never",
)
iface.launch()