Spaces:
Running
Running
from collections import defaultdict | |
from typing import Dict, Generator, List, Optional, Tuple | |
import cv2 | |
import numpy as np | |
import tqdm | |
from mivolo.model.mi_volo import MiVOLO | |
from mivolo.model.yolo_detector import Detector | |
from mivolo.structures import AGE_GENDER_TYPE, PersonAndFaceResult | |
class Predictor: | |
def __init__(self, config, verbose: bool = False): | |
self.detector = Detector(config.detector_weights, config.device, verbose=verbose) | |
self.age_gender_model = MiVOLO( | |
config.checkpoint, | |
config.device, | |
half=True, | |
use_persons=config.with_persons, | |
disable_faces=config.disable_faces, | |
verbose=verbose, | |
) | |
self.draw = config.draw | |
def recognize(self, image: np.ndarray) -> Tuple[PersonAndFaceResult, Optional[np.ndarray]]: | |
detected_objects: PersonAndFaceResult = self.detector.predict(image) | |
self.age_gender_model.predict(image, detected_objects) | |
out_im = None | |
if self.draw: | |
# plot results on image | |
out_im = detected_objects.plot() | |
return detected_objects, out_im | |
def recognize_video(self, source: str) -> Generator: | |
video_capture = cv2.VideoCapture(source) | |
if not video_capture.isOpened(): | |
raise ValueError(f"Failed to open video source {source}") | |
detected_objects_history: Dict[int, List[AGE_GENDER_TYPE]] = defaultdict(list) | |
total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT)) | |
for _ in tqdm.tqdm(range(total_frames)): | |
ret, frame = video_capture.read() | |
if not ret: | |
break | |
detected_objects: PersonAndFaceResult = self.detector.track(frame) | |
self.age_gender_model.predict(frame, detected_objects) | |
current_frame_objs = detected_objects.get_results_for_tracking() | |
cur_persons: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[0] | |
cur_faces: Dict[int, AGE_GENDER_TYPE] = current_frame_objs[1] | |
# add tr_persons and tr_faces to history | |
for guid, data in cur_persons.items(): | |
# not useful for tracking :) | |
if None not in data: | |
detected_objects_history[guid].append(data) | |
for guid, data in cur_faces.items(): | |
if None not in data: | |
detected_objects_history[guid].append(data) | |
detected_objects.set_tracked_age_gender(detected_objects_history) | |
if self.draw: | |
frame = detected_objects.plot() | |
yield detected_objects_history, frame | |