File size: 6,234 Bytes
21ce843 5f16544 3789374 5f16544 21ce843 5f16544 ce0031d 5f16544 21ce843 8a19b0b 71a75ca 8a19b0b 38df08a 8a19b0b 21ce843 8a19b0b 5f16544 8a19b0b 21ce843 e752590 8a19b0b 21ce843 e752590 025ec32 e752590 21ce843 8a19b0b f4e3f66 5f16544 21ce843 f4e3f66 5f16544 025ec32 21ce843 ce0031d 8a19b0b 21ce843 ce0031d e752590 21ce843 f4e3f66 21ce843 e752590 21ce843 e752590 8a19b0b 2c10ed3 8a19b0b e752590 8a19b0b e752590 8a19b0b e752590 86551c8 e752590 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
#!/usr/bin/env python
import os
import shlex
import subprocess
if os.getenv('SYSTEM') == 'spaces':
git_repo = "https://github.com/WildChlamydia/MiVOLO.git"
subprocess.call(shlex.split(f'pip install git+{git_repo}'))
import pathlib
import os
import gradio as gr
import huggingface_hub
import numpy as np
import functools
from dataclasses import dataclass
from mivolo.predictor import Predictor
@dataclass
class Cfg:
detector_weights: str
checkpoint: str
device: str = "cpu"
with_persons: bool = True
disable_faces: bool = False
draw: bool = True
DESCRIPTION = """
# MiVOLO: Multi-input Transformer for Age and Gender Estimation
This is an official demo for https://github.com/WildChlamydia/MiVOLO.\n
Telegram channel: https://t.me/+K0i2fLGpVKBjNzUy (Russian language)
"""
HF_TOKEN = os.getenv('HF_TOKEN')
def load_models():
detector_path = huggingface_hub.hf_hub_download('iitolstykh/demo_yolov8_detector',
'yolov8x_person_face.pt',
use_auth_token=HF_TOKEN)
age_gender_path_v1 = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross',
'checkpoint-377.pth.tar',
use_auth_token=HF_TOKEN)
age_gender_path_v2 = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross',
'mivolo_v2_1.tar',
use_auth_token=HF_TOKEN)
predictor_cfg_v1 = Cfg(detector_path, age_gender_path_v1)
predictor_cfg_v2 = Cfg(detector_path, age_gender_path_v2)
predictor_v1 = Predictor(predictor_cfg_v1)
predictor_v2 = Predictor(predictor_cfg_v2)
return predictor_v1, predictor_v2
def detect(
image: np.ndarray,
score_threshold: float,
iou_threshold: float,
mode: str,
predictor: Predictor
) -> np.ndarray:
# input is rgb image, output must be rgb too
predictor.detector.detector_kwargs['conf'] = score_threshold
predictor.detector.detector_kwargs['iou'] = iou_threshold
if mode == "Use persons and faces":
use_persons = True
disable_faces = False
elif mode == "Use persons only":
use_persons = True
disable_faces = True
elif mode == "Use faces only":
use_persons = False
disable_faces = False
predictor.age_gender_model.meta.use_persons = use_persons
predictor.age_gender_model.meta.disable_faces = disable_faces
image = image[:, :, ::-1] # RGB -> BGR
detected_objects, out_im = predictor.recognize(image)
return out_im[:, :, ::-1] # BGR -> RGB
def clear():
return None, 0.4, 0.7, "Use persons and faces", None
predictor_v1, predictor_v2 = load_models()
prediction_func_v1 = functools.partial(detect, predictor=predictor_v1)
prediction_func_v2 = functools.partial(detect, predictor=predictor_v2)
image_dir = pathlib.Path('images')
examples = [[path.as_posix(), 0.4, 0.7, "Use persons and faces"] for path in sorted(image_dir.glob('*.jpg'))]
with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo_v1:
with gr.Row():
with gr.Column():
image = gr.Image(label='Input', type='numpy')
score_threshold = gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold')
iou_threshold = gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold')
mode = gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"],
value="Use persons and faces",
label="Inference mode",
info="What to use for gender and age recognition")
with gr.Row():
clear_button = gr.Button("Clear")
with gr.Column():
run_button = gr.Button("Submit", variant="primary")
with gr.Column():
result = gr.Image(label='Output', type='numpy')
inputs = [image, score_threshold, iou_threshold, mode]
gr.Examples(examples=examples,
inputs=inputs,
outputs=result,
fn=prediction_func_v1,
cache_examples=False)
run_button.click(fn=prediction_func_v1, inputs=inputs, outputs=result, api_name='predict')
clear_button.click(fn=clear, inputs=None, outputs=[image, score_threshold, iou_threshold, mode, result])
with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo_v2:
with gr.Row():
with gr.Column():
image = gr.Image(label='Input', type='numpy')
score_threshold = gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold')
iou_threshold = gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold')
mode = gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"],
value="Use persons and faces",
label="Inference mode",
info="What to use for gender and age recognition")
with gr.Row():
clear_button = gr.Button("Clear")
with gr.Column():
run_button = gr.Button("Submit", variant="primary")
with gr.Column():
result = gr.Image(label='Output', type='numpy')
inputs = [image, score_threshold, iou_threshold, mode]
gr.Examples(examples=examples,
inputs=inputs,
outputs=result,
fn=prediction_func_v2,
cache_examples=False)
run_button.click(fn=prediction_func_v2, inputs=inputs, outputs=result, api_name='predict')
clear_button.click(fn=clear, inputs=None, outputs=[image, score_threshold, iou_threshold, mode, result])
with gr.Blocks(theme=gr.themes.Default(), css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Tabs():
with gr.Tab(label="MiVOLO_V1"):
demo_v1.render()
# with gr.Tab(label="MiVOLO_V2"):
# demo_v2.render()
if __name__ == "__main__":
demo.queue(max_size=15).launch()
|