import logging from typing import Optional import numpy as np import torch from mivolo.data.misc import prepare_classification_images from mivolo.model.create_timm_model import create_model from mivolo.structures import PersonAndFaceCrops, PersonAndFaceResult from timm.data import resolve_data_config _logger = logging.getLogger("MiVOLO") has_compile = hasattr(torch, "compile") class Meta: def __init__(self): self.min_age = None self.max_age = None self.avg_age = None self.num_classes = None self.in_chans = 3 self.with_persons_model = False self.disable_faces = False self.use_persons = True self.only_age = False self.num_classes_gender = 2 def load_from_ckpt(self, ckpt_path: str, disable_faces: bool = False, use_persons: bool = True) -> "Meta": state = torch.load(ckpt_path, map_location="cpu") self.min_age = state["min_age"] self.max_age = state["max_age"] self.avg_age = state["avg_age"] self.only_age = state["no_gender"] only_age = state["no_gender"] self.disable_faces = disable_faces if "with_persons_model" in state: self.with_persons_model = state["with_persons_model"] else: self.with_persons_model = True if "patch_embed.conv1.0.weight" in state["state_dict"] else False self.num_classes = 1 if only_age else 3 self.in_chans = 3 if not self.with_persons_model else 6 self.use_persons = use_persons and self.with_persons_model if not self.with_persons_model and self.disable_faces: raise ValueError("You can not use disable-faces for faces-only model") if self.with_persons_model and self.disable_faces and not self.use_persons: raise ValueError("You can not disable faces and persons together") return self def __str__(self): attrs = vars(self) attrs.update({"use_person_crops": self.use_person_crops, "use_face_crops": self.use_face_crops}) return ", ".join("%s: %s" % item for item in attrs.items()) @property def use_person_crops(self) -> bool: return self.with_persons_model and self.use_persons @property def use_face_crops(self) -> bool: return not self.disable_faces or not self.with_persons_model class MiVOLO: def __init__( self, ckpt_path: str, device: str = "cpu", half: bool = True, disable_faces: bool = False, use_persons: bool = True, verbose: bool = False, torchcompile: Optional[str] = None, ): self.verbose = verbose self.device = torch.device(device) self.half = half and self.device.type != "cpu" self.meta: Meta = Meta().load_from_ckpt(ckpt_path, disable_faces, use_persons) if self.verbose: _logger.info(f"Model meta:\n{str(self.meta)}") model_name = "mivolo_d1_224" self.model = create_model( model_name=model_name, num_classes=self.meta.num_classes, in_chans=self.meta.in_chans, pretrained=False, checkpoint_path=ckpt_path, filter_keys=["fds."], ) self.param_count = sum([m.numel() for m in self.model.parameters()]) _logger.info(f"Model {model_name} created, param count: {self.param_count}") self.data_config = resolve_data_config( model=self.model, verbose=verbose, use_test_size=True, ) self.data_config["crop_pct"] = 1.0 c, h, w = self.data_config["input_size"] assert h == w, "Incorrect data_config" self.input_size = w self.model = self.model.to(self.device) if torchcompile: assert has_compile, "A version of torch w/ torch.compile() is required for --compile, possibly a nightly." torch._dynamo.reset() self.model = torch.compile(self.model, backend=torchcompile) self.model.eval() if self.half: self.model = self.model.half() def warmup(self, batch_size: int, steps=10): if self.meta.with_persons_model: input_size = (6, self.input_size, self.input_size) else: input_size = self.data_config["input_size"] input = torch.randn((batch_size,) + tuple(input_size)).to(self.device) for _ in range(steps): out = self.inference(input) # noqa: F841 if torch.cuda.is_available(): torch.cuda.synchronize() def inference(self, model_input: torch.tensor) -> torch.tensor: with torch.no_grad(): if self.half: model_input = model_input.half() output = self.model(model_input) return output def predict(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): if detected_bboxes.n_objects == 0: return faces_input, person_input, faces_inds, bodies_inds = self.prepare_crops(image, detected_bboxes) if self.meta.with_persons_model: model_input = torch.cat((faces_input, person_input), dim=1) else: model_input = faces_input output = self.inference(model_input) # write gender and age results into detected_bboxes self.fill_in_results(output, detected_bboxes, faces_inds, bodies_inds) def fill_in_results(self, output, detected_bboxes, faces_inds, bodies_inds): if self.meta.only_age: age_output = output gender_probs, gender_indx = None, None else: age_output = output[:, 2] gender_output = output[:, :2].softmax(-1) gender_probs, gender_indx = gender_output.topk(1) assert output.shape[0] == len(faces_inds) == len(bodies_inds) # per face for index in range(output.shape[0]): face_ind = faces_inds[index] body_ind = bodies_inds[index] # get_age age = age_output[index].item() age = age * (self.meta.max_age - self.meta.min_age) + self.meta.avg_age age = round(age, 2) detected_bboxes.set_age(face_ind, age) detected_bboxes.set_age(body_ind, age) _logger.info(f"\tage: {age}") if gender_probs is not None: gender = "male" if gender_indx[index].item() == 0 else "female" gender_score = gender_probs[index].item() _logger.info(f"\tgender: {gender} [{int(gender_score * 100)}%]") detected_bboxes.set_gender(face_ind, gender, gender_score) detected_bboxes.set_gender(body_ind, gender, gender_score) def prepare_crops(self, image: np.ndarray, detected_bboxes: PersonAndFaceResult): if self.meta.use_person_crops and self.meta.use_face_crops: detected_bboxes.associate_faces_with_persons() crops: PersonAndFaceCrops = detected_bboxes.collect_crops(image) (bodies_inds, bodies_crops), (faces_inds, faces_crops) = crops.get_faces_with_bodies( self.meta.use_person_crops, self.meta.use_face_crops ) if not self.meta.use_face_crops: assert all(f is None for f in faces_crops) faces_input = prepare_classification_images( faces_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device ) if not self.meta.use_person_crops: assert all(p is None for p in bodies_crops) person_input = prepare_classification_images( bodies_crops, self.input_size, self.data_config["mean"], self.data_config["std"], device=self.device ) _logger.info( f"faces_input: {faces_input.shape if faces_input is not None else None}, " f"person_input: {person_input.shape if person_input is not None else None}" ) return faces_input, person_input, faces_inds, bodies_inds if __name__ == "__main__": model = MiVOLO("../pretrained/checkpoint-377.pth.tar", half=True, device="cuda:0")