File size: 3,270 Bytes
21ce843
 
5f16544
 
 
 
 
 
 
 
 
 
21ce843
 
 
 
 
 
5f16544
 
 
 
 
 
 
 
 
 
 
 
 
21ce843
 
 
 
 
 
 
5f16544
 
21ce843
 
 
5f16544
 
 
21ce843
5f16544
 
025ec32
5f16544
21ce843
f4e3f66
 
 
 
 
 
 
5f16544
21ce843
f4e3f66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f16544
 
025ec32
21ce843
 
5f16544
21ce843
 
f4e3f66
21ce843
5f16544
21ce843
 
 
f4e3f66
 
1b018f5
 
f4e3f66
 
 
 
 
21ce843
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python

import os
import shlex
import subprocess

if os.getenv('SYSTEM') == 'spaces':
    GITHUB_TOKEN = os.getenv('GITHUB_TOKEN')
    GITHUB_USER = os.getenv('GITHUB_USER')
    git_repo = f"https://{GITHUB_TOKEN}@github.com/{GITHUB_USER}/xnet_demo.git"
    subprocess.call(shlex.split(f'pip install git+{git_repo}'))

import pathlib
import os
import gradio as gr
import huggingface_hub
import numpy as np
import functools
from dataclasses import dataclass

from xnet.predictor import Predictor


@dataclass
class Cfg:
    detector_weights: str
    checkpoint: str
    device: str = "cpu"
    with_persons: bool = True
    disable_faces: bool = False
    draw: bool = True


TITLE = 'Age and Gender Estimation with Transformers from Face and Body Images in the Wild'
DESCRIPTION = 'This is an official demo for https://github.com/...'

HF_TOKEN = os.getenv('HF_TOKEN')

def load_models():
    detector_path = huggingface_hub.hf_hub_download('iitolstykh/demo_yolov8_detector',
                                           'yolov8x_person_face.pt',
                                           use_auth_token=HF_TOKEN)

    age_gender_path = huggingface_hub.hf_hub_download('iitolstykh/demo_xnet_volo_cross',
                                           'checkpoint-377.pth.tar',
                                           use_auth_token=HF_TOKEN)

    predictor_cfg = Cfg(detector_path, age_gender_path)
    predictor = Predictor(predictor_cfg)

    return predictor

def detect(
        image: np.ndarray,
        score_threshold: float,
        iou_threshold: float,
        mode: str,
        predictor: Predictor
) -> np.ndarray:
    # input is rgb image, output must be rgb too

    predictor.detector.detector_kwargs['conf'] = score_threshold
    predictor.detector.detector_kwargs['iou'] = iou_threshold

    if mode == "Use persons and faces":
        use_persons = True
        disable_faces = False
    elif mode == "Use persons only":
        use_persons = True
        disable_faces = True
    elif mode == "Use faces only":
        use_persons = False
        disable_faces = False

    predictor.age_gender_model.meta.use_persons = use_persons
    predictor.age_gender_model.meta.disable_faces = disable_faces

    image = image[:, :, ::-1]  # RGB -> BGR
    detected_objects, out_im = predictor.recognize(image)
    return out_im[:, :, ::-1]  # BGR -> RGB


predictor = load_models()

image_dir = pathlib.Path('images')
examples = [[path.as_posix(), 0.4, 0.7, "Use persons and faces"] for path in sorted(image_dir.glob('*.jpg'))]

func = functools.partial(detect, predictor=predictor)

gr.Interface(
    fn=func,
    inputs=[
        gr.Image(label='Input', type='numpy'),
        gr.Slider(0, 1, value=0.4, step=0.05, label='Detector Score Threshold'),
        gr.Slider(0, 1, value=0.7, step=0.05, label='NMS Iou Threshold'),
        gr.Radio(["Use persons and faces", "Use persons only", "Use faces only"],
                 value="Use persons and faces",
                 label="Inference mode",
                 info="What to use for gender and age recognition"),
    ],
    outputs=gr.Image(label='Output', type='numpy'),
    examples=examples,
    examples_per_page=30,
    title=TITLE,
    description=DESCRIPTION,
).launch(show_api=False)