Irina Tolstykh
new example
f14a122
raw
history blame
6.04 kB
#!/usr/bin/env python
import pathlib
import os
import gradio as gr
import huggingface_hub
import numpy as np
import functools
from dataclasses import dataclass
from mivolo.predictor import Predictor
@dataclass
class Cfg:
detector_weights: str
checkpoint: str
device: str = "cpu"
with_persons: bool = True
disable_faces: bool = False
draw: bool = True
DESCRIPTION = """
# MiVOLO: Multi-input Transformer for Age and Gender Estimation
This is an official demo for https://github.com/WildChlamydia/MiVOLO.\n
Telegram channel: https://t.me/+K0i2fLGpVKBjNzUy (Russian language)
"""
HF_TOKEN = os.getenv('HF_TOKEN')
def load_models():
detector_path = huggingface_hub.hf_hub_download('iitolstykh/demo_yolov8_detector',
'yolov8x_person_face.pt',
use_auth_token=HF_TOKEN)
age_gender_path_v1 = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross',
'checkpoint-377.pth.tar',
use_auth_token=HF_TOKEN)
age_gender_path_v2 = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross',
'mivolo_v2_384_0.15.pth.tar',
use_auth_token=HF_TOKEN)
predictor_cfg_v1 = Cfg(detector_path, age_gender_path_v1)
predictor_cfg_v2 = Cfg(detector_path, age_gender_path_v2)
predictor_v1 = Predictor(predictor_cfg_v1)
predictor_v2 = Predictor(predictor_cfg_v2)
return predictor_v1, predictor_v2
def detect(
image: np.ndarray,
score_threshold: float,
iou_threshold: float,
mode: str,
predictor: Predictor
) -> np.ndarray:
# input is rgb image, output must be rgb too
predictor.detector.detector_kwargs['conf'] = score_threshold
predictor.detector.detector_kwargs['iou'] = iou_threshold
if mode == "Use persons and faces":
use_persons = True
disable_faces = False
elif mode == "Use persons only":
use_persons = True
disable_faces = True
elif mode == "Use faces only":
use_persons = False
disable_faces = False
predictor.age_gender_model.meta.use_persons = use_persons
predictor.age_gender_model.meta.disable_faces = disable_faces
image = image[:, :, ::-1] # RGB -> BGR
detected_objects, out_im = predictor.recognize(image)
return out_im[:, :, ::-1] # BGR -> RGB
def clear():
return None, 0.4, 0.7, "Use persons and faces", None
predictor_v1, predictor_v2 = load_models()
prediction_func_v1 = functools.partial(detect, predictor=predictor_v1)
prediction_func_v2 = functools.partial(detect, predictor=predictor_v2)
image_dir = pathlib.Path('images')
examples = [[path.as_posix(), 0.4, 0.7, "Use persons and faces"] for path in sorted(image_dir.glob('*.jpg'))]
with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo_v1:
with gr.Row():
with gr.Column():
image = gr.Image(label='Input', type='numpy')
score_threshold = gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold')
iou_threshold = gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold')
mode = gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"],
value="Use persons and faces",
label="Inference mode",
info="What to use for gender and age recognition")
with gr.Row():
clear_button = gr.Button("Clear")
with gr.Column():
run_button = gr.Button("Submit", variant="primary")
with gr.Column():
result = gr.Image(label='Output', type='numpy')
inputs = [image, score_threshold, iou_threshold, mode]
gr.Examples(examples=examples,
inputs=inputs,
outputs=result,
fn=prediction_func_v1,
cache_examples=False)
run_button.click(fn=prediction_func_v1, inputs=inputs, outputs=result, api_name='predict')
clear_button.click(fn=clear, inputs=None, outputs=[image, score_threshold, iou_threshold, mode, result])
with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo_v2:
with gr.Row():
with gr.Column():
image = gr.Image(label='Input', type='numpy')
score_threshold = gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold')
iou_threshold = gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold')
mode = gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"],
value="Use persons and faces",
label="Inference mode",
info="What to use for gender and age recognition")
with gr.Row():
clear_button = gr.Button("Clear")
with gr.Column():
run_button = gr.Button("Submit", variant="primary")
with gr.Column():
result = gr.Image(label='Output', type='numpy')
inputs = [image, score_threshold, iou_threshold, mode]
gr.Examples(examples=examples,
inputs=inputs,
outputs=result,
fn=prediction_func_v2,
cache_examples=False)
run_button.click(fn=prediction_func_v2, inputs=inputs, outputs=result, api_name='predict')
clear_button.click(fn=clear, inputs=None, outputs=[image, score_threshold, iou_threshold, mode, result])
with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.Tab(label="MiVOLO_V1"):
demo_v1.render()
with gr.Tab(label="MiVOLO_V2"):
demo_v2.render()
if __name__ == "__main__":
demo.queue(max_size=15).launch()