|
import logging |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import torch |
|
from mivolo.data.misc import prepare_classification_images |
|
from mivolo.model.create_timm_model import create_model |
|
from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult |
|
from timm.data import resolve_data_config |
|
|
|
_logger = logging.getLogger("MiVOLO") |
|
has_compile = hasattr(torch, "compile") |
|
|
|
|
|
class Meta: |
|
def __init__(self): |
|
self.min_age = None |
|
self.max_age = None |
|
self.avg_age = None |
|
self.num_classes = None |
|
|
|
self.in_chans = 3 |
|
self.with_persons_model = False |
|
self.disable_faces = False |
|
self.use_persons = True |
|
self.only_age = False |
|
|
|
self.num_classes_gender = 2 |
|
|
|
def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta": |
|
|
|
state = torch.load(ckpt_path, map_location="cpu") |
|
|
|
self.min_age = state["min_age"] |
|
self.max_age = state["max_age"] |
|
self.avg_age = state["avg_age"] |
|
self.only_age = state["no_gender"] |
|
|
|
only_age = state["no_gender"] |
|
|
|
self.disable_faces = disable_faces |
|
if "with_persons_model" in state: |
|
self.with_persons_model = state["with_persons_model"] |
|
else: |
|
self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False |
|
|
|
self.num_classes = 1 if only_age else 3 |
|
self.in_chans = 3 if not self.with_persons_model else 6 |
|
self.use_persons = use_persons and self.with_persons_model |
|
|
|
if not self.with_persons_model and self.disable_faces: |
|
raise ValueError("You can not use disable-faces for faces-only model") |
|
if self.with_persons_model and self.disable_faces and not self.use_persons: |
|
raise ValueError("You can not disable faces and persons together") |
|
|
|
return self |
|
|
|
def __str__(self): |
|
attrs = vars(self) |
|
attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops}) |
|
return ", ".join("%s: %s" % item for item in attrs.items()) |
|
|
|
@property |
|
def use_person_crops(self) -> bool: |
|
return self.with_persons_model and self.use_persons |
|
|
|
@property |
|
def use_face_crops(self) -> bool: |
|
return not self.disable_faces or not self.with_persons_model |
|
|
|
|
|
class MiVOLO: |
|
def __init__( |
|
self, |
|
ckpt_path: str, |
|
device: str = "cpu", |
|
half: bool = True, |
|
disable_faces: bool = False, |
|
use_persons: bool = True, |
|
verbose: bool = False, |
|
torchcompile: Optional[str] = None, |
|
): |
|
self.verbose = verbose |
|
self.device = torch.device(device) |
|
self.half = half and self.device.type != "cpu" |
|
|
|
self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons) |
|
if self.verbose: |
|
_logger.info(f"Model meta:\n{str(self.meta)}") |
|
|
|
model_name = "mivolo_d1_224" |
|
self.model = create_model( |
|
model_name=model_name, |
|
num_classes=self.meta.num_classes, |
|
in_chans=self.meta.in_chans, |
|
pretrained=False, |
|
checkpoint_path=ckpt_path, |
|
filter_keys=["fds."], |
|
) |
|
self.param_count = sum([m.numel() for m in self.model.parameters()]) |
|
_logger.info(f"Model {model_name} created, param count: {self.param_count}") |
|
|
|
self.data_config = resolve_data_config( |
|
model=self.model, |
|
verbose=verbose, |
|
use_test_size=True, |
|
) |
|
self.data_config["crop_pct"] = 1.0 |
|
c, h, w = self.data_config["input_size"] |
|
assert h == w, "Incorrect data_config" |
|
self.input_size = w |
|
|
|
self.model = self.model.to(self.device) |
|
|
|
if torchcompile: |
|
assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly." |
|
torch._dynamo.reset() |
|
self.model = torch.compile(self.model, backend=torchcompile) |
|
|
|
self.model.eval() |
|
if self.half: |
|
self.model = self.model.half() |
|
|
|
def warmup(self, batch_size: int, steps=10): |
|
if self.meta.with_persons_model: |
|
input_size = (6, self.input_size, self.input_size) |
|
else: |
|
input_size = self.data_config["input_size"] |
|
|
|
input = torch.randn((batch_size,) + tuple(input_size)).to(self.device) |
|
|
|
for _ in range(steps): |
|
out = self.inference(input) |
|
|
|
if torch.cuda.is_available(): |
|
torch.cuda.synchronize() |
|
|
|
def inference(self, model_input: torch.tensor) -> torch.tensor: |
|
|
|
with torch.no_grad(): |
|
if self.half: |
|
model_input = model_input.half() |
|
output = self.model(model_input) |
|
return output |
|
|
|
def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): |
|
if detected_bboxes.n_objects == 0: |
|
return |
|
|
|
faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes) |
|
|
|
if self.meta.with_persons_model: |
|
model_input = torch.cat((faces_input, person_input), dim=1) |
|
else: |
|
model_input = faces_input |
|
output = self.inference(model_input) |
|
|
|
|
|
self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds) |
|
|
|
def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds): |
|
if self.meta.only_age: |
|
age_output = output |
|
gender_probs, gender_indx = None, None |
|
else: |
|
age_output = output[:, 2] |
|
gender_output = output[:, :2].softmax(-1) |
|
gender_probs, gender_indx = gender_output.topk(1) |
|
|
|
assert output.shape[0] == len(faces_inds) == len(bodies_inds) |
|
|
|
|
|
for index in range(output.shape[0]): |
|
face_ind = faces_inds[index] |
|
body_ind = bodies_inds[index] |
|
|
|
|
|
age = age_output[index].item() |
|
age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age |
|
age = round(age, 2) |
|
|
|
detected_bboxes.set_age(face_ind, age) |
|
detected_bboxes.set_age(body_ind, age) |
|
|
|
_logger.info(f"\tage: {age}") |
|
|
|
if gender_probs is not None: |
|
gender = "male" if gender_indx[index].item() == 0 else "female" |
|
gender_score = gender_probs[index].item() |
|
|
|
_logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]") |
|
|
|
detected_bboxes.set_gender(face_ind, gender, gender_score) |
|
detected_bboxes.set_gender(body_ind, gender, gender_score) |
|
|
|
def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): |
|
|
|
if self.meta.use_person_crops and self.meta.use_face_crops: |
|
detected_bboxes.associate_faces_with_persons() |
|
|
|
crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image) |
|
(bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies( |
|
self.meta.use_person_crops, self.meta.use_face_crops |
|
) |
|
|
|
if not self.meta.use_face_crops: |
|
assert all(f is None for f in faces_crops) |
|
|
|
faces_input = prepare_classification_images( |
|
faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device |
|
) |
|
|
|
if not self.meta.use_person_crops: |
|
assert all(p is None for p in bodies_crops) |
|
|
|
person_input = prepare_classification_images( |
|
bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device |
|
) |
|
|
|
_logger.info( |
|
f"faces_input: {faces_input.shape if faces_input is not None else None}, " |
|
f"person_input: {person_input.shape if person_input is not None else None}" |
|
) |
|
|
|
return faces_input, person_input, faces_inds, bodies_inds |
|
|
|
|
|
if __name__ == "__main__": |
|
model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0") |
|
|