Spaces:
Running
Running
admin
commited on
Commit
•
4c4ff57
1
Parent(s):
2e57b80
sync
Browse files- .gitattributes +10 -11
- .gitignore +6 -0
- README.md +4 -3
- app.py +119 -0
- mivolo/data/data_reader.py +125 -0
- mivolo/data/dataset/__init__.py +64 -0
- mivolo/data/dataset/age_gender_dataset.py +194 -0
- mivolo/data/dataset/age_gender_loader.py +169 -0
- mivolo/data/dataset/classification_dataset.py +48 -0
- mivolo/data/dataset/reader_age_gender.py +490 -0
- mivolo/data/misc.py +264 -0
- mivolo/model/create_timm_model.py +107 -0
- mivolo/model/cross_bottleneck_attn.py +116 -0
- mivolo/model/mi_volo.py +229 -0
- mivolo/model/mivolo_model.py +402 -0
- mivolo/model/yolo_detector.py +48 -0
- mivolo/predictor.py +68 -0
- mivolo/structures.py +493 -0
- mivolo/version.py +1 -0
- requirements.txt +6 -0
- utils.py +60 -0
.gitattributes
CHANGED
@@ -1,35 +1,34 @@
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
|
4 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
|
6 |
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
|
11 |
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
13 |
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
17 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
20 |
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
22 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
|
24 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.db* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.ark* filter=lfs diff=lfs merge=lfs -text
|
30 |
+
**/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
|
31 |
+
**/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
|
32 |
+
**/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mivolo/model/__pycache__/*
|
2 |
+
mivolo/data/__pycache__/*
|
3 |
+
mivolo/__pycache__/*
|
4 |
+
__pycache__/*
|
5 |
+
model/*
|
6 |
+
rename.sh
|
README.md
CHANGED
@@ -1,13 +1,14 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
|
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: Gender Age Detector
|
3 |
+
emoji: 👩🧑🦲
|
4 |
colorFrom: indigo
|
5 |
colorTo: pink
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.36.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: apache-2.0
|
11 |
+
arxiv: 2307.04616
|
12 |
---
|
13 |
|
14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import imghdr
|
4 |
+
import shutil
|
5 |
+
import warnings
|
6 |
+
import numpy as np
|
7 |
+
import gradio as gr
|
8 |
+
from dataclasses import dataclass
|
9 |
+
from mivolo.predictor import Predictor
|
10 |
+
from utils import is_url, download_file, get_jpg_files, MODEL_DIR
|
11 |
+
|
12 |
+
TMP_DIR = "./__pycache__"
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class Cfg:
|
17 |
+
detector_weights: str
|
18 |
+
checkpoint: str
|
19 |
+
device: str = "cpu"
|
20 |
+
with_persons: bool = True
|
21 |
+
disable_faces: bool = False
|
22 |
+
draw: bool = True
|
23 |
+
|
24 |
+
|
25 |
+
class ValidImgDetector:
|
26 |
+
predictor = None
|
27 |
+
|
28 |
+
def __init__(self):
|
29 |
+
detector_path = f"{MODEL_DIR}/yolov8x_person_face.pt"
|
30 |
+
age_gender_path = f"{MODEL_DIR}/model_imdb_cross_person_4.22_99.46.pth.tar"
|
31 |
+
predictor_cfg = Cfg(detector_path, age_gender_path)
|
32 |
+
self.predictor = Predictor(predictor_cfg)
|
33 |
+
|
34 |
+
def _detect(
|
35 |
+
self,
|
36 |
+
image: np.ndarray,
|
37 |
+
score_threshold: float,
|
38 |
+
iou_threshold: float,
|
39 |
+
mode: str,
|
40 |
+
predictor: Predictor,
|
41 |
+
) -> np.ndarray:
|
42 |
+
# input is rgb image, output must be rgb too
|
43 |
+
predictor.detector.detector_kwargs["conf"] = score_threshold
|
44 |
+
predictor.detector.detector_kwargs["iou"] = iou_threshold
|
45 |
+
if mode == "Use persons and faces":
|
46 |
+
use_persons = True
|
47 |
+
disable_faces = False
|
48 |
+
|
49 |
+
elif mode == "Use persons only":
|
50 |
+
use_persons = True
|
51 |
+
disable_faces = True
|
52 |
+
|
53 |
+
elif mode == "Use faces only":
|
54 |
+
use_persons = False
|
55 |
+
disable_faces = False
|
56 |
+
|
57 |
+
predictor.age_gender_model.meta.use_persons = use_persons
|
58 |
+
predictor.age_gender_model.meta.disable_faces = disable_faces
|
59 |
+
# image = image[:, :, ::-1] # RGB -> BGR
|
60 |
+
detected_objects, out_im = predictor.recognize(image)
|
61 |
+
has_child, has_female, has_male = False, False, False
|
62 |
+
if len(detected_objects.ages) > 0:
|
63 |
+
has_child = min(detected_objects.ages) < 18
|
64 |
+
has_female = "female" in detected_objects.genders
|
65 |
+
has_male = "male" in detected_objects.genders
|
66 |
+
|
67 |
+
return out_im[:, :, ::-1], has_child, has_female, has_male
|
68 |
+
|
69 |
+
def valid_img(self, img_path):
|
70 |
+
image = cv2.imread(img_path)
|
71 |
+
return self._detect(image, 0.4, 0.7, "Use persons and faces", self.predictor)
|
72 |
+
|
73 |
+
|
74 |
+
def infer(photo: str):
|
75 |
+
if is_url(photo):
|
76 |
+
if os.path.exists(TMP_DIR):
|
77 |
+
shutil.rmtree(TMP_DIR)
|
78 |
+
|
79 |
+
photo = download_file(photo, f"{TMP_DIR}/download.jpg")
|
80 |
+
|
81 |
+
detector = ValidImgDetector()
|
82 |
+
if not photo or not os.path.exists(photo) or imghdr.what(photo) == None:
|
83 |
+
return None, None, None, "请正确输入图片 Please input image correctly"
|
84 |
+
|
85 |
+
return detector.valid_img(photo)
|
86 |
+
|
87 |
+
|
88 |
+
if __name__ == "__main__":
|
89 |
+
with gr.Blocks() as iface:
|
90 |
+
warnings.filterwarnings("ignore")
|
91 |
+
with gr.Tab("上传模式 Upload Mode"):
|
92 |
+
gr.Interface(
|
93 |
+
fn=infer,
|
94 |
+
inputs=gr.Image(label="上传照片 Upload Photo", type="filepath"),
|
95 |
+
outputs=[
|
96 |
+
gr.Image(label="检测结果 Detection Result", type="numpy"),
|
97 |
+
gr.Textbox(label="存在儿童 Has Child"),
|
98 |
+
gr.Textbox(label="存在女性 Has Female"),
|
99 |
+
gr.Textbox(label="存在男性 Has Male"),
|
100 |
+
],
|
101 |
+
examples=get_jpg_files(f"{MODEL_DIR}/examples"),
|
102 |
+
allow_flagging="never",
|
103 |
+
)
|
104 |
+
|
105 |
+
with gr.Tab("在线模式 Online Mode"):
|
106 |
+
gr.Interface(
|
107 |
+
fn=infer,
|
108 |
+
inputs=gr.Textbox(label="网络图片链接 Online Picture URL"),
|
109 |
+
outputs=[
|
110 |
+
gr.Image(label="检测结果 Detection Result", type="numpy"),
|
111 |
+
gr.Textbox(label="存在儿童 Has Child"),
|
112 |
+
gr.Textbox(label="存在女性 Has Female"),
|
113 |
+
gr.Textbox(label="存在男性 Has Male"),
|
114 |
+
],
|
115 |
+
allow_flagging="never",
|
116 |
+
cache_examples=False,
|
117 |
+
)
|
118 |
+
|
119 |
+
iface.launch()
|
mivolo/data/data_reader.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import defaultdict
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from enum import Enum
|
5 |
+
from typing import Dict, List, Optional, Tuple
|
6 |
+
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
IMAGES_EXT: Tuple = (".jpeg", ".jpg", ".png", ".webp", ".bmp", ".gif")
|
10 |
+
VIDEO_EXT: Tuple = (".mp4", ".avi", ".mov", ".mkv", ".webm")
|
11 |
+
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class PictureInfo:
|
15 |
+
image_path: str
|
16 |
+
age: Optional[str] # age or age range(start;end format) or "-1"
|
17 |
+
gender: Optional[str] # "M" of "F" or "-1"
|
18 |
+
bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # face bbox: xyxy
|
19 |
+
person_bbox: List[int] = field(default_factory=lambda: [-1, -1, -1, -1]) # person bbox: xyxy
|
20 |
+
|
21 |
+
@property
|
22 |
+
def has_person_bbox(self) -> bool:
|
23 |
+
return any(coord != -1 for coord in self.person_bbox)
|
24 |
+
|
25 |
+
@property
|
26 |
+
def has_face_bbox(self) -> bool:
|
27 |
+
return any(coord != -1 for coord in self.bbox)
|
28 |
+
|
29 |
+
def has_gt(self, only_age: bool = False) -> bool:
|
30 |
+
if only_age:
|
31 |
+
return self.age != "-1"
|
32 |
+
else:
|
33 |
+
return not (self.age == "-1" and self.gender == "-1")
|
34 |
+
|
35 |
+
def clear_person_bbox(self):
|
36 |
+
self.person_bbox = [-1, -1, -1, -1]
|
37 |
+
|
38 |
+
def clear_face_bbox(self):
|
39 |
+
self.bbox = [-1, -1, -1, -1]
|
40 |
+
|
41 |
+
|
42 |
+
class AnnotType(Enum):
|
43 |
+
ORIGINAL = "original"
|
44 |
+
PERSONS = "persons"
|
45 |
+
NONE = "none"
|
46 |
+
|
47 |
+
@classmethod
|
48 |
+
def _missing_(cls, value):
|
49 |
+
print(f"WARN: Unknown annotation type {value}.")
|
50 |
+
return AnnotType.NONE
|
51 |
+
|
52 |
+
|
53 |
+
def get_all_files(path: str, extensions: Tuple = IMAGES_EXT):
|
54 |
+
files_all = []
|
55 |
+
for root, subFolders, files in os.walk(path):
|
56 |
+
for name in files:
|
57 |
+
# linux tricks with .directory that still is file
|
58 |
+
if "directory" not in name and sum([ext.lower() in name.lower() for ext in extensions]) > 0:
|
59 |
+
files_all.append(os.path.join(root, name))
|
60 |
+
return files_all
|
61 |
+
|
62 |
+
|
63 |
+
class InputType(Enum):
|
64 |
+
Image = 0
|
65 |
+
Video = 1
|
66 |
+
VideoStream = 2
|
67 |
+
|
68 |
+
|
69 |
+
def get_input_type(input_path: str) -> InputType:
|
70 |
+
if os.path.isdir(input_path):
|
71 |
+
print("Input is a folder, only images will be processed")
|
72 |
+
return InputType.Image
|
73 |
+
elif os.path.isfile(input_path):
|
74 |
+
if input_path.endswith(VIDEO_EXT):
|
75 |
+
return InputType.Video
|
76 |
+
if input_path.endswith(IMAGES_EXT):
|
77 |
+
return InputType.Image
|
78 |
+
else:
|
79 |
+
raise ValueError(
|
80 |
+
f"Unknown or unsupported input file format {input_path}, \
|
81 |
+
supported video formats: {VIDEO_EXT}, \
|
82 |
+
supported image formats: {IMAGES_EXT}"
|
83 |
+
)
|
84 |
+
elif input_path.startswith("http") and not input_path.endswith(IMAGES_EXT):
|
85 |
+
return InputType.VideoStream
|
86 |
+
else:
|
87 |
+
raise ValueError(f"Unknown input {input_path}")
|
88 |
+
|
89 |
+
|
90 |
+
def read_csv_annotation_file(annotation_file: str, images_dir: str, ignore_without_gt=False):
|
91 |
+
bboxes_per_image: Dict[str, List[PictureInfo]] = defaultdict(list)
|
92 |
+
|
93 |
+
df = pd.read_csv(annotation_file, sep=",")
|
94 |
+
|
95 |
+
annot_type = AnnotType("persons") if "person_x0" in df.columns else AnnotType("original")
|
96 |
+
print(f"Reading {annotation_file} (type: {annot_type})...")
|
97 |
+
|
98 |
+
missing_images = 0
|
99 |
+
for index, row in df.iterrows():
|
100 |
+
img_path = os.path.join(images_dir, row["img_name"])
|
101 |
+
if not os.path.exists(img_path):
|
102 |
+
missing_images += 1
|
103 |
+
continue
|
104 |
+
|
105 |
+
face_x1, face_y1, face_x2, face_y2 = row["face_x0"], row["face_y0"], row["face_x1"], row["face_y1"]
|
106 |
+
age, gender = str(row["age"]), str(row["gender"])
|
107 |
+
|
108 |
+
if ignore_without_gt and (age == "-1" or gender == "-1"):
|
109 |
+
continue
|
110 |
+
|
111 |
+
if annot_type == AnnotType.PERSONS:
|
112 |
+
p_x1, p_y1, p_x2, p_y2 = row["person_x0"], row["person_y0"], row["person_x1"], row["person_y1"]
|
113 |
+
person_bbox = list(map(int, [p_x1, p_y1, p_x2, p_y2]))
|
114 |
+
else:
|
115 |
+
person_bbox = [-1, -1, -1, -1]
|
116 |
+
|
117 |
+
bbox = list(map(int, [face_x1, face_y1, face_x2, face_y2]))
|
118 |
+
pic_info = PictureInfo(img_path, age, gender, bbox, person_bbox)
|
119 |
+
assert isinstance(pic_info.person_bbox, list)
|
120 |
+
|
121 |
+
bboxes_per_image[img_path].append(pic_info)
|
122 |
+
|
123 |
+
if missing_images > 0:
|
124 |
+
print(f"WARNING: Missing images: {missing_images}/{len(df)}")
|
125 |
+
return bboxes_per_image, annot_type
|
mivolo/data/dataset/__init__.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from mivolo.model.mi_volo import MiVOLO
|
5 |
+
|
6 |
+
from .age_gender_dataset import AgeGenderDataset
|
7 |
+
from .age_gender_loader import create_loader
|
8 |
+
from .classification_dataset import AdienceDataset, FairFaceDataset
|
9 |
+
|
10 |
+
DATASET_CLASS_MAP = {
|
11 |
+
"utk": AgeGenderDataset,
|
12 |
+
"lagenda": AgeGenderDataset,
|
13 |
+
"imdb": AgeGenderDataset,
|
14 |
+
"adience": AdienceDataset,
|
15 |
+
"fairface": FairFaceDataset,
|
16 |
+
}
|
17 |
+
|
18 |
+
|
19 |
+
def build(
|
20 |
+
name: str,
|
21 |
+
images_path: str,
|
22 |
+
annotations_path: str,
|
23 |
+
split: str,
|
24 |
+
mivolo_model: MiVOLO,
|
25 |
+
workers: int,
|
26 |
+
batch_size: int,
|
27 |
+
) -> Tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:
|
28 |
+
|
29 |
+
dataset_class = DATASET_CLASS_MAP[name]
|
30 |
+
|
31 |
+
dataset: torch.utils.data.Dataset = dataset_class(
|
32 |
+
images_path=images_path,
|
33 |
+
annotations_path=annotations_path,
|
34 |
+
name=name,
|
35 |
+
split=split,
|
36 |
+
target_size=mivolo_model.input_size,
|
37 |
+
max_age=mivolo_model.meta.max_age,
|
38 |
+
min_age=mivolo_model.meta.min_age,
|
39 |
+
model_with_persons=mivolo_model.meta.with_persons_model,
|
40 |
+
use_persons=mivolo_model.meta.use_persons,
|
41 |
+
disable_faces=mivolo_model.meta.disable_faces,
|
42 |
+
only_age=mivolo_model.meta.only_age,
|
43 |
+
)
|
44 |
+
|
45 |
+
data_config = mivolo_model.data_config
|
46 |
+
|
47 |
+
in_chans = 3 if not mivolo_model.meta.with_persons_model else 6
|
48 |
+
input_size = (in_chans, mivolo_model.input_size, mivolo_model.input_size)
|
49 |
+
|
50 |
+
dataset_loader: torch.utils.data.DataLoader = create_loader(
|
51 |
+
dataset,
|
52 |
+
input_size=input_size,
|
53 |
+
batch_size=batch_size,
|
54 |
+
mean=data_config["mean"],
|
55 |
+
std=data_config["std"],
|
56 |
+
num_workers=workers,
|
57 |
+
crop_pct=data_config["crop_pct"],
|
58 |
+
crop_mode=data_config["crop_mode"],
|
59 |
+
pin_memory=False,
|
60 |
+
device=mivolo_model.device,
|
61 |
+
target_type=dataset.target_dtype,
|
62 |
+
)
|
63 |
+
|
64 |
+
return dataset, dataset_loader
|
mivolo/data/dataset/age_gender_dataset.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Any, List, Optional, Set
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from mivolo.data.dataset.reader_age_gender import ReaderAgeGender
|
8 |
+
from PIL import Image
|
9 |
+
from torchvision import transforms
|
10 |
+
|
11 |
+
_logger = logging.getLogger("AgeGenderDataset")
|
12 |
+
|
13 |
+
|
14 |
+
class AgeGenderDataset(torch.utils.data.Dataset):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
images_path,
|
18 |
+
annotations_path,
|
19 |
+
name=None,
|
20 |
+
split="train",
|
21 |
+
load_bytes=False,
|
22 |
+
img_mode="RGB",
|
23 |
+
transform=None,
|
24 |
+
is_training=False,
|
25 |
+
seed=1234,
|
26 |
+
target_size=224,
|
27 |
+
min_age=None,
|
28 |
+
max_age=None,
|
29 |
+
model_with_persons=False,
|
30 |
+
use_persons=False,
|
31 |
+
disable_faces=False,
|
32 |
+
only_age=False,
|
33 |
+
):
|
34 |
+
reader = ReaderAgeGender(
|
35 |
+
images_path,
|
36 |
+
annotations_path,
|
37 |
+
split=split,
|
38 |
+
seed=seed,
|
39 |
+
target_size=target_size,
|
40 |
+
with_persons=use_persons,
|
41 |
+
disable_faces=disable_faces,
|
42 |
+
only_age=only_age,
|
43 |
+
)
|
44 |
+
|
45 |
+
self.name = name
|
46 |
+
self.model_with_persons = model_with_persons
|
47 |
+
self.reader = reader
|
48 |
+
self.load_bytes = load_bytes
|
49 |
+
self.img_mode = img_mode
|
50 |
+
self.transform = transform
|
51 |
+
self._consecutive_errors = 0
|
52 |
+
self.is_training = is_training
|
53 |
+
self.random_flip = 0.0
|
54 |
+
|
55 |
+
# Setting up classes.
|
56 |
+
# If min and max classes are passed - use them to have the same preprocessing for validation
|
57 |
+
self.max_age: float = None
|
58 |
+
self.min_age: float = None
|
59 |
+
self.avg_age: float = None
|
60 |
+
self.set_ages_min_max(min_age, max_age)
|
61 |
+
|
62 |
+
self.genders = ["M", "F"]
|
63 |
+
self.num_classes_gender = len(self.genders)
|
64 |
+
|
65 |
+
self.age_classes: Optional[List[str]] = self.set_age_classes()
|
66 |
+
|
67 |
+
self.num_classes_age = 1 if self.age_classes is None else len(self.age_classes)
|
68 |
+
self.num_classes: int = self.num_classes_age + self.num_classes_gender
|
69 |
+
self.target_dtype = torch.float32
|
70 |
+
|
71 |
+
def set_age_classes(self) -> Optional[List[str]]:
|
72 |
+
return None # for regression dataset
|
73 |
+
|
74 |
+
def set_ages_min_max(self, min_age: Optional[float], max_age: Optional[float]):
|
75 |
+
|
76 |
+
assert all(age is None for age in [min_age, max_age]) or all(
|
77 |
+
age is not None for age in [min_age, max_age]
|
78 |
+
), "Both min and max age must be passed or none of them"
|
79 |
+
|
80 |
+
if max_age is not None and min_age is not None:
|
81 |
+
_logger.info(f"Received predefined min_age {min_age} and max_age {max_age}")
|
82 |
+
self.max_age = max_age
|
83 |
+
self.min_age = min_age
|
84 |
+
else:
|
85 |
+
# collect statistics from loaded dataset
|
86 |
+
all_ages_set: Set[int] = set()
|
87 |
+
for img_path, image_samples in self.reader._ann.items():
|
88 |
+
for image_sample_info in image_samples:
|
89 |
+
if image_sample_info.age == "-1":
|
90 |
+
continue
|
91 |
+
age = round(float(image_sample_info.age))
|
92 |
+
all_ages_set.add(age)
|
93 |
+
|
94 |
+
self.max_age = max(all_ages_set)
|
95 |
+
self.min_age = min(all_ages_set)
|
96 |
+
|
97 |
+
self.avg_age = (self.max_age + self.min_age) / 2.0
|
98 |
+
|
99 |
+
def _norm_age(self, age):
|
100 |
+
return (age - self.avg_age) / (self.max_age - self.min_age)
|
101 |
+
|
102 |
+
def parse_gender(self, _gender: str) -> float:
|
103 |
+
if _gender != "-1":
|
104 |
+
gender = float(0 if _gender == "M" or _gender == "0" else 1)
|
105 |
+
else:
|
106 |
+
gender = -1
|
107 |
+
return gender
|
108 |
+
|
109 |
+
def parse_target(self, _age: str, gender: str) -> List[Any]:
|
110 |
+
if _age != "-1":
|
111 |
+
age = round(float(_age))
|
112 |
+
age = self._norm_age(float(age))
|
113 |
+
else:
|
114 |
+
age = -1
|
115 |
+
|
116 |
+
target: List[float] = [age, self.parse_gender(gender)]
|
117 |
+
return target
|
118 |
+
|
119 |
+
@property
|
120 |
+
def transform(self):
|
121 |
+
return self._transform
|
122 |
+
|
123 |
+
@transform.setter
|
124 |
+
def transform(self, transform):
|
125 |
+
# Disable pretrained monkey-patched transforms
|
126 |
+
if not transform:
|
127 |
+
return
|
128 |
+
|
129 |
+
_trans = []
|
130 |
+
for trans in transform.transforms:
|
131 |
+
if "Resize" in str(trans):
|
132 |
+
continue
|
133 |
+
if "Crop" in str(trans):
|
134 |
+
continue
|
135 |
+
_trans.append(trans)
|
136 |
+
self._transform = transforms.Compose(_trans)
|
137 |
+
|
138 |
+
def apply_tranforms(self, image: Optional[np.ndarray]) -> np.ndarray:
|
139 |
+
if image is None:
|
140 |
+
return None
|
141 |
+
|
142 |
+
if self.transform is None:
|
143 |
+
return image
|
144 |
+
|
145 |
+
image = convert_to_pil(image, self.img_mode)
|
146 |
+
for trans in self.transform.transforms:
|
147 |
+
image = trans(image)
|
148 |
+
return image
|
149 |
+
|
150 |
+
def __getitem__(self, index):
|
151 |
+
# get preprocessed face and person crops (np.ndarray)
|
152 |
+
# resize + pad, for person crops: cut off other bboxes
|
153 |
+
images, target = self.reader[index]
|
154 |
+
|
155 |
+
target = self.parse_target(*target)
|
156 |
+
|
157 |
+
if self.model_with_persons:
|
158 |
+
face_image, person_image = images
|
159 |
+
person_image: np.ndarray = self.apply_tranforms(person_image)
|
160 |
+
else:
|
161 |
+
face_image = images[0]
|
162 |
+
person_image = None
|
163 |
+
|
164 |
+
face_image: np.ndarray = self.apply_tranforms(face_image)
|
165 |
+
|
166 |
+
if person_image is not None:
|
167 |
+
img = np.concatenate([face_image, person_image], axis=0)
|
168 |
+
else:
|
169 |
+
img = face_image
|
170 |
+
|
171 |
+
return img, target
|
172 |
+
|
173 |
+
def __len__(self):
|
174 |
+
return len(self.reader)
|
175 |
+
|
176 |
+
def filename(self, index, basename=False, absolute=False):
|
177 |
+
return self.reader.filename(index, basename, absolute)
|
178 |
+
|
179 |
+
def filenames(self, basename=False, absolute=False):
|
180 |
+
return self.reader.filenames(basename, absolute)
|
181 |
+
|
182 |
+
|
183 |
+
def convert_to_pil(cv_im: Optional[np.ndarray], img_mode: str = "RGB") -> "Image":
|
184 |
+
if cv_im is None:
|
185 |
+
return None
|
186 |
+
|
187 |
+
if img_mode == "RGB":
|
188 |
+
cv_im = cv2.cvtColor(cv_im, cv2.COLOR_BGR2RGB)
|
189 |
+
else:
|
190 |
+
raise Exception("Incorrect image mode has been passed!")
|
191 |
+
|
192 |
+
cv_im = np.ascontiguousarray(cv_im)
|
193 |
+
pil_image = Image.fromarray(cv_im)
|
194 |
+
return pil_image
|
mivolo/data/dataset/age_gender_loader.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from timm https://github.com/huggingface/pytorch-image-models
|
3 |
+
|
4 |
+
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
5 |
+
"""
|
6 |
+
|
7 |
+
import logging
|
8 |
+
from contextlib import suppress
|
9 |
+
from functools import partial
|
10 |
+
from itertools import repeat
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.utils.data
|
15 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
16 |
+
from timm.data.dataset import IterableImageDataset
|
17 |
+
from timm.data.loader import PrefetchLoader, _worker_init
|
18 |
+
from timm.data.transforms_factory import create_transform
|
19 |
+
|
20 |
+
_logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
|
23 |
+
def fast_collate(batch, target_dtype=torch.uint8):
|
24 |
+
"""A fast collation function optimized for uint8 images (np array or torch) and target_dtype targets (labels)"""
|
25 |
+
assert isinstance(batch[0], tuple)
|
26 |
+
batch_size = len(batch)
|
27 |
+
if isinstance(batch[0][0], np.ndarray):
|
28 |
+
targets = torch.tensor([b[1] for b in batch], dtype=target_dtype)
|
29 |
+
assert len(targets) == batch_size
|
30 |
+
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
31 |
+
for i in range(batch_size):
|
32 |
+
tensor[i] += torch.from_numpy(batch[i][0])
|
33 |
+
return tensor, targets
|
34 |
+
else:
|
35 |
+
raise ValueError(f"Incorrect batch type: {type(batch[0][0])}")
|
36 |
+
|
37 |
+
|
38 |
+
def adapt_to_chs(x, n):
|
39 |
+
if not isinstance(x, (tuple, list)):
|
40 |
+
x = tuple(repeat(x, n))
|
41 |
+
elif len(x) != n:
|
42 |
+
# doubled channels
|
43 |
+
if len(x) * 2 == n:
|
44 |
+
x = np.concatenate((x, x))
|
45 |
+
_logger.warning(f"Pretrained mean/std different shape than model (doubled channes), using concat: {x}.")
|
46 |
+
else:
|
47 |
+
x_mean = np.mean(x).item()
|
48 |
+
x = (x_mean,) * n
|
49 |
+
_logger.warning(f"Pretrained mean/std different shape than model, using avg value {x}.")
|
50 |
+
else:
|
51 |
+
assert len(x) == n, "normalization stats must match image channels"
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class PrefetchLoaderForMultiInput(PrefetchLoader):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
loader,
|
59 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
60 |
+
std=IMAGENET_DEFAULT_STD,
|
61 |
+
channels=3,
|
62 |
+
device=torch.device("cpu"),
|
63 |
+
img_dtype=torch.float32,
|
64 |
+
):
|
65 |
+
|
66 |
+
mean = adapt_to_chs(mean, channels)
|
67 |
+
std = adapt_to_chs(std, channels)
|
68 |
+
normalization_shape = (1, channels, 1, 1)
|
69 |
+
|
70 |
+
self.loader = loader
|
71 |
+
self.device = device
|
72 |
+
self.img_dtype = img_dtype
|
73 |
+
self.mean = torch.tensor([x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
|
74 |
+
self.std = torch.tensor([x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
|
75 |
+
|
76 |
+
self.is_cuda = torch.cuda.is_available() and device.type == "cpu"
|
77 |
+
|
78 |
+
def __iter__(self):
|
79 |
+
first = True
|
80 |
+
if self.is_cuda:
|
81 |
+
stream = torch.cuda.Stream()
|
82 |
+
stream_context = partial(torch.cuda.stream, stream=stream)
|
83 |
+
else:
|
84 |
+
stream = None
|
85 |
+
stream_context = suppress
|
86 |
+
|
87 |
+
for next_input, next_target in self.loader:
|
88 |
+
|
89 |
+
with stream_context():
|
90 |
+
next_input = next_input.to(device=self.device, non_blocking=True)
|
91 |
+
next_target = next_target.to(device=self.device, non_blocking=True)
|
92 |
+
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
|
93 |
+
|
94 |
+
if not first:
|
95 |
+
yield input, target # noqa: F823, F821
|
96 |
+
else:
|
97 |
+
first = False
|
98 |
+
|
99 |
+
if stream is not None:
|
100 |
+
torch.cuda.current_stream().wait_stream(stream)
|
101 |
+
|
102 |
+
input = next_input
|
103 |
+
target = next_target
|
104 |
+
|
105 |
+
yield input, target
|
106 |
+
|
107 |
+
|
108 |
+
def create_loader(
|
109 |
+
dataset,
|
110 |
+
input_size,
|
111 |
+
batch_size,
|
112 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
113 |
+
std=IMAGENET_DEFAULT_STD,
|
114 |
+
num_workers=1,
|
115 |
+
crop_pct=None,
|
116 |
+
crop_mode=None,
|
117 |
+
pin_memory=False,
|
118 |
+
img_dtype=torch.float32,
|
119 |
+
device=torch.device("cpu"),
|
120 |
+
persistent_workers=True,
|
121 |
+
worker_seeding="all",
|
122 |
+
target_type=torch.int64,
|
123 |
+
):
|
124 |
+
|
125 |
+
transform = create_transform(
|
126 |
+
input_size,
|
127 |
+
is_training=False,
|
128 |
+
use_prefetcher=True,
|
129 |
+
mean=mean,
|
130 |
+
std=std,
|
131 |
+
crop_pct=crop_pct,
|
132 |
+
crop_mode=crop_mode,
|
133 |
+
)
|
134 |
+
dataset.transform = transform
|
135 |
+
|
136 |
+
if isinstance(dataset, IterableImageDataset):
|
137 |
+
# give Iterable datasets early knowledge of num_workers so that sample estimates
|
138 |
+
# are correct before worker processes are launched
|
139 |
+
dataset.set_loader_cfg(num_workers=num_workers)
|
140 |
+
raise ValueError("Incorrect dataset type: IterableImageDataset")
|
141 |
+
|
142 |
+
loader_class = torch.utils.data.DataLoader
|
143 |
+
loader_args = dict(
|
144 |
+
batch_size=batch_size,
|
145 |
+
shuffle=False,
|
146 |
+
num_workers=num_workers,
|
147 |
+
sampler=None,
|
148 |
+
collate_fn=lambda batch: fast_collate(batch, target_dtype=target_type),
|
149 |
+
pin_memory=pin_memory,
|
150 |
+
drop_last=False,
|
151 |
+
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
|
152 |
+
persistent_workers=persistent_workers,
|
153 |
+
)
|
154 |
+
try:
|
155 |
+
loader = loader_class(dataset, **loader_args)
|
156 |
+
except TypeError:
|
157 |
+
loader_args.pop("persistent_workers") # only in Pytorch 1.7+
|
158 |
+
loader = loader_class(dataset, **loader_args)
|
159 |
+
|
160 |
+
loader = PrefetchLoaderForMultiInput(
|
161 |
+
loader,
|
162 |
+
mean=mean,
|
163 |
+
std=std,
|
164 |
+
channels=input_size[0],
|
165 |
+
device=device,
|
166 |
+
img_dtype=img_dtype,
|
167 |
+
)
|
168 |
+
|
169 |
+
return loader
|
mivolo/data/dataset/classification_dataset.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .age_gender_dataset import AgeGenderDataset
|
6 |
+
|
7 |
+
|
8 |
+
class ClassificationDataset(AgeGenderDataset):
|
9 |
+
def __init__(self, *args, **kwargs):
|
10 |
+
super().__init__(*args, **kwargs)
|
11 |
+
|
12 |
+
self.target_dtype = torch.int32
|
13 |
+
|
14 |
+
def set_age_classes(self) -> Optional[List[str]]:
|
15 |
+
raise NotImplementedError
|
16 |
+
|
17 |
+
def parse_target(self, age: str, gender: str) -> List[Any]:
|
18 |
+
assert self.age_classes is not None
|
19 |
+
if age != "-1":
|
20 |
+
assert age in self.age_classes, f"Unknown category in {self.name} dataset: {age}"
|
21 |
+
age_ind = self.age_classes.index(age)
|
22 |
+
else:
|
23 |
+
age_ind = -1
|
24 |
+
|
25 |
+
target: List[int] = [age_ind, int(self.parse_gender(gender))]
|
26 |
+
return target
|
27 |
+
|
28 |
+
|
29 |
+
class FairFaceDataset(ClassificationDataset):
|
30 |
+
def set_age_classes(self) -> Optional[List[str]]:
|
31 |
+
age_classes = ["0;2", "3;9", "10;19", "20;29", "30;39", "40;49", "50;59", "60;69", "70;120"]
|
32 |
+
# a[i-1] <= v < a[i] => age_classes[i-1]
|
33 |
+
self._intervals = torch.tensor([0, 3, 10, 20, 30, 40, 50, 60, 70])
|
34 |
+
|
35 |
+
return age_classes
|
36 |
+
|
37 |
+
|
38 |
+
class AdienceDataset(ClassificationDataset):
|
39 |
+
def __init__(self, *args, **kwargs):
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
|
42 |
+
self.target_dtype = torch.int32
|
43 |
+
|
44 |
+
def set_age_classes(self) -> Optional[List[str]]:
|
45 |
+
age_classes = ["0;2", "4;6", "8;12", "15;20", "25;32", "38;43", "48;53", "60;100"]
|
46 |
+
# a[i-1] <= v < a[i] => age_classes[i-1]
|
47 |
+
self._intervals = torch.tensor([0, 4, 7, 14, 24, 36, 46, 57])
|
48 |
+
return age_classes
|
mivolo/data/dataset/reader_age_gender.py
ADDED
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from functools import partial
|
4 |
+
from multiprocessing.pool import ThreadPool
|
5 |
+
from typing import Dict, List, Optional, Tuple
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from mivolo.data.data_reader import AnnotType, PictureInfo, get_all_files, read_csv_annotation_file
|
10 |
+
from mivolo.data.misc import IOU, class_letterbox, cropout_black_parts
|
11 |
+
from timm.data.readers.reader import Reader
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
CROP_ROUND_TOL = 0.3
|
15 |
+
MIN_PERSON_SIZE = 100
|
16 |
+
MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
|
17 |
+
|
18 |
+
_logger = logging.getLogger("ReaderAgeGender")
|
19 |
+
|
20 |
+
|
21 |
+
class ReaderAgeGender(Reader):
|
22 |
+
"""
|
23 |
+
Reader for almost original imdb-wiki cleaned dataset.
|
24 |
+
Two changes:
|
25 |
+
1. Your annotation must be in ./annotation subdir of dataset root
|
26 |
+
2. Images must be in images subdir
|
27 |
+
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
images_path,
|
33 |
+
annotations_path,
|
34 |
+
split="validation",
|
35 |
+
target_size=224,
|
36 |
+
min_size=5,
|
37 |
+
seed=1234,
|
38 |
+
with_persons=False,
|
39 |
+
min_person_size=MIN_PERSON_SIZE,
|
40 |
+
disable_faces=False,
|
41 |
+
only_age=False,
|
42 |
+
min_person_aftercut_ratio=MIN_PERSON_CROP_AFTERCUT_RATIO,
|
43 |
+
crop_round_tol=CROP_ROUND_TOL,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
self.with_persons = with_persons
|
48 |
+
self.disable_faces = disable_faces
|
49 |
+
self.only_age = only_age
|
50 |
+
|
51 |
+
# can be only black for now, even though it's not very good with further normalization
|
52 |
+
self.crop_out_color = (0, 0, 0)
|
53 |
+
|
54 |
+
self.empty_crop = np.ones((target_size, target_size, 3)) * self.crop_out_color
|
55 |
+
self.empty_crop = self.empty_crop.astype(np.uint8)
|
56 |
+
|
57 |
+
self.min_person_size = min_person_size
|
58 |
+
self.min_person_aftercut_ratio = min_person_aftercut_ratio
|
59 |
+
self.crop_round_tol = crop_round_tol
|
60 |
+
|
61 |
+
self.split = split
|
62 |
+
self.min_size = min_size
|
63 |
+
self.seed = seed
|
64 |
+
self.target_size = target_size
|
65 |
+
|
66 |
+
# Reading annotations. Can be multiple files if annotations_path dir
|
67 |
+
self._ann: Dict[str, List[PictureInfo]] = {} # list of samples for each image
|
68 |
+
self._associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
|
69 |
+
self._faces_list: List[Tuple[str, int]] = [] # samples from this list will be loaded in __getitem__
|
70 |
+
|
71 |
+
self._read_annotations(images_path, annotations_path)
|
72 |
+
_logger.info(f"Dataset length: {len(self._faces_list)} crops")
|
73 |
+
|
74 |
+
def __getitem__(self, index):
|
75 |
+
return self._read_img_and_label(index)
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
return len(self._faces_list)
|
79 |
+
|
80 |
+
def _filename(self, index, basename=False, absolute=False):
|
81 |
+
img_p = self._faces_list[index][0]
|
82 |
+
return os.path.basename(img_p) if basename else img_p
|
83 |
+
|
84 |
+
def _read_annotations(self, images_path, csvs_path):
|
85 |
+
self._ann = {}
|
86 |
+
self._faces_list = []
|
87 |
+
self._associated_objects = {}
|
88 |
+
|
89 |
+
csvs = get_all_files(csvs_path, [".csv"])
|
90 |
+
csvs = [c for c in csvs if self.split in os.path.basename(c)]
|
91 |
+
|
92 |
+
# load annotations per image
|
93 |
+
for csv in csvs:
|
94 |
+
db, ann_type = read_csv_annotation_file(csv, images_path)
|
95 |
+
if self.with_persons and ann_type != AnnotType.PERSONS:
|
96 |
+
raise ValueError(
|
97 |
+
f"Annotation type in file {csv} contains no persons, "
|
98 |
+
f"but annotations with persons are requested."
|
99 |
+
)
|
100 |
+
self._ann.update(db)
|
101 |
+
|
102 |
+
if len(self._ann) == 0:
|
103 |
+
raise ValueError("Annotations are empty!")
|
104 |
+
|
105 |
+
self._ann, self._associated_objects = self.prepare_annotations()
|
106 |
+
images_list = list(self._ann.keys())
|
107 |
+
|
108 |
+
for img_path in images_list:
|
109 |
+
for index, image_sample_info in enumerate(self._ann[img_path]):
|
110 |
+
assert image_sample_info.has_gt(
|
111 |
+
self.only_age
|
112 |
+
), "Annotations must be checked with self.prepare_annotations() func"
|
113 |
+
self._faces_list.append((img_path, index))
|
114 |
+
|
115 |
+
def _read_img_and_label(self, index):
|
116 |
+
if not isinstance(index, int):
|
117 |
+
raise TypeError("ReaderAgeGender expected index to be integer")
|
118 |
+
|
119 |
+
img_p, face_index = self._faces_list[index]
|
120 |
+
ann: PictureInfo = self._ann[img_p][face_index]
|
121 |
+
img = cv2.imread(img_p)
|
122 |
+
|
123 |
+
face_empty = True
|
124 |
+
if ann.has_face_bbox and not (self.with_persons and self.disable_faces):
|
125 |
+
face_crop, face_empty = self._get_crop(ann.bbox, img)
|
126 |
+
|
127 |
+
if not self.with_persons and face_empty:
|
128 |
+
# model without persons
|
129 |
+
raise ValueError("Annotations must be checked with self.prepare_annotations() func")
|
130 |
+
|
131 |
+
if face_empty:
|
132 |
+
face_crop = self.empty_crop
|
133 |
+
|
134 |
+
person_empty = True
|
135 |
+
if self.with_persons or self.disable_faces:
|
136 |
+
if ann.has_person_bbox:
|
137 |
+
# cut off all associated objects from person crop
|
138 |
+
objects = self._associated_objects[img_p][face_index]
|
139 |
+
person_crop, person_empty = self._get_crop(
|
140 |
+
ann.person_bbox,
|
141 |
+
img,
|
142 |
+
crop_out_color=self.crop_out_color,
|
143 |
+
asced_objects=objects,
|
144 |
+
)
|
145 |
+
|
146 |
+
if face_empty and person_empty:
|
147 |
+
raise ValueError("Annotations must be checked with self.prepare_annotations() func")
|
148 |
+
|
149 |
+
if person_empty:
|
150 |
+
person_crop = self.empty_crop
|
151 |
+
|
152 |
+
return (face_crop, person_crop), [ann.age, ann.gender]
|
153 |
+
|
154 |
+
def _get_crop(
|
155 |
+
self,
|
156 |
+
bbox,
|
157 |
+
img,
|
158 |
+
asced_objects=None,
|
159 |
+
crop_out_color=(0, 0, 0),
|
160 |
+
) -> Tuple[np.ndarray, bool]:
|
161 |
+
|
162 |
+
empty_bbox = False
|
163 |
+
|
164 |
+
xmin, ymin, xmax, ymax = bbox
|
165 |
+
assert not (
|
166 |
+
ymax - ymin < self.min_size or xmax - xmin < self.min_size
|
167 |
+
), "Annotations must be checked with self.prepare_annotations() func"
|
168 |
+
|
169 |
+
crop = img[ymin:ymax, xmin:xmax]
|
170 |
+
|
171 |
+
if asced_objects:
|
172 |
+
# cut off other objects for person crop
|
173 |
+
crop, empty_bbox = _cropout_asced_objs(
|
174 |
+
asced_objects,
|
175 |
+
bbox,
|
176 |
+
crop.copy(),
|
177 |
+
crop_out_color=crop_out_color,
|
178 |
+
min_person_size=self.min_person_size,
|
179 |
+
crop_round_tol=self.crop_round_tol,
|
180 |
+
min_person_aftercut_ratio=self.min_person_aftercut_ratio,
|
181 |
+
)
|
182 |
+
if empty_bbox:
|
183 |
+
crop = self.empty_crop
|
184 |
+
|
185 |
+
crop = class_letterbox(crop, new_shape=(self.target_size, self.target_size), color=crop_out_color)
|
186 |
+
return crop, empty_bbox
|
187 |
+
|
188 |
+
def prepare_annotations(self):
|
189 |
+
|
190 |
+
good_anns: Dict[str, List[PictureInfo]] = {}
|
191 |
+
all_associated_objects: Dict[str, Dict[int, List[List[int]]]] = {}
|
192 |
+
|
193 |
+
if not self.with_persons:
|
194 |
+
# remove all persons
|
195 |
+
for img_path, bboxes in self._ann.items():
|
196 |
+
for sample in bboxes:
|
197 |
+
sample.clear_person_bbox()
|
198 |
+
|
199 |
+
# check dataset and collect associated_objects
|
200 |
+
verify_images_func = partial(
|
201 |
+
verify_images,
|
202 |
+
min_size=self.min_size,
|
203 |
+
min_person_size=self.min_person_size,
|
204 |
+
with_persons=self.with_persons,
|
205 |
+
disable_faces=self.disable_faces,
|
206 |
+
crop_round_tol=self.crop_round_tol,
|
207 |
+
min_person_aftercut_ratio=self.min_person_aftercut_ratio,
|
208 |
+
only_age=self.only_age,
|
209 |
+
)
|
210 |
+
num_threads = min(8, os.cpu_count())
|
211 |
+
|
212 |
+
all_msgs = []
|
213 |
+
broken = 0
|
214 |
+
skipped = 0
|
215 |
+
all_skipped_crops = 0
|
216 |
+
desc = "Check annotations..."
|
217 |
+
with ThreadPool(num_threads) as pool:
|
218 |
+
pbar = tqdm(
|
219 |
+
pool.imap_unordered(verify_images_func, list(self._ann.items())),
|
220 |
+
desc=desc,
|
221 |
+
total=len(self._ann),
|
222 |
+
)
|
223 |
+
|
224 |
+
for (img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops) in pbar:
|
225 |
+
broken += 1 if is_corrupted else 0
|
226 |
+
all_msgs.extend(msgs)
|
227 |
+
all_skipped_crops += skipped_crops
|
228 |
+
skipped += 1 if is_empty_annotations else 0
|
229 |
+
if img_info is not None:
|
230 |
+
img_path, img_samples = img_info
|
231 |
+
good_anns[img_path] = img_samples
|
232 |
+
all_associated_objects.update({img_path: associated_objects})
|
233 |
+
|
234 |
+
pbar.desc = (
|
235 |
+
f"{desc} {skipped} images skipped ({all_skipped_crops} crops are incorrect); "
|
236 |
+
f"{broken} images corrupted"
|
237 |
+
)
|
238 |
+
|
239 |
+
pbar.close()
|
240 |
+
|
241 |
+
for msg in all_msgs:
|
242 |
+
print(msg)
|
243 |
+
print(f"\nLeft images: {len(good_anns)}")
|
244 |
+
|
245 |
+
return good_anns, all_associated_objects
|
246 |
+
|
247 |
+
|
248 |
+
def verify_images(
|
249 |
+
img_info,
|
250 |
+
min_size: int,
|
251 |
+
min_person_size: int,
|
252 |
+
with_persons: bool,
|
253 |
+
disable_faces: bool,
|
254 |
+
crop_round_tol: float,
|
255 |
+
min_person_aftercut_ratio: float,
|
256 |
+
only_age: bool,
|
257 |
+
):
|
258 |
+
# If crop is too small, if image can not be read or if image does not exist
|
259 |
+
# then filter out this sample
|
260 |
+
|
261 |
+
disable_faces = disable_faces and with_persons
|
262 |
+
kwargs = dict(
|
263 |
+
min_person_size=min_person_size,
|
264 |
+
disable_faces=disable_faces,
|
265 |
+
with_persons=with_persons,
|
266 |
+
crop_round_tol=crop_round_tol,
|
267 |
+
min_person_aftercut_ratio=min_person_aftercut_ratio,
|
268 |
+
only_age=only_age,
|
269 |
+
)
|
270 |
+
|
271 |
+
def bbox_correct(bbox, min_size, im_h, im_w) -> Tuple[bool, List[int]]:
|
272 |
+
ymin, ymax, xmin, xmax = _correct_bbox(bbox, im_h, im_w)
|
273 |
+
crop_h, crop_w = ymax - ymin, xmax - xmin
|
274 |
+
if crop_h < min_size or crop_w < min_size:
|
275 |
+
return False, [-1, -1, -1, -1]
|
276 |
+
bbox = [xmin, ymin, xmax, ymax]
|
277 |
+
return True, bbox
|
278 |
+
|
279 |
+
msgs = []
|
280 |
+
skipped_crops = 0
|
281 |
+
is_corrupted = False
|
282 |
+
is_empty_annotations = False
|
283 |
+
|
284 |
+
img_path: str = img_info[0]
|
285 |
+
img_samples: List[PictureInfo] = img_info[1]
|
286 |
+
try:
|
287 |
+
im_cv = cv2.imread(img_path)
|
288 |
+
im_h, im_w = im_cv.shape[:2]
|
289 |
+
except Exception:
|
290 |
+
msgs.append(f"Can not load image {img_path}")
|
291 |
+
is_corrupted = True
|
292 |
+
return None, {}, msgs, is_corrupted, is_empty_annotations, skipped_crops
|
293 |
+
|
294 |
+
out_samples: List[PictureInfo] = []
|
295 |
+
for sample in img_samples:
|
296 |
+
# correct face bbox
|
297 |
+
if sample.has_face_bbox:
|
298 |
+
is_correct, sample.bbox = bbox_correct(sample.bbox, min_size, im_h, im_w)
|
299 |
+
if not is_correct and sample.has_gt(only_age):
|
300 |
+
msgs.append("Small face. Passing..")
|
301 |
+
skipped_crops += 1
|
302 |
+
|
303 |
+
# correct person bbox
|
304 |
+
if sample.has_person_bbox:
|
305 |
+
is_correct, sample.person_bbox = bbox_correct(
|
306 |
+
sample.person_bbox, max(min_person_size, min_size), im_h, im_w
|
307 |
+
)
|
308 |
+
if not is_correct and sample.has_gt(only_age):
|
309 |
+
msgs.append(f"Small person {img_path}. Passing..")
|
310 |
+
skipped_crops += 1
|
311 |
+
|
312 |
+
if sample.has_face_bbox or sample.has_person_bbox:
|
313 |
+
out_samples.append(sample)
|
314 |
+
elif sample.has_gt(only_age):
|
315 |
+
msgs.append("Sample hs no face and no body. Passing..")
|
316 |
+
skipped_crops += 1
|
317 |
+
|
318 |
+
# sort that samples with undefined age and gender be the last
|
319 |
+
out_samples = sorted(out_samples, key=lambda sample: 1 if not sample.has_gt(only_age) else 0)
|
320 |
+
|
321 |
+
# for each person find other faces and persons bboxes, intersected with it
|
322 |
+
associated_objects: Dict[int, List[List[int]]] = find_associated_objects(out_samples, only_age=only_age)
|
323 |
+
|
324 |
+
out_samples, associated_objects, skipped_crops = filter_bad_samples(
|
325 |
+
out_samples, associated_objects, im_cv, msgs, skipped_crops, **kwargs
|
326 |
+
)
|
327 |
+
|
328 |
+
out_img_info: Optional[Tuple[str, List]] = (img_path, out_samples)
|
329 |
+
if len(out_samples) == 0:
|
330 |
+
out_img_info = None
|
331 |
+
is_empty_annotations = True
|
332 |
+
|
333 |
+
return out_img_info, associated_objects, msgs, is_corrupted, is_empty_annotations, skipped_crops
|
334 |
+
|
335 |
+
|
336 |
+
def filter_bad_samples(
|
337 |
+
out_samples: List[PictureInfo],
|
338 |
+
associated_objects: dict,
|
339 |
+
im_cv: np.ndarray,
|
340 |
+
msgs: List[str],
|
341 |
+
skipped_crops: int,
|
342 |
+
**kwargs,
|
343 |
+
):
|
344 |
+
with_persons, disable_faces, min_person_size, crop_round_tol, min_person_aftercut_ratio, only_age = (
|
345 |
+
kwargs["with_persons"],
|
346 |
+
kwargs["disable_faces"],
|
347 |
+
kwargs["min_person_size"],
|
348 |
+
kwargs["crop_round_tol"],
|
349 |
+
kwargs["min_person_aftercut_ratio"],
|
350 |
+
kwargs["only_age"],
|
351 |
+
)
|
352 |
+
|
353 |
+
# left only samples with annotations
|
354 |
+
inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_gt(only_age)]
|
355 |
+
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
|
356 |
+
|
357 |
+
if kwargs["disable_faces"]:
|
358 |
+
# clear all faces
|
359 |
+
for ind, sample in enumerate(out_samples):
|
360 |
+
sample.clear_face_bbox()
|
361 |
+
|
362 |
+
# left only samples with person_bbox
|
363 |
+
inds = [sample_ind for sample_ind, sample in enumerate(out_samples) if sample.has_person_bbox]
|
364 |
+
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
|
365 |
+
|
366 |
+
if with_persons or disable_faces:
|
367 |
+
# check that preprocessing func
|
368 |
+
# _cropout_asced_objs() return not empty person_image for each out sample
|
369 |
+
|
370 |
+
inds = []
|
371 |
+
for ind, sample in enumerate(out_samples):
|
372 |
+
person_empty = True
|
373 |
+
if sample.has_person_bbox:
|
374 |
+
xmin, ymin, xmax, ymax = sample.person_bbox
|
375 |
+
crop = im_cv[ymin:ymax, xmin:xmax]
|
376 |
+
# cut off all associated objects from person crop
|
377 |
+
_, person_empty = _cropout_asced_objs(
|
378 |
+
associated_objects[ind],
|
379 |
+
sample.person_bbox,
|
380 |
+
crop.copy(),
|
381 |
+
min_person_size=min_person_size,
|
382 |
+
crop_round_tol=crop_round_tol,
|
383 |
+
min_person_aftercut_ratio=min_person_aftercut_ratio,
|
384 |
+
)
|
385 |
+
|
386 |
+
if person_empty and not sample.has_face_bbox:
|
387 |
+
msgs.append("Small person after preprocessing. Passing..")
|
388 |
+
skipped_crops += 1
|
389 |
+
else:
|
390 |
+
inds.append(ind)
|
391 |
+
out_samples, associated_objects = _filter_by_ind(out_samples, associated_objects, inds)
|
392 |
+
|
393 |
+
assert len(associated_objects) == len(out_samples)
|
394 |
+
return out_samples, associated_objects, skipped_crops
|
395 |
+
|
396 |
+
|
397 |
+
def _filter_by_ind(out_samples, associated_objects, inds):
|
398 |
+
_associated_objects = {}
|
399 |
+
_out_samples = []
|
400 |
+
for ind, sample in enumerate(out_samples):
|
401 |
+
if ind in inds:
|
402 |
+
_associated_objects[len(_out_samples)] = associated_objects[ind]
|
403 |
+
_out_samples.append(sample)
|
404 |
+
|
405 |
+
return _out_samples, _associated_objects
|
406 |
+
|
407 |
+
|
408 |
+
def find_associated_objects(
|
409 |
+
image_samples: List[PictureInfo], iou_thresh=0.0001, only_age=False
|
410 |
+
) -> Dict[int, List[List[int]]]:
|
411 |
+
"""
|
412 |
+
For each person (which has gt age and gt gender) find other faces and persons bboxes, intersected with it
|
413 |
+
"""
|
414 |
+
associated_objects: Dict[int, List[List[int]]] = {}
|
415 |
+
|
416 |
+
for iindex, image_sample_info in enumerate(image_samples):
|
417 |
+
# add own face
|
418 |
+
associated_objects[iindex] = [image_sample_info.bbox] if image_sample_info.has_face_bbox else []
|
419 |
+
|
420 |
+
if not image_sample_info.has_person_bbox or not image_sample_info.has_gt(only_age):
|
421 |
+
# if sample has not gt => not be used
|
422 |
+
continue
|
423 |
+
|
424 |
+
iperson_box = image_sample_info.person_bbox
|
425 |
+
for jindex, other_image_sample in enumerate(image_samples):
|
426 |
+
if iindex == jindex:
|
427 |
+
continue
|
428 |
+
if other_image_sample.has_face_bbox:
|
429 |
+
jface_bbox = other_image_sample.bbox
|
430 |
+
iou = _get_iou(jface_bbox, iperson_box)
|
431 |
+
if iou >= iou_thresh:
|
432 |
+
associated_objects[iindex].append(jface_bbox)
|
433 |
+
if other_image_sample.has_person_bbox:
|
434 |
+
jperson_bbox = other_image_sample.person_bbox
|
435 |
+
iou = _get_iou(jperson_bbox, iperson_box)
|
436 |
+
if iou >= iou_thresh:
|
437 |
+
associated_objects[iindex].append(jperson_bbox)
|
438 |
+
|
439 |
+
return associated_objects
|
440 |
+
|
441 |
+
|
442 |
+
def _cropout_asced_objs(
|
443 |
+
asced_objects,
|
444 |
+
person_bbox,
|
445 |
+
crop,
|
446 |
+
min_person_size,
|
447 |
+
crop_round_tol,
|
448 |
+
min_person_aftercut_ratio,
|
449 |
+
crop_out_color=(0, 0, 0),
|
450 |
+
):
|
451 |
+
empty = False
|
452 |
+
xmin, ymin, xmax, ymax = person_bbox
|
453 |
+
|
454 |
+
for a_obj in asced_objects:
|
455 |
+
aobj_xmin, aobj_ymin, aobj_xmax, aobj_ymax = a_obj
|
456 |
+
|
457 |
+
aobj_ymin = int(max(aobj_ymin - ymin, 0))
|
458 |
+
aobj_xmin = int(max(aobj_xmin - xmin, 0))
|
459 |
+
aobj_ymax = int(min(aobj_ymax - ymin, ymax - ymin))
|
460 |
+
aobj_xmax = int(min(aobj_xmax - xmin, xmax - xmin))
|
461 |
+
|
462 |
+
crop[aobj_ymin:aobj_ymax, aobj_xmin:aobj_xmax] = crop_out_color
|
463 |
+
|
464 |
+
crop, cropped_ratio = cropout_black_parts(crop, crop_round_tol)
|
465 |
+
if (
|
466 |
+
crop.shape[0] < min_person_size or crop.shape[1] < min_person_size
|
467 |
+
) or cropped_ratio < min_person_aftercut_ratio:
|
468 |
+
crop = None
|
469 |
+
empty = True
|
470 |
+
|
471 |
+
return crop, empty
|
472 |
+
|
473 |
+
|
474 |
+
def _correct_bbox(bbox, h, w):
|
475 |
+
xmin, ymin, xmax, ymax = bbox
|
476 |
+
ymin = min(max(ymin, 0), h)
|
477 |
+
ymax = min(max(ymax, 0), h)
|
478 |
+
xmin = min(max(xmin, 0), w)
|
479 |
+
xmax = min(max(xmax, 0), w)
|
480 |
+
return ymin, ymax, xmin, xmax
|
481 |
+
|
482 |
+
|
483 |
+
def _get_iou(bbox1, bbox2):
|
484 |
+
xmin1, ymin1, xmax1, ymax1 = bbox1
|
485 |
+
xmin2, ymin2, xmax2, ymax2 = bbox2
|
486 |
+
iou = IOU(
|
487 |
+
[ymin1, xmin1, ymax1, xmax1],
|
488 |
+
[ymin2, xmin2, ymax2, xmax2],
|
489 |
+
)
|
490 |
+
return iou
|
mivolo/data/misc.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import ast
|
3 |
+
import re
|
4 |
+
from typing import List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torchvision.transforms.functional as F
|
10 |
+
from scipy.optimize import linear_sum_assignment
|
11 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
12 |
+
|
13 |
+
CROP_ROUND_RATE = 0.1
|
14 |
+
MIN_PERSON_CROP_NONZERO = 0.5
|
15 |
+
|
16 |
+
|
17 |
+
def aggregate_votes_winsorized(ages, max_age_dist=6):
|
18 |
+
# Replace any annotation that is more than a max_age_dist away from the median
|
19 |
+
# with the median + max_age_dist if higher or max_age_dist - max_age_dist if below
|
20 |
+
median = np.median(ages)
|
21 |
+
ages = np.clip(ages, median - max_age_dist, median + max_age_dist)
|
22 |
+
return np.mean(ages)
|
23 |
+
|
24 |
+
|
25 |
+
def cropout_black_parts(img, tol=0.3):
|
26 |
+
# Create a binary mask of zero pixels
|
27 |
+
zero_pixels_mask = np.all(img == 0, axis=2)
|
28 |
+
# Calculate the threshold for zero pixels in rows and columns
|
29 |
+
threshold = img.shape[0] - img.shape[0] * tol
|
30 |
+
# Calculate row sums and column sums of zero pixels mask
|
31 |
+
row_sums = np.sum(zero_pixels_mask, axis=1)
|
32 |
+
col_sums = np.sum(zero_pixels_mask, axis=0)
|
33 |
+
# Find the first and last rows with zero pixel sums above the threshold
|
34 |
+
start_row = np.argmin(row_sums > threshold)
|
35 |
+
end_row = img.shape[0] - np.argmin(row_sums[::-1] > threshold)
|
36 |
+
# Find the first and last columns with zero pixel sums above the threshold
|
37 |
+
start_col = np.argmin(col_sums > threshold)
|
38 |
+
end_col = img.shape[1] - np.argmin(col_sums[::-1] > threshold)
|
39 |
+
# Crop the image
|
40 |
+
cropped_img = img[start_row:end_row, start_col:end_col, :]
|
41 |
+
area = cropped_img.shape[0] * cropped_img.shape[1]
|
42 |
+
area_orig = img.shape[0] * img.shape[1]
|
43 |
+
return cropped_img, area / area_orig
|
44 |
+
|
45 |
+
|
46 |
+
def natural_key(string_):
|
47 |
+
"""See http://www.codinghorror.com/blog/archives/001018.html"""
|
48 |
+
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
|
49 |
+
|
50 |
+
|
51 |
+
def add_bool_arg(parser, name, default=False, help=""):
|
52 |
+
dest_name = name.replace("-", "_")
|
53 |
+
group = parser.add_mutually_exclusive_group(required=False)
|
54 |
+
group.add_argument("--" + name, dest=dest_name, action="store_true", help=help)
|
55 |
+
group.add_argument("--no-" + name, dest=dest_name, action="store_false", help=help)
|
56 |
+
parser.set_defaults(**{dest_name: default})
|
57 |
+
|
58 |
+
|
59 |
+
def cumulative_score(pred_ages, gt_ages, L, tol=1e-6):
|
60 |
+
n = pred_ages.shape[0]
|
61 |
+
num_correct = torch.sum(torch.abs(pred_ages - gt_ages) <= L + tol)
|
62 |
+
cs_score = num_correct / n
|
63 |
+
return cs_score
|
64 |
+
|
65 |
+
|
66 |
+
def cumulative_error(pred_ages, gt_ages, L, tol=1e-6):
|
67 |
+
n = pred_ages.shape[0]
|
68 |
+
num_correct = torch.sum(torch.abs(pred_ages - gt_ages) >= L + tol)
|
69 |
+
cs_score = num_correct / n
|
70 |
+
return cs_score
|
71 |
+
|
72 |
+
|
73 |
+
class ParseKwargs(argparse.Action):
|
74 |
+
def __call__(self, parser, namespace, values, option_string=None):
|
75 |
+
kw = {}
|
76 |
+
for value in values:
|
77 |
+
key, value = value.split("=")
|
78 |
+
try:
|
79 |
+
kw[key] = ast.literal_eval(value)
|
80 |
+
except ValueError:
|
81 |
+
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
|
82 |
+
setattr(namespace, self.dest, kw)
|
83 |
+
|
84 |
+
|
85 |
+
def box_iou(box1, box2, over_second=False):
|
86 |
+
"""
|
87 |
+
Return intersection-over-union (Jaccard index) of boxes.
|
88 |
+
If over_second == True, return mean(intersection-over-union, (inter / area2))
|
89 |
+
|
90 |
+
Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
|
91 |
+
|
92 |
+
Arguments:
|
93 |
+
box1 (Tensor[N, 4])
|
94 |
+
box2 (Tensor[M, 4])
|
95 |
+
Returns:
|
96 |
+
iou (Tensor[N, M]): the NxM matrix containing the pairwise
|
97 |
+
IoU values for every element in boxes1 and boxes2
|
98 |
+
"""
|
99 |
+
|
100 |
+
def box_area(box):
|
101 |
+
# box = 4xn
|
102 |
+
return (box[2] - box[0]) * (box[3] - box[1])
|
103 |
+
|
104 |
+
area1 = box_area(box1.T)
|
105 |
+
area2 = box_area(box2.T)
|
106 |
+
|
107 |
+
# inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
108 |
+
inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) - torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
|
109 |
+
|
110 |
+
iou = inter / (area1[:, None] + area2 - inter) # iou = inter / (area1 + area2 - inter)
|
111 |
+
if over_second:
|
112 |
+
return (inter / area2 + iou) / 2 # mean(inter / area2, iou)
|
113 |
+
else:
|
114 |
+
return iou
|
115 |
+
|
116 |
+
|
117 |
+
def split_batch(bs: int, dev: int) -> Tuple[int, int]:
|
118 |
+
full_bs = (bs // dev) * dev
|
119 |
+
part_bs = bs - full_bs
|
120 |
+
return full_bs, part_bs
|
121 |
+
|
122 |
+
|
123 |
+
def assign_faces(
|
124 |
+
persons_bboxes: List[torch.tensor], faces_bboxes: List[torch.tensor], iou_thresh: float = 0.0001
|
125 |
+
) -> Tuple[List[Optional[int]], List[int]]:
|
126 |
+
"""
|
127 |
+
Assign person to each face if it is possible.
|
128 |
+
Return:
|
129 |
+
- assigned_faces List[Optional[int]]: mapping of face_ind to person_ind
|
130 |
+
( assigned_faces[face_ind] = person_ind ). person_ind can be None
|
131 |
+
- unassigned_persons_inds List[int]: persons indexes without any assigned face
|
132 |
+
"""
|
133 |
+
|
134 |
+
assigned_faces: List[Optional[int]] = [None for _ in range(len(faces_bboxes))]
|
135 |
+
unassigned_persons_inds: List[int] = [p_ind for p_ind in range(len(persons_bboxes))]
|
136 |
+
|
137 |
+
if len(persons_bboxes) == 0 or len(faces_bboxes) == 0:
|
138 |
+
return assigned_faces, unassigned_persons_inds
|
139 |
+
|
140 |
+
cost_matrix = box_iou(torch.stack(persons_bboxes), torch.stack(faces_bboxes), over_second=True).cpu().numpy()
|
141 |
+
persons_indexes, face_indexes = [], []
|
142 |
+
|
143 |
+
if len(cost_matrix) > 0:
|
144 |
+
persons_indexes, face_indexes = linear_sum_assignment(cost_matrix, maximize=True)
|
145 |
+
|
146 |
+
matched_persons = set()
|
147 |
+
for person_idx, face_idx in zip(persons_indexes, face_indexes):
|
148 |
+
ciou = cost_matrix[person_idx][face_idx]
|
149 |
+
if ciou > iou_thresh:
|
150 |
+
if person_idx in matched_persons:
|
151 |
+
# Person can not be assigned twice, in reality this should not happen
|
152 |
+
continue
|
153 |
+
assigned_faces[face_idx] = person_idx
|
154 |
+
matched_persons.add(person_idx)
|
155 |
+
|
156 |
+
unassigned_persons_inds = [p_ind for p_ind in range(len(persons_bboxes)) if p_ind not in matched_persons]
|
157 |
+
|
158 |
+
return assigned_faces, unassigned_persons_inds
|
159 |
+
|
160 |
+
|
161 |
+
def class_letterbox(im, new_shape=(640, 640), color=(0, 0, 0), scaleup=True):
|
162 |
+
# Resize and pad image while meeting stride-multiple constraints
|
163 |
+
shape = im.shape[:2] # current shape [height, width]
|
164 |
+
if isinstance(new_shape, int):
|
165 |
+
new_shape = (new_shape, new_shape)
|
166 |
+
|
167 |
+
if im.shape[0] == new_shape[0] and im.shape[1] == new_shape[1]:
|
168 |
+
return im
|
169 |
+
|
170 |
+
# Scale ratio (new / old)
|
171 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
172 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
173 |
+
r = min(r, 1.0)
|
174 |
+
|
175 |
+
# Compute padding
|
176 |
+
# ratio = r, r # width, height ratios
|
177 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
178 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
179 |
+
|
180 |
+
dw /= 2 # divide padding into 2 sides
|
181 |
+
dh /= 2
|
182 |
+
|
183 |
+
if shape[::-1] != new_unpad: # resize
|
184 |
+
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
185 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
186 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
187 |
+
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
|
188 |
+
return im
|
189 |
+
|
190 |
+
|
191 |
+
def prepare_classification_images(
|
192 |
+
img_list: List[Optional[np.ndarray]],
|
193 |
+
target_size: int = 224,
|
194 |
+
mean=IMAGENET_DEFAULT_MEAN,
|
195 |
+
std=IMAGENET_DEFAULT_STD,
|
196 |
+
device=None,
|
197 |
+
) -> torch.tensor:
|
198 |
+
|
199 |
+
prepared_images: List[torch.tensor] = []
|
200 |
+
|
201 |
+
for img in img_list:
|
202 |
+
if img is None:
|
203 |
+
img = torch.zeros((3, target_size, target_size), dtype=torch.float32)
|
204 |
+
img = F.normalize(img, mean=mean, std=std)
|
205 |
+
img = img.unsqueeze(0)
|
206 |
+
prepared_images.append(img)
|
207 |
+
continue
|
208 |
+
img = class_letterbox(img, new_shape=(target_size, target_size))
|
209 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
210 |
+
|
211 |
+
img = img / 255.0
|
212 |
+
img = (img - mean) / std
|
213 |
+
img = img.astype(dtype=np.float32)
|
214 |
+
|
215 |
+
img = img.transpose((2, 0, 1))
|
216 |
+
img = np.ascontiguousarray(img)
|
217 |
+
img = torch.from_numpy(img)
|
218 |
+
img = img.unsqueeze(0)
|
219 |
+
|
220 |
+
prepared_images.append(img)
|
221 |
+
|
222 |
+
prepared_input = torch.concat(prepared_images)
|
223 |
+
|
224 |
+
if device:
|
225 |
+
prepared_input = prepared_input.to(device)
|
226 |
+
|
227 |
+
return prepared_input
|
228 |
+
|
229 |
+
|
230 |
+
def IOU(bb1: Union[tuple, list], bb2: Union[tuple, list], norm_second_bbox: bool = False) -> float:
|
231 |
+
# expects [ymin, xmin, ymax, xmax], doesnt matter absolute or relative
|
232 |
+
assert bb1[1] < bb1[3]
|
233 |
+
assert bb1[0] < bb1[2]
|
234 |
+
assert bb2[1] < bb2[3]
|
235 |
+
assert bb2[0] < bb2[2]
|
236 |
+
|
237 |
+
# determine the coordinates of the intersection rectangle
|
238 |
+
x_left = max(bb1[1], bb2[1])
|
239 |
+
y_top = max(bb1[0], bb2[0])
|
240 |
+
x_right = min(bb1[3], bb2[3])
|
241 |
+
y_bottom = min(bb1[2], bb2[2])
|
242 |
+
|
243 |
+
if x_right < x_left or y_bottom < y_top:
|
244 |
+
return 0.0
|
245 |
+
|
246 |
+
# The intersection of two axis-aligned bounding boxes is always an
|
247 |
+
# axis-aligned bounding box
|
248 |
+
intersection_area = (x_right - x_left) * (y_bottom - y_top)
|
249 |
+
# compute the area of both AABBs
|
250 |
+
bb1_area = (bb1[3] - bb1[1]) * (bb1[2] - bb1[0])
|
251 |
+
bb2_area = (bb2[3] - bb2[1]) * (bb2[2] - bb2[0])
|
252 |
+
if not norm_second_bbox:
|
253 |
+
# compute the intersection over union by taking the intersection
|
254 |
+
# area and dividing it by the sum of prediction + ground-truth
|
255 |
+
# areas - the interesection area
|
256 |
+
iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
|
257 |
+
else:
|
258 |
+
# for cases when we search if second bbox is inside first one
|
259 |
+
iou = intersection_area / float(bb2_area)
|
260 |
+
|
261 |
+
assert iou >= 0.0
|
262 |
+
assert iou <= 1.01
|
263 |
+
|
264 |
+
return iou
|
mivolo/model/create_timm_model.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from timm https://github.com/huggingface/pytorch-image-models
|
3 |
+
|
4 |
+
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
5 |
+
"""
|
6 |
+
|
7 |
+
import os
|
8 |
+
from typing import Any, Dict, Optional, Union
|
9 |
+
|
10 |
+
import timm
|
11 |
+
|
12 |
+
# register new models
|
13 |
+
from mivolo.model.mivolo_model import * # noqa: F403, F401
|
14 |
+
from timm.layers import set_layer_config
|
15 |
+
from timm.models._factory import parse_model_name
|
16 |
+
from timm.models._helpers import load_state_dict, remap_checkpoint
|
17 |
+
from timm.models._hub import load_model_config_from_hf
|
18 |
+
from timm.models._pretrained import PretrainedCfg, split_model_name_tag
|
19 |
+
from timm.models._registry import is_model, model_entrypoint
|
20 |
+
|
21 |
+
|
22 |
+
def load_checkpoint(
|
23 |
+
model, checkpoint_path, use_ema=True, strict=True, remap=False, filter_keys=None, state_dict_map=None
|
24 |
+
):
|
25 |
+
if os.path.splitext(checkpoint_path)[-1].lower() in (".npz", ".npy"):
|
26 |
+
# numpy checkpoint, try to load via model specific load_pretrained fn
|
27 |
+
if hasattr(model, "load_pretrained"):
|
28 |
+
timm.models._model_builder.load_pretrained(checkpoint_path)
|
29 |
+
else:
|
30 |
+
raise NotImplementedError("Model cannot load numpy checkpoint")
|
31 |
+
return
|
32 |
+
state_dict = load_state_dict(checkpoint_path, use_ema)
|
33 |
+
if remap:
|
34 |
+
state_dict = remap_checkpoint(model, state_dict)
|
35 |
+
if filter_keys:
|
36 |
+
for sd_key in list(state_dict.keys()):
|
37 |
+
for filter_key in filter_keys:
|
38 |
+
if filter_key in sd_key:
|
39 |
+
if sd_key in state_dict:
|
40 |
+
del state_dict[sd_key]
|
41 |
+
|
42 |
+
rep = []
|
43 |
+
if state_dict_map is not None:
|
44 |
+
# 'patch_embed.conv1.' : 'patch_embed.conv.'
|
45 |
+
for state_k in list(state_dict.keys()):
|
46 |
+
for target_k, target_v in state_dict_map.items():
|
47 |
+
if target_v in state_k:
|
48 |
+
target_name = state_k.replace(target_v, target_k)
|
49 |
+
state_dict[target_name] = state_dict[state_k]
|
50 |
+
rep.append(state_k)
|
51 |
+
for r in rep:
|
52 |
+
if r in state_dict:
|
53 |
+
del state_dict[r]
|
54 |
+
|
55 |
+
incompatible_keys = model.load_state_dict(state_dict, strict=strict if filter_keys is None else False)
|
56 |
+
return incompatible_keys
|
57 |
+
|
58 |
+
|
59 |
+
def create_model(
|
60 |
+
model_name: str,
|
61 |
+
pretrained: bool = False,
|
62 |
+
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
63 |
+
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
64 |
+
checkpoint_path: str = "",
|
65 |
+
scriptable: Optional[bool] = None,
|
66 |
+
exportable: Optional[bool] = None,
|
67 |
+
no_jit: Optional[bool] = None,
|
68 |
+
filter_keys=None,
|
69 |
+
state_dict_map=None,
|
70 |
+
**kwargs,
|
71 |
+
):
|
72 |
+
"""Create a model
|
73 |
+
Lookup model's entrypoint function and pass relevant args to create a new model.
|
74 |
+
"""
|
75 |
+
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
76 |
+
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
77 |
+
# non-supporting models don't break and default args remain in effect.
|
78 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
79 |
+
|
80 |
+
model_source, model_name = parse_model_name(model_name)
|
81 |
+
if model_source == "hf-hub":
|
82 |
+
assert not pretrained_cfg, "pretrained_cfg should not be set when sourcing model from Hugging Face Hub."
|
83 |
+
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
84 |
+
# load model weights + pretrained_cfg from Hugging Face hub.
|
85 |
+
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
86 |
+
else:
|
87 |
+
model_name, pretrained_tag = split_model_name_tag(model_name)
|
88 |
+
if not pretrained_cfg:
|
89 |
+
# a valid pretrained_cfg argument takes priority over tag in model name
|
90 |
+
pretrained_cfg = pretrained_tag
|
91 |
+
|
92 |
+
if not is_model(model_name):
|
93 |
+
raise RuntimeError("Unknown model (%s)" % model_name)
|
94 |
+
|
95 |
+
create_fn = model_entrypoint(model_name)
|
96 |
+
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
97 |
+
model = create_fn(
|
98 |
+
pretrained=pretrained,
|
99 |
+
pretrained_cfg=pretrained_cfg,
|
100 |
+
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
101 |
+
**kwargs,
|
102 |
+
)
|
103 |
+
|
104 |
+
if checkpoint_path:
|
105 |
+
load_checkpoint(model, checkpoint_path, filter_keys=filter_keys, state_dict_map=state_dict_map)
|
106 |
+
|
107 |
+
return model
|
mivolo/model/cross_bottleneck_attn.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code based on timm https://github.com/huggingface/pytorch-image-models
|
3 |
+
|
4 |
+
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from timm.layers.bottleneck_attn import PosEmbedRel
|
10 |
+
from timm.layers.helpers import make_divisible
|
11 |
+
from timm.layers.mlp import Mlp
|
12 |
+
from timm.layers.trace_utils import _assert
|
13 |
+
from timm.layers.weight_init import trunc_normal_
|
14 |
+
|
15 |
+
|
16 |
+
class CrossBottleneckAttn(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim,
|
20 |
+
dim_out=None,
|
21 |
+
feat_size=None,
|
22 |
+
stride=1,
|
23 |
+
num_heads=4,
|
24 |
+
dim_head=None,
|
25 |
+
qk_ratio=1.0,
|
26 |
+
qkv_bias=False,
|
27 |
+
scale_pos_embed=False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
assert feat_size is not None, "A concrete feature size matching expected input (H, W) is required"
|
31 |
+
dim_out = dim_out or dim
|
32 |
+
assert dim_out % num_heads == 0
|
33 |
+
|
34 |
+
self.num_heads = num_heads
|
35 |
+
self.dim_head_qk = dim_head or make_divisible(dim_out * qk_ratio, divisor=8) // num_heads
|
36 |
+
self.dim_head_v = dim_out // self.num_heads
|
37 |
+
self.dim_out_qk = num_heads * self.dim_head_qk
|
38 |
+
self.dim_out_v = num_heads * self.dim_head_v
|
39 |
+
self.scale = self.dim_head_qk**-0.5
|
40 |
+
self.scale_pos_embed = scale_pos_embed
|
41 |
+
|
42 |
+
self.qkv_f = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
43 |
+
self.qkv_p = nn.Conv2d(dim, self.dim_out_qk * 2 + self.dim_out_v, 1, bias=qkv_bias)
|
44 |
+
|
45 |
+
# NOTE I'm only supporting relative pos embedding for now
|
46 |
+
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head_qk, scale=self.scale)
|
47 |
+
|
48 |
+
self.norm = nn.LayerNorm([self.dim_out_v * 2, *feat_size])
|
49 |
+
mlp_ratio = 4
|
50 |
+
self.mlp = Mlp(
|
51 |
+
in_features=self.dim_out_v * 2,
|
52 |
+
hidden_features=int(dim * mlp_ratio),
|
53 |
+
act_layer=nn.GELU,
|
54 |
+
out_features=dim_out,
|
55 |
+
drop=0,
|
56 |
+
use_conv=True,
|
57 |
+
)
|
58 |
+
|
59 |
+
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
|
60 |
+
self.reset_parameters()
|
61 |
+
|
62 |
+
def reset_parameters(self):
|
63 |
+
trunc_normal_(self.qkv_f.weight, std=self.qkv_f.weight.shape[1] ** -0.5) # fan-in
|
64 |
+
trunc_normal_(self.qkv_p.weight, std=self.qkv_p.weight.shape[1] ** -0.5) # fan-in
|
65 |
+
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
|
66 |
+
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
|
67 |
+
|
68 |
+
def get_qkv(self, x, qvk_conv):
|
69 |
+
B, C, H, W = x.shape
|
70 |
+
|
71 |
+
x = qvk_conv(x) # B, (2 * dim_head_qk + dim_head_v) * num_heads, H, W
|
72 |
+
|
73 |
+
q, k, v = torch.split(x, [self.dim_out_qk, self.dim_out_qk, self.dim_out_v], dim=1)
|
74 |
+
|
75 |
+
q = q.reshape(B * self.num_heads, self.dim_head_qk, -1).transpose(-1, -2)
|
76 |
+
k = k.reshape(B * self.num_heads, self.dim_head_qk, -1) # no transpose, for q @ k
|
77 |
+
v = v.reshape(B * self.num_heads, self.dim_head_v, -1).transpose(-1, -2)
|
78 |
+
|
79 |
+
return q, k, v
|
80 |
+
|
81 |
+
def apply_attn(self, q, k, v, B, H, W, dropout=None):
|
82 |
+
if self.scale_pos_embed:
|
83 |
+
attn = (q @ k + self.pos_embed(q)) * self.scale # B * num_heads, H * W, H * W
|
84 |
+
else:
|
85 |
+
attn = (q @ k) * self.scale + self.pos_embed(q)
|
86 |
+
attn = attn.softmax(dim=-1)
|
87 |
+
if dropout:
|
88 |
+
attn = dropout(attn)
|
89 |
+
|
90 |
+
out = (attn @ v).transpose(-1, -2).reshape(B, self.dim_out_v, H, W) # B, dim_out, H, W
|
91 |
+
return out
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
B, C, H, W = x.shape
|
95 |
+
|
96 |
+
dim = int(C / 2)
|
97 |
+
x1 = x[:, :dim, :, :]
|
98 |
+
x2 = x[:, dim:, :, :]
|
99 |
+
|
100 |
+
_assert(H == self.pos_embed.height, "")
|
101 |
+
_assert(W == self.pos_embed.width, "")
|
102 |
+
|
103 |
+
q_f, k_f, v_f = self.get_qkv(x1, self.qkv_f)
|
104 |
+
q_p, k_p, v_p = self.get_qkv(x2, self.qkv_p)
|
105 |
+
|
106 |
+
# person to face
|
107 |
+
out_f = self.apply_attn(q_f, k_p, v_p, B, H, W)
|
108 |
+
# face to person
|
109 |
+
out_p = self.apply_attn(q_p, k_f, v_f, B, H, W)
|
110 |
+
|
111 |
+
x_pf = torch.cat((out_f, out_p), dim=1) # B, dim_out * 2, H, W
|
112 |
+
x_pf = self.norm(x_pf)
|
113 |
+
x_pf = self.mlp(x_pf) # B, dim_out, H, W
|
114 |
+
|
115 |
+
out = self.pool(x_pf)
|
116 |
+
return out
|
mivolo/model/mi_volo.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from mivolo.data.misc import prepare_classification_images
|
7 |
+
from mivolo.model.create_timm_model import create_model
|
8 |
+
from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult
|
9 |
+
from timm.data import resolve_data_config
|
10 |
+
|
11 |
+
_logger = logging.getLogger("MiVOLO")
|
12 |
+
has_compile = hasattr(torch, "compile")
|
13 |
+
|
14 |
+
|
15 |
+
class Meta:
|
16 |
+
def __init__(self):
|
17 |
+
self.min_age = None
|
18 |
+
self.max_age = None
|
19 |
+
self.avg_age = None
|
20 |
+
self.num_classes = None
|
21 |
+
|
22 |
+
self.in_chans = 3
|
23 |
+
self.with_persons_model = False
|
24 |
+
self.disable_faces = False
|
25 |
+
self.use_persons = True
|
26 |
+
self.only_age = False
|
27 |
+
|
28 |
+
self.num_classes_gender = 2
|
29 |
+
|
30 |
+
def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta":
|
31 |
+
|
32 |
+
state = torch.load(ckpt_path, map_location="cpu")
|
33 |
+
|
34 |
+
self.min_age = state["min_age"]
|
35 |
+
self.max_age = state["max_age"]
|
36 |
+
self.avg_age = state["avg_age"]
|
37 |
+
self.only_age = state["no_gender"]
|
38 |
+
|
39 |
+
only_age = state["no_gender"]
|
40 |
+
|
41 |
+
self.disable_faces = disable_faces
|
42 |
+
if "with_persons_model" in state:
|
43 |
+
self.with_persons_model = state["with_persons_model"]
|
44 |
+
else:
|
45 |
+
self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False
|
46 |
+
|
47 |
+
self.num_classes = 1 if only_age else 3
|
48 |
+
self.in_chans = 3 if not self.with_persons_model else 6
|
49 |
+
self.use_persons = use_persons and self.with_persons_model
|
50 |
+
|
51 |
+
if not self.with_persons_model and self.disable_faces:
|
52 |
+
raise ValueError("You can not use disable-faces for faces-only model")
|
53 |
+
if self.with_persons_model and self.disable_faces and not self.use_persons:
|
54 |
+
raise ValueError("You can not disable faces and persons together")
|
55 |
+
|
56 |
+
return self
|
57 |
+
|
58 |
+
def __str__(self):
|
59 |
+
attrs = vars(self)
|
60 |
+
attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops})
|
61 |
+
return ", ".join("%s: %s" % item for item in attrs.items())
|
62 |
+
|
63 |
+
@property
|
64 |
+
def use_person_crops(self) -> bool:
|
65 |
+
return self.with_persons_model and self.use_persons
|
66 |
+
|
67 |
+
@property
|
68 |
+
def use_face_crops(self) -> bool:
|
69 |
+
return not self.disable_faces or not self.with_persons_model
|
70 |
+
|
71 |
+
|
72 |
+
class MiVOLO:
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
ckpt_path: str,
|
76 |
+
device: str = "cpu",
|
77 |
+
half: bool = True,
|
78 |
+
disable_faces: bool = False,
|
79 |
+
use_persons: bool = True,
|
80 |
+
verbose: bool = False,
|
81 |
+
torchcompile: Optional[str] = None,
|
82 |
+
):
|
83 |
+
self.verbose = verbose
|
84 |
+
self.device = torch.device(device)
|
85 |
+
self.half = half and self.device.type != "cpu"
|
86 |
+
|
87 |
+
self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons)
|
88 |
+
if self.verbose:
|
89 |
+
_logger.info(f"Model meta:\n{str(self.meta)}")
|
90 |
+
|
91 |
+
model_name = "mivolo_d1_224"
|
92 |
+
self.model = create_model(
|
93 |
+
model_name=model_name,
|
94 |
+
num_classes=self.meta.num_classes,
|
95 |
+
in_chans=self.meta.in_chans,
|
96 |
+
pretrained=False,
|
97 |
+
checkpoint_path=ckpt_path,
|
98 |
+
filter_keys=["fds."],
|
99 |
+
)
|
100 |
+
self.param_count = sum([m.numel() for m in self.model.parameters()])
|
101 |
+
_logger.info(f"Model {model_name} created, param count: {self.param_count}")
|
102 |
+
|
103 |
+
self.data_config = resolve_data_config(
|
104 |
+
model=self.model,
|
105 |
+
verbose=verbose,
|
106 |
+
use_test_size=True,
|
107 |
+
)
|
108 |
+
self.data_config["crop_pct"] = 1.0
|
109 |
+
c, h, w = self.data_config["input_size"]
|
110 |
+
assert h == w, "Incorrect data_config"
|
111 |
+
self.input_size = w
|
112 |
+
|
113 |
+
self.model = self.model.to(self.device)
|
114 |
+
|
115 |
+
if torchcompile:
|
116 |
+
assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly."
|
117 |
+
torch._dynamo.reset()
|
118 |
+
self.model = torch.compile(self.model, backend=torchcompile)
|
119 |
+
|
120 |
+
self.model.eval()
|
121 |
+
if self.half:
|
122 |
+
self.model = self.model.half()
|
123 |
+
|
124 |
+
def warmup(self, batch_size: int, steps=10):
|
125 |
+
if self.meta.with_persons_model:
|
126 |
+
input_size = (6, self.input_size, self.input_size)
|
127 |
+
else:
|
128 |
+
input_size = self.data_config["input_size"]
|
129 |
+
|
130 |
+
input = torch.randn((batch_size,) + tuple(input_size)).to(self.device)
|
131 |
+
|
132 |
+
for _ in range(steps):
|
133 |
+
out = self.inference(input) # noqa: F841
|
134 |
+
|
135 |
+
if torch.cuda.is_available():
|
136 |
+
torch.cuda.synchronize()
|
137 |
+
|
138 |
+
def inference(self, model_input: torch.tensor) -> torch.tensor:
|
139 |
+
|
140 |
+
with torch.no_grad():
|
141 |
+
if self.half:
|
142 |
+
model_input = model_input.half()
|
143 |
+
output = self.model(model_input)
|
144 |
+
return output
|
145 |
+
|
146 |
+
def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
|
147 |
+
if detected_bboxes.n_objects == 0:
|
148 |
+
return
|
149 |
+
|
150 |
+
faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes)
|
151 |
+
|
152 |
+
if self.meta.with_persons_model:
|
153 |
+
model_input = torch.cat((faces_input, person_input), dim=1)
|
154 |
+
else:
|
155 |
+
model_input = faces_input
|
156 |
+
output = self.inference(model_input)
|
157 |
+
|
158 |
+
# write gender and age results into detected_bboxes
|
159 |
+
self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds)
|
160 |
+
|
161 |
+
def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds):
|
162 |
+
if self.meta.only_age:
|
163 |
+
age_output = output
|
164 |
+
gender_probs, gender_indx = None, None
|
165 |
+
else:
|
166 |
+
age_output = output[:, 2]
|
167 |
+
gender_output = output[:, :2].softmax(-1)
|
168 |
+
gender_probs, gender_indx = gender_output.topk(1)
|
169 |
+
|
170 |
+
assert output.shape[0] == len(faces_inds) == len(bodies_inds)
|
171 |
+
|
172 |
+
# per face
|
173 |
+
for index in range(output.shape[0]):
|
174 |
+
face_ind = faces_inds[index]
|
175 |
+
body_ind = bodies_inds[index]
|
176 |
+
|
177 |
+
# get_age
|
178 |
+
age = age_output[index].item()
|
179 |
+
age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age
|
180 |
+
age = round(age, 2)
|
181 |
+
|
182 |
+
detected_bboxes.set_age(face_ind, age)
|
183 |
+
detected_bboxes.set_age(body_ind, age)
|
184 |
+
|
185 |
+
_logger.info(f"\tage: {age}")
|
186 |
+
|
187 |
+
if gender_probs is not None:
|
188 |
+
gender = "male" if gender_indx[index].item() == 0 else "female"
|
189 |
+
gender_score = gender_probs[index].item()
|
190 |
+
|
191 |
+
_logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]")
|
192 |
+
|
193 |
+
detected_bboxes.set_gender(face_ind, gender, gender_score)
|
194 |
+
detected_bboxes.set_gender(body_ind, gender, gender_score)
|
195 |
+
|
196 |
+
def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult):
|
197 |
+
|
198 |
+
if self.meta.use_person_crops and self.meta.use_face_crops:
|
199 |
+
detected_bboxes.associate_faces_with_persons()
|
200 |
+
|
201 |
+
crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image)
|
202 |
+
(bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies(
|
203 |
+
self.meta.use_person_crops, self.meta.use_face_crops
|
204 |
+
)
|
205 |
+
|
206 |
+
if not self.meta.use_face_crops:
|
207 |
+
assert all(f is None for f in faces_crops)
|
208 |
+
|
209 |
+
faces_input = prepare_classification_images(
|
210 |
+
faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
|
211 |
+
)
|
212 |
+
|
213 |
+
if not self.meta.use_person_crops:
|
214 |
+
assert all(p is None for p in bodies_crops)
|
215 |
+
|
216 |
+
person_input = prepare_classification_images(
|
217 |
+
bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device
|
218 |
+
)
|
219 |
+
|
220 |
+
_logger.info(
|
221 |
+
f"faces_input: {faces_input.shape if faces_input is not None else None}, "
|
222 |
+
f"person_input: {person_input.shape if person_input is not None else None}"
|
223 |
+
)
|
224 |
+
|
225 |
+
return faces_input, person_input, faces_inds, bodies_inds
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0")
|
mivolo/model/mivolo_model.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Code adapted from timm https://github.com/huggingface/pytorch-image-models
|
3 |
+
|
4 |
+
Modifications and additions for mivolo by / Copyright 2023, Irina Tolstykh, Maxim Kuprashevich
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from mivolo.model.cross_bottleneck_attn import CrossBottleneckAttn
|
10 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
11 |
+
from timm.layers import trunc_normal_
|
12 |
+
from timm.models._builder import build_model_with_cfg
|
13 |
+
from timm.models._registry import register_model
|
14 |
+
from timm.models.volo import VOLO
|
15 |
+
|
16 |
+
__all__ = ["MiVOLOModel"] # model_registry will add each entrypoint fn to this
|
17 |
+
|
18 |
+
|
19 |
+
def _cfg(url="", **kwargs):
|
20 |
+
return {
|
21 |
+
"url": url,
|
22 |
+
"num_classes": 1000,
|
23 |
+
"input_size": (3, 224, 224),
|
24 |
+
"pool_size": None,
|
25 |
+
"crop_pct": 0.96,
|
26 |
+
"interpolation": "bicubic",
|
27 |
+
"fixed_input_size": True,
|
28 |
+
"mean": IMAGENET_DEFAULT_MEAN,
|
29 |
+
"std": IMAGENET_DEFAULT_STD,
|
30 |
+
"first_conv": None,
|
31 |
+
"classifier": ("head", "aux_head"),
|
32 |
+
**kwargs,
|
33 |
+
}
|
34 |
+
|
35 |
+
|
36 |
+
default_cfgs = {
|
37 |
+
"mivolo_d1_224": _cfg(
|
38 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_224_84.2.pth.tar", crop_pct=0.96
|
39 |
+
),
|
40 |
+
"mivolo_d1_384": _cfg(
|
41 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d1_384_85.2.pth.tar",
|
42 |
+
crop_pct=1.0,
|
43 |
+
input_size=(3, 384, 384),
|
44 |
+
),
|
45 |
+
"mivolo_d2_224": _cfg(
|
46 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_224_85.2.pth.tar", crop_pct=0.96
|
47 |
+
),
|
48 |
+
"mivolo_d2_384": _cfg(
|
49 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d2_384_86.0.pth.tar",
|
50 |
+
crop_pct=1.0,
|
51 |
+
input_size=(3, 384, 384),
|
52 |
+
),
|
53 |
+
"mivolo_d3_224": _cfg(
|
54 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_224_85.4.pth.tar", crop_pct=0.96
|
55 |
+
),
|
56 |
+
"mivolo_d3_448": _cfg(
|
57 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d3_448_86.3.pth.tar",
|
58 |
+
crop_pct=1.0,
|
59 |
+
input_size=(3, 448, 448),
|
60 |
+
),
|
61 |
+
"mivolo_d4_224": _cfg(
|
62 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_224_85.7.pth.tar", crop_pct=0.96
|
63 |
+
),
|
64 |
+
"mivolo_d4_448": _cfg(
|
65 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d4_448_86.79.pth.tar",
|
66 |
+
crop_pct=1.15,
|
67 |
+
input_size=(3, 448, 448),
|
68 |
+
),
|
69 |
+
"mivolo_d5_224": _cfg(
|
70 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_224_86.10.pth.tar", crop_pct=0.96
|
71 |
+
),
|
72 |
+
"mivolo_d5_448": _cfg(
|
73 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_448_87.0.pth.tar",
|
74 |
+
crop_pct=1.15,
|
75 |
+
input_size=(3, 448, 448),
|
76 |
+
),
|
77 |
+
"mivolo_d5_512": _cfg(
|
78 |
+
url="https://github.com/sail-sg/volo/releases/download/volo_1/d5_512_87.07.pth.tar",
|
79 |
+
crop_pct=1.15,
|
80 |
+
input_size=(3, 512, 512),
|
81 |
+
),
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
def get_output_size(input_shape, conv_layer):
|
86 |
+
padding = conv_layer.padding
|
87 |
+
dilation = conv_layer.dilation
|
88 |
+
kernel_size = conv_layer.kernel_size
|
89 |
+
stride = conv_layer.stride
|
90 |
+
|
91 |
+
output_size = [
|
92 |
+
((input_shape[i] + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) - 1) // stride[i]) + 1 for i in range(2)
|
93 |
+
]
|
94 |
+
return output_size
|
95 |
+
|
96 |
+
|
97 |
+
def get_output_size_module(input_size, stem):
|
98 |
+
output_size = input_size
|
99 |
+
|
100 |
+
for module in stem:
|
101 |
+
if isinstance(module, nn.Conv2d):
|
102 |
+
output_size = [
|
103 |
+
(
|
104 |
+
(output_size[i] + 2 * module.padding[i] - module.dilation[i] * (module.kernel_size[i] - 1) - 1)
|
105 |
+
// module.stride[i]
|
106 |
+
)
|
107 |
+
+ 1
|
108 |
+
for i in range(2)
|
109 |
+
]
|
110 |
+
|
111 |
+
return output_size
|
112 |
+
|
113 |
+
|
114 |
+
class PatchEmbed(nn.Module):
|
115 |
+
"""Image to Patch Embedding."""
|
116 |
+
|
117 |
+
def __init__(
|
118 |
+
self, img_size=224, stem_conv=False, stem_stride=1, patch_size=8, in_chans=3, hidden_dim=64, embed_dim=384
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
assert patch_size in [4, 8, 16]
|
122 |
+
assert in_chans in [3, 6]
|
123 |
+
self.with_persons_model = in_chans == 6
|
124 |
+
self.use_cross_attn = True
|
125 |
+
|
126 |
+
if stem_conv:
|
127 |
+
if not self.with_persons_model:
|
128 |
+
self.conv = self.create_stem(stem_stride, in_chans, hidden_dim)
|
129 |
+
else:
|
130 |
+
self.conv = True # just to match interface
|
131 |
+
# split
|
132 |
+
self.conv1 = self.create_stem(stem_stride, 3, hidden_dim)
|
133 |
+
self.conv2 = self.create_stem(stem_stride, 3, hidden_dim)
|
134 |
+
else:
|
135 |
+
self.conv = None
|
136 |
+
|
137 |
+
if self.with_persons_model:
|
138 |
+
|
139 |
+
self.proj1 = nn.Conv2d(
|
140 |
+
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
|
141 |
+
)
|
142 |
+
self.proj2 = nn.Conv2d(
|
143 |
+
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
|
144 |
+
)
|
145 |
+
|
146 |
+
stem_out_shape = get_output_size_module((img_size, img_size), self.conv1)
|
147 |
+
self.proj_output_size = get_output_size(stem_out_shape, self.proj1)
|
148 |
+
|
149 |
+
self.map = CrossBottleneckAttn(embed_dim, dim_out=embed_dim, num_heads=1, feat_size=self.proj_output_size)
|
150 |
+
|
151 |
+
else:
|
152 |
+
self.proj = nn.Conv2d(
|
153 |
+
hidden_dim, embed_dim, kernel_size=patch_size // stem_stride, stride=patch_size // stem_stride
|
154 |
+
)
|
155 |
+
|
156 |
+
self.patch_dim = img_size // patch_size
|
157 |
+
self.num_patches = self.patch_dim**2
|
158 |
+
|
159 |
+
def create_stem(self, stem_stride, in_chans, hidden_dim):
|
160 |
+
return nn.Sequential(
|
161 |
+
nn.Conv2d(in_chans, hidden_dim, kernel_size=7, stride=stem_stride, padding=3, bias=False), # 112x112
|
162 |
+
nn.BatchNorm2d(hidden_dim),
|
163 |
+
nn.ReLU(inplace=True),
|
164 |
+
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
|
165 |
+
nn.BatchNorm2d(hidden_dim),
|
166 |
+
nn.ReLU(inplace=True),
|
167 |
+
nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False), # 112x112
|
168 |
+
nn.BatchNorm2d(hidden_dim),
|
169 |
+
nn.ReLU(inplace=True),
|
170 |
+
)
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
if self.conv is not None:
|
174 |
+
if self.with_persons_model:
|
175 |
+
x1 = x[:, :3]
|
176 |
+
x2 = x[:, 3:]
|
177 |
+
|
178 |
+
x1 = self.conv1(x1)
|
179 |
+
x1 = self.proj1(x1)
|
180 |
+
|
181 |
+
x2 = self.conv2(x2)
|
182 |
+
x2 = self.proj2(x2)
|
183 |
+
|
184 |
+
x = torch.cat([x1, x2], dim=1)
|
185 |
+
x = self.map(x)
|
186 |
+
else:
|
187 |
+
x = self.conv(x)
|
188 |
+
x = self.proj(x) # B, C, H, W
|
189 |
+
|
190 |
+
return x
|
191 |
+
|
192 |
+
|
193 |
+
class MiVOLOModel(VOLO):
|
194 |
+
"""
|
195 |
+
Vision Outlooker, the main class of our model
|
196 |
+
"""
|
197 |
+
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
layers,
|
201 |
+
img_size=224,
|
202 |
+
in_chans=3,
|
203 |
+
num_classes=1000,
|
204 |
+
global_pool="token",
|
205 |
+
patch_size=8,
|
206 |
+
stem_hidden_dim=64,
|
207 |
+
embed_dims=None,
|
208 |
+
num_heads=None,
|
209 |
+
downsamples=(True, False, False, False),
|
210 |
+
outlook_attention=(True, False, False, False),
|
211 |
+
mlp_ratio=3.0,
|
212 |
+
qkv_bias=False,
|
213 |
+
drop_rate=0.0,
|
214 |
+
attn_drop_rate=0.0,
|
215 |
+
drop_path_rate=0.0,
|
216 |
+
norm_layer=nn.LayerNorm,
|
217 |
+
post_layers=("ca", "ca"),
|
218 |
+
use_aux_head=True,
|
219 |
+
use_mix_token=False,
|
220 |
+
pooling_scale=2,
|
221 |
+
):
|
222 |
+
super().__init__(
|
223 |
+
layers,
|
224 |
+
img_size,
|
225 |
+
in_chans,
|
226 |
+
num_classes,
|
227 |
+
global_pool,
|
228 |
+
patch_size,
|
229 |
+
stem_hidden_dim,
|
230 |
+
embed_dims,
|
231 |
+
num_heads,
|
232 |
+
downsamples,
|
233 |
+
outlook_attention,
|
234 |
+
mlp_ratio,
|
235 |
+
qkv_bias,
|
236 |
+
drop_rate,
|
237 |
+
attn_drop_rate,
|
238 |
+
drop_path_rate,
|
239 |
+
norm_layer,
|
240 |
+
post_layers,
|
241 |
+
use_aux_head,
|
242 |
+
use_mix_token,
|
243 |
+
pooling_scale,
|
244 |
+
)
|
245 |
+
|
246 |
+
self.patch_embed = PatchEmbed(
|
247 |
+
stem_conv=True,
|
248 |
+
stem_stride=2,
|
249 |
+
patch_size=patch_size,
|
250 |
+
in_chans=in_chans,
|
251 |
+
hidden_dim=stem_hidden_dim,
|
252 |
+
embed_dim=embed_dims[0],
|
253 |
+
)
|
254 |
+
|
255 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
256 |
+
self.apply(self._init_weights)
|
257 |
+
|
258 |
+
def forward_features(self, x):
|
259 |
+
x = self.patch_embed(x).permute(0, 2, 3, 1) # B,C,H,W-> B,H,W,C
|
260 |
+
|
261 |
+
# step2: tokens learning in the two stages
|
262 |
+
x = self.forward_tokens(x)
|
263 |
+
|
264 |
+
# step3: post network, apply class attention or not
|
265 |
+
if self.post_network is not None:
|
266 |
+
x = self.forward_cls(x)
|
267 |
+
x = self.norm(x)
|
268 |
+
return x
|
269 |
+
|
270 |
+
def forward_head(self, x, pre_logits: bool = False, targets=None, epoch=None):
|
271 |
+
if self.global_pool == "avg":
|
272 |
+
out = x.mean(dim=1)
|
273 |
+
elif self.global_pool == "token":
|
274 |
+
out = x[:, 0]
|
275 |
+
else:
|
276 |
+
out = x
|
277 |
+
if pre_logits:
|
278 |
+
return out
|
279 |
+
|
280 |
+
features = out
|
281 |
+
fds_enabled = hasattr(self, "_fds_forward")
|
282 |
+
if fds_enabled:
|
283 |
+
features = self._fds_forward(features, targets, epoch)
|
284 |
+
|
285 |
+
out = self.head(features)
|
286 |
+
if self.aux_head is not None:
|
287 |
+
# generate classes in all feature tokens, see token labeling
|
288 |
+
aux = self.aux_head(x[:, 1:])
|
289 |
+
out = out + 0.5 * aux.max(1)[0]
|
290 |
+
|
291 |
+
return (out, features) if (fds_enabled and self.training) else out
|
292 |
+
|
293 |
+
def forward(self, x, targets=None, epoch=None):
|
294 |
+
"""simplified forward (without mix token training)"""
|
295 |
+
x = self.forward_features(x)
|
296 |
+
x = self.forward_head(x, targets=targets, epoch=epoch)
|
297 |
+
return x
|
298 |
+
|
299 |
+
|
300 |
+
def _create_mivolo(variant, pretrained=False, **kwargs):
|
301 |
+
if kwargs.get("features_only", None):
|
302 |
+
raise RuntimeError("features_only not implemented for Vision Transformer models.")
|
303 |
+
return build_model_with_cfg(MiVOLOModel, variant, pretrained, **kwargs)
|
304 |
+
|
305 |
+
|
306 |
+
@register_model
|
307 |
+
def mivolo_d1_224(pretrained=False, **kwargs):
|
308 |
+
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
|
309 |
+
model = _create_mivolo("mivolo_d1_224", pretrained=pretrained, **model_args)
|
310 |
+
return model
|
311 |
+
|
312 |
+
|
313 |
+
@register_model
|
314 |
+
def mivolo_d1_384(pretrained=False, **kwargs):
|
315 |
+
model_args = dict(layers=(4, 4, 8, 2), embed_dims=(192, 384, 384, 384), num_heads=(6, 12, 12, 12), **kwargs)
|
316 |
+
model = _create_mivolo("mivolo_d1_384", pretrained=pretrained, **model_args)
|
317 |
+
return model
|
318 |
+
|
319 |
+
|
320 |
+
@register_model
|
321 |
+
def mivolo_d2_224(pretrained=False, **kwargs):
|
322 |
+
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
323 |
+
model = _create_mivolo("mivolo_d2_224", pretrained=pretrained, **model_args)
|
324 |
+
return model
|
325 |
+
|
326 |
+
|
327 |
+
@register_model
|
328 |
+
def mivolo_d2_384(pretrained=False, **kwargs):
|
329 |
+
model_args = dict(layers=(6, 4, 10, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
330 |
+
model = _create_mivolo("mivolo_d2_384", pretrained=pretrained, **model_args)
|
331 |
+
return model
|
332 |
+
|
333 |
+
|
334 |
+
@register_model
|
335 |
+
def mivolo_d3_224(pretrained=False, **kwargs):
|
336 |
+
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
337 |
+
model = _create_mivolo("mivolo_d3_224", pretrained=pretrained, **model_args)
|
338 |
+
return model
|
339 |
+
|
340 |
+
|
341 |
+
@register_model
|
342 |
+
def mivolo_d3_448(pretrained=False, **kwargs):
|
343 |
+
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(256, 512, 512, 512), num_heads=(8, 16, 16, 16), **kwargs)
|
344 |
+
model = _create_mivolo("mivolo_d3_448", pretrained=pretrained, **model_args)
|
345 |
+
return model
|
346 |
+
|
347 |
+
|
348 |
+
@register_model
|
349 |
+
def mivolo_d4_224(pretrained=False, **kwargs):
|
350 |
+
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
|
351 |
+
model = _create_mivolo("mivolo_d4_224", pretrained=pretrained, **model_args)
|
352 |
+
return model
|
353 |
+
|
354 |
+
|
355 |
+
@register_model
|
356 |
+
def mivolo_d4_448(pretrained=False, **kwargs):
|
357 |
+
"""VOLO-D4 model, Params: 193M"""
|
358 |
+
model_args = dict(layers=(8, 8, 16, 4), embed_dims=(384, 768, 768, 768), num_heads=(12, 16, 16, 16), **kwargs)
|
359 |
+
model = _create_mivolo("mivolo_d4_448", pretrained=pretrained, **model_args)
|
360 |
+
return model
|
361 |
+
|
362 |
+
|
363 |
+
@register_model
|
364 |
+
def mivolo_d5_224(pretrained=False, **kwargs):
|
365 |
+
model_args = dict(
|
366 |
+
layers=(12, 12, 20, 4),
|
367 |
+
embed_dims=(384, 768, 768, 768),
|
368 |
+
num_heads=(12, 16, 16, 16),
|
369 |
+
mlp_ratio=4,
|
370 |
+
stem_hidden_dim=128,
|
371 |
+
**kwargs
|
372 |
+
)
|
373 |
+
model = _create_mivolo("mivolo_d5_224", pretrained=pretrained, **model_args)
|
374 |
+
return model
|
375 |
+
|
376 |
+
|
377 |
+
@register_model
|
378 |
+
def mivolo_d5_448(pretrained=False, **kwargs):
|
379 |
+
model_args = dict(
|
380 |
+
layers=(12, 12, 20, 4),
|
381 |
+
embed_dims=(384, 768, 768, 768),
|
382 |
+
num_heads=(12, 16, 16, 16),
|
383 |
+
mlp_ratio=4,
|
384 |
+
stem_hidden_dim=128,
|
385 |
+
**kwargs
|
386 |
+
)
|
387 |
+
model = _create_mivolo("mivolo_d5_448", pretrained=pretrained, **model_args)
|
388 |
+
return model
|
389 |
+
|
390 |
+
|
391 |
+
@register_model
|
392 |
+
def mivolo_d5_512(pretrained=False, **kwargs):
|
393 |
+
model_args = dict(
|
394 |
+
layers=(12, 12, 20, 4),
|
395 |
+
embed_dims=(384, 768, 768, 768),
|
396 |
+
num_heads=(12, 16, 16, 16),
|
397 |
+
mlp_ratio=4,
|
398 |
+
stem_hidden_dim=128,
|
399 |
+
**kwargs
|
400 |
+
)
|
401 |
+
model = _create_mivolo("mivolo_d5_512", pretrained=pretrained, **model_args)
|
402 |
+
return model
|
mivolo/model/yolo_detector.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Dict, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import PIL
|
6 |
+
import torch
|
7 |
+
from mivolo.structures import PersonAndFaceResult
|
8 |
+
from ultralytics import YOLO
|
9 |
+
# from ultralytics.yolo.engine.results import Results
|
10 |
+
|
11 |
+
# because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
|
12 |
+
os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
|
13 |
+
|
14 |
+
|
15 |
+
class Detector:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
weights: str,
|
19 |
+
device: str = "cpu",
|
20 |
+
half: bool = True,
|
21 |
+
verbose: bool = False,
|
22 |
+
conf_thresh: float = 0.4,
|
23 |
+
iou_thresh: float = 0.7,
|
24 |
+
):
|
25 |
+
self.yolo = YOLO(weights)
|
26 |
+
self.yolo.fuse()
|
27 |
+
|
28 |
+
self.device = torch.device(device)
|
29 |
+
self.half = half and self.device.type != "cpu"
|
30 |
+
|
31 |
+
if self.half:
|
32 |
+
self.yolo.model = self.yolo.model.half()
|
33 |
+
|
34 |
+
self.detector_names: Dict[int, str] = self.yolo.model.names
|
35 |
+
|
36 |
+
# init yolo.predictor
|
37 |
+
self.detector_kwargs = {
|
38 |
+
"conf": conf_thresh, "iou": iou_thresh, "half": self.half, "verbose": verbose}
|
39 |
+
# self.yolo.predict(**self.detector_kwargs)
|
40 |
+
|
41 |
+
def predict(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
|
42 |
+
results = self.yolo.predict(image, **self.detector_kwargs)[0]
|
43 |
+
return PersonAndFaceResult(results)
|
44 |
+
|
45 |
+
def track(self, image: Union[np.ndarray, str, "PIL.Image"]) -> PersonAndFaceResult:
|
46 |
+
results = self.yolo.track(
|
47 |
+
image, persist=True, **self.detector_kwargs)[0]
|
48 |
+
return PersonAndFaceResult(results)
|
mivolo/predictor.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
from typing import Dict, Generator, List, Optional, Tuple
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import tqdm
|
7 |
+
from mivolo.model.mi_volo import MiVOLO
|
8 |
+
from mivolo.model.yolo_detector import Detector
|
9 |
+
from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult
|
10 |
+
|
11 |
+
|
12 |
+
class Predictor:
|
13 |
+
def __init__(self, config, verbose: bool = False):
|
14 |
+
self.detector = Detector(config.detector_weights, config.device, verbose=verbose)
|
15 |
+
self.age_gender_model = MiVOLO(
|
16 |
+
config.checkpoint,
|
17 |
+
config.device,
|
18 |
+
half=True,
|
19 |
+
use_persons=config.with_persons,
|
20 |
+
disable_faces=config.disable_faces,
|
21 |
+
verbose=verbose,
|
22 |
+
)
|
23 |
+
self.draw = config.draw
|
24 |
+
|
25 |
+
def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]:
|
26 |
+
detected_objects: PersonAndFaceResult = self.detector.predict(image)
|
27 |
+
self.age_gender_model.predict(image, detected_objects)
|
28 |
+
|
29 |
+
out_im = None
|
30 |
+
if self.draw:
|
31 |
+
# plot results on image
|
32 |
+
out_im = detected_objects.plot()
|
33 |
+
|
34 |
+
return detected_objects, out_im
|
35 |
+
|
36 |
+
def recognize_video(self, source: str) -> Generator:
|
37 |
+
video_capture = cv2.VideoCapture(source)
|
38 |
+
if not video_capture.isOpened():
|
39 |
+
raise ValueError(f"Failed to open video source {source}")
|
40 |
+
|
41 |
+
detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list)
|
42 |
+
|
43 |
+
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
|
44 |
+
for _ in tqdm.tqdm(range(total_frames)):
|
45 |
+
ret, frame = video_capture.read()
|
46 |
+
if not ret:
|
47 |
+
break
|
48 |
+
|
49 |
+
detected_objects: PersonAndFaceResult = self.detector.track(frame)
|
50 |
+
self.age_gender_model.predict(frame, detected_objects)
|
51 |
+
|
52 |
+
current_frame_objs = detected_objects.get_results_for_tracking()
|
53 |
+
cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0]
|
54 |
+
cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1]
|
55 |
+
|
56 |
+
# add tr_persons and tr_faces to history
|
57 |
+
for guid, data in cur_persons.items():
|
58 |
+
# not useful for tracking :)
|
59 |
+
if None not in data:
|
60 |
+
detected_objects_history[guid].append(data)
|
61 |
+
for guid, data in cur_faces.items():
|
62 |
+
if None not in data:
|
63 |
+
detected_objects_history[guid].append(data)
|
64 |
+
|
65 |
+
detected_objects.set_tracked_age_gender(detected_objects_history)
|
66 |
+
if self.draw:
|
67 |
+
frame = detected_objects.plot()
|
68 |
+
yield detected_objects_history, frame
|
mivolo/structures.py
ADDED
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
from copy import deepcopy
|
4 |
+
from typing import Dict, List, Optional, Tuple
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from mivolo.data.misc import aggregate_votes_winsorized, assign_faces, box_iou, cropout_black_parts
|
10 |
+
from ultralytics.engine.results import Results
|
11 |
+
from ultralytics.utils.plotting import Annotator, colors
|
12 |
+
|
13 |
+
# because of ultralytics bug it is important to unset CUBLAS_WORKSPACE_CONFIG after the module importing
|
14 |
+
os.unsetenv("CUBLAS_WORKSPACE_CONFIG")
|
15 |
+
|
16 |
+
AGE_GENDER_TYPE = Tuple[float, str]
|
17 |
+
|
18 |
+
|
19 |
+
class PersonAndFaceCrops:
|
20 |
+
def __init__(self):
|
21 |
+
# int: index of person along results
|
22 |
+
self.crops_persons: Dict[int, np.ndarray] = {}
|
23 |
+
|
24 |
+
# int: index of face along results
|
25 |
+
self.crops_faces: Dict[int, np.ndarray] = {}
|
26 |
+
|
27 |
+
# int: index of face along results
|
28 |
+
self.crops_faces_wo_body: Dict[int, np.ndarray] = {}
|
29 |
+
|
30 |
+
# int: index of person along results
|
31 |
+
self.crops_persons_wo_face: Dict[int, np.ndarray] = {}
|
32 |
+
|
33 |
+
def _add_to_output(
|
34 |
+
self, crops: Dict[int, np.ndarray], out_crops: List[np.ndarray], out_crop_inds: List[Optional[int]]
|
35 |
+
):
|
36 |
+
inds_to_add = list(crops.keys())
|
37 |
+
crops_to_add = list(crops.values())
|
38 |
+
out_crops.extend(crops_to_add)
|
39 |
+
out_crop_inds.extend(inds_to_add)
|
40 |
+
|
41 |
+
def _get_all_faces(
|
42 |
+
self, use_persons: bool, use_faces: bool
|
43 |
+
) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
|
44 |
+
"""
|
45 |
+
Returns
|
46 |
+
if use_persons and use_faces
|
47 |
+
faces: faces_with_bodies + faces_without_bodies + [None] * len(crops_persons_wo_face)
|
48 |
+
if use_persons and not use_faces
|
49 |
+
faces: [None] * n_persons
|
50 |
+
if not use_persons and use_faces:
|
51 |
+
faces: faces_with_bodies + faces_without_bodies
|
52 |
+
"""
|
53 |
+
|
54 |
+
def add_none_to_output(faces_inds, faces_crops, num):
|
55 |
+
faces_inds.extend([None for _ in range(num)])
|
56 |
+
faces_crops.extend([None for _ in range(num)])
|
57 |
+
|
58 |
+
faces_inds: List[Optional[int]] = []
|
59 |
+
faces_crops: List[Optional[np.ndarray]] = []
|
60 |
+
|
61 |
+
if not use_faces:
|
62 |
+
add_none_to_output(faces_inds, faces_crops, len(
|
63 |
+
self.crops_persons) + len(self.crops_persons_wo_face))
|
64 |
+
return faces_inds, faces_crops
|
65 |
+
|
66 |
+
self._add_to_output(self.crops_faces, faces_crops, faces_inds)
|
67 |
+
self._add_to_output(self.crops_faces_wo_body, faces_crops, faces_inds)
|
68 |
+
|
69 |
+
if use_persons:
|
70 |
+
add_none_to_output(faces_inds, faces_crops,
|
71 |
+
len(self.crops_persons_wo_face))
|
72 |
+
|
73 |
+
return faces_inds, faces_crops
|
74 |
+
|
75 |
+
def _get_all_bodies(
|
76 |
+
self, use_persons: bool, use_faces: bool
|
77 |
+
) -> Tuple[List[Optional[int]], List[Optional[np.ndarray]]]:
|
78 |
+
"""
|
79 |
+
Returns
|
80 |
+
if use_persons and use_faces
|
81 |
+
persons: bodies_with_faces + [None] * len(faces_without_bodies) + bodies_without_faces
|
82 |
+
if use_persons and not use_faces
|
83 |
+
persons: bodies_with_faces + bodies_without_faces
|
84 |
+
if not use_persons and use_faces
|
85 |
+
persons: [None] * n_faces
|
86 |
+
"""
|
87 |
+
|
88 |
+
def add_none_to_output(bodies_inds, bodies_crops, num):
|
89 |
+
bodies_inds.extend([None for _ in range(num)])
|
90 |
+
bodies_crops.extend([None for _ in range(num)])
|
91 |
+
|
92 |
+
bodies_inds: List[Optional[int]] = []
|
93 |
+
bodies_crops: List[Optional[np.ndarray]] = []
|
94 |
+
|
95 |
+
if not use_persons:
|
96 |
+
add_none_to_output(bodies_inds, bodies_crops, len(
|
97 |
+
self.crops_faces) + len(self.crops_faces_wo_body))
|
98 |
+
return bodies_inds, bodies_crops
|
99 |
+
|
100 |
+
self._add_to_output(self.crops_persons, bodies_crops, bodies_inds)
|
101 |
+
if use_faces:
|
102 |
+
add_none_to_output(bodies_inds, bodies_crops,
|
103 |
+
len(self.crops_faces_wo_body))
|
104 |
+
|
105 |
+
self._add_to_output(self.crops_persons_wo_face,
|
106 |
+
bodies_crops, bodies_inds)
|
107 |
+
|
108 |
+
return bodies_inds, bodies_crops
|
109 |
+
|
110 |
+
def get_faces_with_bodies(self, use_persons: bool, use_faces: bool):
|
111 |
+
"""
|
112 |
+
Return
|
113 |
+
faces: faces_with_bodies, faces_without_bodies, [None] * len(crops_persons_wo_face)
|
114 |
+
persons: bodies_with_faces, [None] * len(faces_without_bodies), bodies_without_faces
|
115 |
+
"""
|
116 |
+
|
117 |
+
bodies_inds, bodies_crops = self._get_all_bodies(
|
118 |
+
use_persons, use_faces)
|
119 |
+
faces_inds, faces_crops = self._get_all_faces(use_persons, use_faces)
|
120 |
+
|
121 |
+
return (bodies_inds, bodies_crops), (faces_inds, faces_crops)
|
122 |
+
|
123 |
+
def save(self, out_dir="output"):
|
124 |
+
ind = 0
|
125 |
+
os.makedirs(out_dir, exist_ok=True)
|
126 |
+
for crops in [self.crops_persons, self.crops_faces, self.crops_faces_wo_body, self.crops_persons_wo_face]:
|
127 |
+
for crop in crops.values():
|
128 |
+
if crop is None:
|
129 |
+
continue
|
130 |
+
out_name = os.path.join(out_dir, f"{ind}_crop.jpg")
|
131 |
+
cv2.imwrite(out_name, crop)
|
132 |
+
ind += 1
|
133 |
+
|
134 |
+
|
135 |
+
class PersonAndFaceResult:
|
136 |
+
def __init__(self, results: Results):
|
137 |
+
|
138 |
+
self.yolo_results = results
|
139 |
+
names = set(results.names.values())
|
140 |
+
assert "person" in names and "face" in names
|
141 |
+
|
142 |
+
# initially no faces and persons are associated to each other
|
143 |
+
self.face_to_person_map: Dict[int, Optional[int]] = {
|
144 |
+
ind: None for ind in self.get_bboxes_inds("face")}
|
145 |
+
self.unassigned_persons_inds: List[int] = self.get_bboxes_inds(
|
146 |
+
"person")
|
147 |
+
n_objects = len(self.yolo_results.boxes)
|
148 |
+
self.ages: List[Optional[float]] = [None for _ in range(n_objects)]
|
149 |
+
self.genders: List[Optional[str]] = [None for _ in range(n_objects)]
|
150 |
+
self.gender_scores: List[Optional[float]] = [
|
151 |
+
None for _ in range(n_objects)]
|
152 |
+
|
153 |
+
@property
|
154 |
+
def n_objects(self) -> int:
|
155 |
+
return len(self.yolo_results.boxes)
|
156 |
+
|
157 |
+
def get_bboxes_inds(self, category: str) -> List[int]:
|
158 |
+
bboxes: List[int] = []
|
159 |
+
for ind, det in enumerate(self.yolo_results.boxes):
|
160 |
+
name = self.yolo_results.names[int(det.cls)]
|
161 |
+
if name == category:
|
162 |
+
bboxes.append(ind)
|
163 |
+
|
164 |
+
return bboxes
|
165 |
+
|
166 |
+
def get_distance_to_center(self, bbox_ind: int) -> float:
|
167 |
+
"""
|
168 |
+
Calculate euclidian distance between bbox center and image center.
|
169 |
+
"""
|
170 |
+
im_h, im_w = self.yolo_results[bbox_ind].orig_shape
|
171 |
+
x1, y1, x2, y2 = self.get_bbox_by_ind(bbox_ind).cpu().numpy()
|
172 |
+
center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
|
173 |
+
dist = math.dist([center_x, center_y], [im_w / 2, im_h / 2])
|
174 |
+
return dist
|
175 |
+
|
176 |
+
def plot(
|
177 |
+
self,
|
178 |
+
conf=False,
|
179 |
+
line_width=None,
|
180 |
+
font_size=None,
|
181 |
+
font="Arial.ttf",
|
182 |
+
pil=False,
|
183 |
+
img=None,
|
184 |
+
labels=True,
|
185 |
+
boxes=True,
|
186 |
+
probs=True,
|
187 |
+
ages=True,
|
188 |
+
genders=True,
|
189 |
+
gender_probs=False,
|
190 |
+
):
|
191 |
+
"""
|
192 |
+
Plots the detection results on an input RGB image. Accepts a numpy array (cv2) or a PIL Image.
|
193 |
+
Args:
|
194 |
+
conf (bool): Whether to plot the detection confidence score.
|
195 |
+
line_width (float, optional): The line width of the bounding boxes. If None, it is scaled to the image size.
|
196 |
+
font_size (float, optional): The font size of the text. If None, it is scaled to the image size.
|
197 |
+
font (str): The font to use for the text.
|
198 |
+
pil (bool): Whether to return the image as a PIL Image.
|
199 |
+
img (numpy.ndarray): Plot to another image. if not, plot to original image.
|
200 |
+
labels (bool): Whether to plot the label of bounding boxes.
|
201 |
+
boxes (bool): Whether to plot the bounding boxes.
|
202 |
+
probs (bool): Whether to plot classification probability
|
203 |
+
ages (bool): Whether to plot the age of bounding boxes.
|
204 |
+
genders (bool): Whether to plot the genders of bounding boxes.
|
205 |
+
gender_probs (bool): Whether to plot gender classification probability
|
206 |
+
Returns:
|
207 |
+
(numpy.ndarray): A numpy array of the annotated image.
|
208 |
+
"""
|
209 |
+
|
210 |
+
# return self.yolo_results.plot()
|
211 |
+
colors_by_ind = {}
|
212 |
+
for face_ind, person_ind in self.face_to_person_map.items():
|
213 |
+
if person_ind is not None:
|
214 |
+
colors_by_ind[face_ind] = face_ind + 2
|
215 |
+
colors_by_ind[person_ind] = face_ind + 2
|
216 |
+
else:
|
217 |
+
colors_by_ind[face_ind] = 0
|
218 |
+
for person_ind in self.unassigned_persons_inds:
|
219 |
+
colors_by_ind[person_ind] = 1
|
220 |
+
|
221 |
+
names = self.yolo_results.names
|
222 |
+
annotator = Annotator(
|
223 |
+
deepcopy(self.yolo_results.orig_img if img is None else img),
|
224 |
+
line_width,
|
225 |
+
font_size,
|
226 |
+
font,
|
227 |
+
pil,
|
228 |
+
example=names,
|
229 |
+
)
|
230 |
+
pred_boxes, show_boxes = self.yolo_results.boxes, boxes
|
231 |
+
pred_probs, show_probs = self.yolo_results.probs, probs
|
232 |
+
|
233 |
+
if pred_boxes and show_boxes:
|
234 |
+
for bb_ind, (d, age, gender, gender_score) in enumerate(
|
235 |
+
zip(pred_boxes, self.ages, self.genders, self.gender_scores)
|
236 |
+
):
|
237 |
+
c, conf, guid = int(d.cls), float(
|
238 |
+
d.conf) if conf else None, None if d.id is None else int(d.id.item())
|
239 |
+
name = ("" if guid is None else f"id:{guid} ") + names[c]
|
240 |
+
label = (
|
241 |
+
f"{name} {conf:.2f}" if conf else name) if labels else None
|
242 |
+
if ages and age is not None:
|
243 |
+
label += f" {age:.1f}"
|
244 |
+
if genders and gender is not None:
|
245 |
+
label += f" {'F' if gender == 'female' else 'M'}"
|
246 |
+
if gender_probs and gender_score is not None:
|
247 |
+
label += f" ({gender_score:.1f})"
|
248 |
+
annotator.box_label(d.xyxy.squeeze(), label,
|
249 |
+
color=colors(colors_by_ind[bb_ind], True))
|
250 |
+
|
251 |
+
if pred_probs is not None and show_probs:
|
252 |
+
text = f"{', '.join(f'{names[j] if names else j} {pred_probs.data[j]:.2f}' for j in pred_probs.top5)}, "
|
253 |
+
annotator.text((32, 32), text, txt_color=(
|
254 |
+
255, 255, 255)) # TODO: allow setting colors
|
255 |
+
|
256 |
+
return annotator.result()
|
257 |
+
|
258 |
+
def set_tracked_age_gender(self, tracked_objects: Dict[int, List[AGE_GENDER_TYPE]]):
|
259 |
+
"""
|
260 |
+
Update age and gender for objects based on history from tracked_objects.
|
261 |
+
Args:
|
262 |
+
tracked_objects (dict[int, list[AGE_GENDER_TYPE]]): info about tracked objects by guid
|
263 |
+
"""
|
264 |
+
|
265 |
+
for face_ind, person_ind in self.face_to_person_map.items():
|
266 |
+
pguid = self._get_id_by_ind(person_ind)
|
267 |
+
fguid = self._get_id_by_ind(face_ind)
|
268 |
+
|
269 |
+
if fguid == -1 and pguid == -1:
|
270 |
+
# YOLO might not assign ids for some objects in some cases:
|
271 |
+
# https://github.com/ultralytics/ultralytics/issues/3830
|
272 |
+
continue
|
273 |
+
age, gender = self._gather_tracking_result(
|
274 |
+
tracked_objects, fguid, pguid)
|
275 |
+
if age is None or gender is None:
|
276 |
+
continue
|
277 |
+
self.set_age(face_ind, age)
|
278 |
+
self.set_gender(face_ind, gender, 1.0)
|
279 |
+
if pguid != -1:
|
280 |
+
self.set_gender(person_ind, gender, 1.0)
|
281 |
+
self.set_age(person_ind, age)
|
282 |
+
|
283 |
+
for person_ind in self.unassigned_persons_inds:
|
284 |
+
pid = self._get_id_by_ind(person_ind)
|
285 |
+
if pid == -1:
|
286 |
+
continue
|
287 |
+
age, gender = self._gather_tracking_result(
|
288 |
+
tracked_objects, -1, pid)
|
289 |
+
if age is None or gender is None:
|
290 |
+
continue
|
291 |
+
self.set_gender(person_ind, gender, 1.0)
|
292 |
+
self.set_age(person_ind, age)
|
293 |
+
|
294 |
+
def _get_id_by_ind(self, ind: Optional[int] = None) -> int:
|
295 |
+
if ind is None:
|
296 |
+
return -1
|
297 |
+
obj_id = self.yolo_results.boxes[ind].id
|
298 |
+
if obj_id is None:
|
299 |
+
return -1
|
300 |
+
return obj_id.item()
|
301 |
+
|
302 |
+
def get_bbox_by_ind(self, ind: int, im_h: int = None, im_w: int = None) -> torch.tensor:
|
303 |
+
bb = self.yolo_results.boxes[ind].xyxy.squeeze().type(torch.int32)
|
304 |
+
if im_h is not None and im_w is not None:
|
305 |
+
bb[0] = torch.clamp(bb[0], min=0, max=im_w - 1)
|
306 |
+
bb[1] = torch.clamp(bb[1], min=0, max=im_h - 1)
|
307 |
+
bb[2] = torch.clamp(bb[2], min=0, max=im_w - 1)
|
308 |
+
bb[3] = torch.clamp(bb[3], min=0, max=im_h - 1)
|
309 |
+
return bb
|
310 |
+
|
311 |
+
def set_age(self, ind: Optional[int], age: float):
|
312 |
+
if ind is not None:
|
313 |
+
self.ages[ind] = age
|
314 |
+
|
315 |
+
def set_gender(self, ind: Optional[int], gender: str, gender_score: float):
|
316 |
+
if ind is not None:
|
317 |
+
self.genders[ind] = gender
|
318 |
+
self.gender_scores[ind] = gender_score
|
319 |
+
|
320 |
+
@staticmethod
|
321 |
+
def _gather_tracking_result(
|
322 |
+
tracked_objects: Dict[int, List[AGE_GENDER_TYPE]],
|
323 |
+
fguid: int = -1,
|
324 |
+
pguid: int = -1,
|
325 |
+
minimum_sample_size: int = 10,
|
326 |
+
) -> AGE_GENDER_TYPE:
|
327 |
+
|
328 |
+
assert fguid != -1 or pguid != -1, "Incorrect tracking behaviour"
|
329 |
+
|
330 |
+
face_ages = [r[0] for r in tracked_objects[fguid] if r[0]
|
331 |
+
is not None] if fguid in tracked_objects else []
|
332 |
+
face_genders = [r[1] for r in tracked_objects[fguid]
|
333 |
+
if r[1] is not None] if fguid in tracked_objects else []
|
334 |
+
person_ages = [r[0] for r in tracked_objects[pguid]
|
335 |
+
if r[0] is not None] if pguid in tracked_objects else []
|
336 |
+
person_genders = [r[1] for r in tracked_objects[pguid]
|
337 |
+
if r[1] is not None] if pguid in tracked_objects else []
|
338 |
+
|
339 |
+
if not face_ages and not person_ages: # both empty
|
340 |
+
return None, None
|
341 |
+
|
342 |
+
# You can play here with different aggregation strategies
|
343 |
+
# Face ages - predictions based on face or face + person, depends on history of object
|
344 |
+
# Person ages - predictions based on person or face + person, depends on history of object
|
345 |
+
|
346 |
+
if len(person_ages + face_ages) >= minimum_sample_size:
|
347 |
+
age = aggregate_votes_winsorized(person_ages + face_ages)
|
348 |
+
else:
|
349 |
+
face_age = np.mean(face_ages) if face_ages else None
|
350 |
+
person_age = np.mean(person_ages) if person_ages else None
|
351 |
+
if face_age is None:
|
352 |
+
face_age = person_age
|
353 |
+
if person_age is None:
|
354 |
+
person_age = face_age
|
355 |
+
age = (face_age + person_age) / 2.0
|
356 |
+
|
357 |
+
genders = face_genders + person_genders
|
358 |
+
assert len(genders) > 0
|
359 |
+
# take mode of genders
|
360 |
+
gender = max(set(genders), key=genders.count)
|
361 |
+
|
362 |
+
return age, gender
|
363 |
+
|
364 |
+
def get_results_for_tracking(self) -> Tuple[Dict[int, AGE_GENDER_TYPE], Dict[int, AGE_GENDER_TYPE]]:
|
365 |
+
"""
|
366 |
+
Get objects from current frame
|
367 |
+
"""
|
368 |
+
persons: Dict[int, AGE_GENDER_TYPE] = {}
|
369 |
+
faces: Dict[int, AGE_GENDER_TYPE] = {}
|
370 |
+
|
371 |
+
names = self.yolo_results.names
|
372 |
+
pred_boxes = self.yolo_results.boxes
|
373 |
+
for _, (det, age, gender, _) in enumerate(zip(pred_boxes, self.ages, self.genders, self.gender_scores)):
|
374 |
+
if det.id is None:
|
375 |
+
continue
|
376 |
+
cat_id, _, guid = int(det.cls), float(det.conf), int(det.id.item())
|
377 |
+
name = names[cat_id]
|
378 |
+
if name == "person":
|
379 |
+
persons[guid] = (age, gender)
|
380 |
+
elif name == "face":
|
381 |
+
faces[guid] = (age, gender)
|
382 |
+
|
383 |
+
return persons, faces
|
384 |
+
|
385 |
+
def associate_faces_with_persons(self):
|
386 |
+
face_bboxes_inds: List[int] = self.get_bboxes_inds("face")
|
387 |
+
person_bboxes_inds: List[int] = self.get_bboxes_inds("person")
|
388 |
+
|
389 |
+
face_bboxes: List[torch.tensor] = [
|
390 |
+
self.get_bbox_by_ind(ind) for ind in face_bboxes_inds]
|
391 |
+
person_bboxes: List[torch.tensor] = [
|
392 |
+
self.get_bbox_by_ind(ind) for ind in person_bboxes_inds]
|
393 |
+
|
394 |
+
self.face_to_person_map = {ind: None for ind in face_bboxes_inds}
|
395 |
+
assigned_faces, unassigned_persons_inds = assign_faces(
|
396 |
+
person_bboxes, face_bboxes)
|
397 |
+
|
398 |
+
for face_ind, person_ind in enumerate(assigned_faces):
|
399 |
+
face_ind = face_bboxes_inds[face_ind]
|
400 |
+
person_ind = person_bboxes_inds[person_ind] if person_ind is not None else None
|
401 |
+
self.face_to_person_map[face_ind] = person_ind
|
402 |
+
|
403 |
+
self.unassigned_persons_inds = [
|
404 |
+
person_bboxes_inds[person_ind] for person_ind in unassigned_persons_inds]
|
405 |
+
|
406 |
+
def crop_object(
|
407 |
+
self, full_image: np.ndarray, ind: int, cut_other_classes: Optional[List[str]] = None
|
408 |
+
) -> Optional[np.ndarray]:
|
409 |
+
|
410 |
+
IOU_THRESH = 0.000001
|
411 |
+
MIN_PERSON_CROP_AFTERCUT_RATIO = 0.4
|
412 |
+
CROP_ROUND_RATE = 0.3
|
413 |
+
MIN_PERSON_SIZE = 50
|
414 |
+
|
415 |
+
obj_bbox = self.get_bbox_by_ind(ind, *full_image.shape[:2])
|
416 |
+
x1, y1, x2, y2 = obj_bbox
|
417 |
+
cur_cat = self.yolo_results.names[int(
|
418 |
+
self.yolo_results.boxes[ind].cls)]
|
419 |
+
# get crop of face or person
|
420 |
+
obj_image = full_image[y1:y2, x1:x2].copy()
|
421 |
+
crop_h, crop_w = obj_image.shape[:2]
|
422 |
+
|
423 |
+
if cur_cat == "person" and (crop_h < MIN_PERSON_SIZE or crop_w < MIN_PERSON_SIZE):
|
424 |
+
return None
|
425 |
+
|
426 |
+
if not cut_other_classes:
|
427 |
+
return obj_image
|
428 |
+
|
429 |
+
# calc iou between obj_bbox and other bboxes
|
430 |
+
other_bboxes: List[torch.tensor] = [
|
431 |
+
self.get_bbox_by_ind(other_ind, *full_image.shape[:2]) for other_ind in range(len(self.yolo_results.boxes))
|
432 |
+
]
|
433 |
+
|
434 |
+
iou_matrix = box_iou(torch.stack([obj_bbox]), torch.stack(
|
435 |
+
other_bboxes)).cpu().numpy()[0]
|
436 |
+
|
437 |
+
# cut out other objects in case of intersection
|
438 |
+
for other_ind, (det, iou) in enumerate(zip(self.yolo_results.boxes, iou_matrix)):
|
439 |
+
other_cat = self.yolo_results.names[int(det.cls)]
|
440 |
+
if ind == other_ind or iou < IOU_THRESH or other_cat not in cut_other_classes:
|
441 |
+
continue
|
442 |
+
o_x1, o_y1, o_x2, o_y2 = det.xyxy.squeeze().type(torch.int32)
|
443 |
+
|
444 |
+
# remap current_person_bbox to reference_person_bbox coordinates
|
445 |
+
o_x1 = max(o_x1 - x1, 0)
|
446 |
+
o_y1 = max(o_y1 - y1, 0)
|
447 |
+
o_x2 = min(o_x2 - x1, crop_w)
|
448 |
+
o_y2 = min(o_y2 - y1, crop_h)
|
449 |
+
|
450 |
+
if other_cat != "face":
|
451 |
+
if (o_y1 / crop_h) < CROP_ROUND_RATE:
|
452 |
+
o_y1 = 0
|
453 |
+
if ((crop_h - o_y2) / crop_h) < CROP_ROUND_RATE:
|
454 |
+
o_y2 = crop_h
|
455 |
+
if (o_x1 / crop_w) < CROP_ROUND_RATE:
|
456 |
+
o_x1 = 0
|
457 |
+
if ((crop_w - o_x2) / crop_w) < CROP_ROUND_RATE:
|
458 |
+
o_x2 = crop_w
|
459 |
+
|
460 |
+
obj_image[o_y1:o_y2, o_x1:o_x2] = 0
|
461 |
+
|
462 |
+
obj_image, remain_ratio = cropout_black_parts(
|
463 |
+
obj_image, CROP_ROUND_RATE)
|
464 |
+
if remain_ratio < MIN_PERSON_CROP_AFTERCUT_RATIO:
|
465 |
+
return None
|
466 |
+
|
467 |
+
return obj_image
|
468 |
+
|
469 |
+
def collect_crops(self, image) -> PersonAndFaceCrops:
|
470 |
+
|
471 |
+
crops_data = PersonAndFaceCrops()
|
472 |
+
for face_ind, person_ind in self.face_to_person_map.items():
|
473 |
+
face_image = self.crop_object(
|
474 |
+
image, face_ind, cut_other_classes=[])
|
475 |
+
|
476 |
+
if person_ind is None:
|
477 |
+
crops_data.crops_faces_wo_body[face_ind] = face_image
|
478 |
+
continue
|
479 |
+
|
480 |
+
person_image = self.crop_object(
|
481 |
+
image, person_ind, cut_other_classes=["face", "person"])
|
482 |
+
|
483 |
+
crops_data.crops_faces[face_ind] = face_image
|
484 |
+
crops_data.crops_persons[person_ind] = person_image
|
485 |
+
|
486 |
+
for person_ind in self.unassigned_persons_inds:
|
487 |
+
person_image = self.crop_object(
|
488 |
+
image, person_ind, cut_other_classes=["face", "person"])
|
489 |
+
crops_data.crops_persons_wo_face[person_ind] = person_image
|
490 |
+
|
491 |
+
# uncomment to save preprocessed crops
|
492 |
+
# crops_data.save()
|
493 |
+
return crops_data
|
mivolo/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "0.3.0dev"
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ultralytics==8.0.187
|
2 |
+
timm==0.8.13.dev0
|
3 |
+
tqdm
|
4 |
+
requests
|
5 |
+
opencv-python
|
6 |
+
omegaconf
|
utils.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
from modelscope import snapshot_download
|
5 |
+
from urllib.parse import urlparse
|
6 |
+
|
7 |
+
MODEL_DIR = snapshot_download("MuGeminorum/MiVOLO", cache_dir="./mivolo/__pycache__")
|
8 |
+
|
9 |
+
|
10 |
+
def is_url(s: str):
|
11 |
+
try:
|
12 |
+
# 解析字符串
|
13 |
+
result = urlparse(s)
|
14 |
+
# 检查scheme(如http, https)和netloc(域名)
|
15 |
+
return all([result.scheme, result.netloc])
|
16 |
+
|
17 |
+
except:
|
18 |
+
# 如果解析过程中发生异常,则返回False
|
19 |
+
return False
|
20 |
+
|
21 |
+
|
22 |
+
def download_file(url: str, save_path: str):
|
23 |
+
if os.path.exists(save_path):
|
24 |
+
print("目标已存在,无需下载")
|
25 |
+
return
|
26 |
+
|
27 |
+
create_dir(os.path.dirname(save_path))
|
28 |
+
response = requests.get(url, stream=True)
|
29 |
+
total_size = int(response.headers.get("content-length", 0))
|
30 |
+
# 使用 tqdm 创建一个进度条
|
31 |
+
progress_bar = tqdm(total=total_size, unit="B", unit_scale=True)
|
32 |
+
with open(save_path, "wb") as file:
|
33 |
+
for data in response.iter_content(chunk_size=1024):
|
34 |
+
file.write(data)
|
35 |
+
progress_bar.update(len(data))
|
36 |
+
|
37 |
+
progress_bar.close()
|
38 |
+
if total_size != 0 and progress_bar.n != total_size:
|
39 |
+
os.remove(save_path)
|
40 |
+
print("下载失败,重试中...")
|
41 |
+
download_file(url, save_path)
|
42 |
+
|
43 |
+
else:
|
44 |
+
print("下载完成")
|
45 |
+
|
46 |
+
return save_path
|
47 |
+
|
48 |
+
|
49 |
+
def create_dir(dir_path: str):
|
50 |
+
if not os.path.exists(dir_path):
|
51 |
+
os.makedirs(dir_path)
|
52 |
+
|
53 |
+
|
54 |
+
def get_jpg_files(folder_path: str):
|
55 |
+
all_files = os.listdir(folder_path)
|
56 |
+
return [
|
57 |
+
os.path.join(folder_path, file)
|
58 |
+
for file in all_files
|
59 |
+
if file.lower().endswith(".jpg")
|
60 |
+
]
|