import argparse import logging import os import random import cv2 import torch import yt_dlp import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '././'))) from mivolo.data.data_reader import InputType, get_all_files, get_input_type from mivolo.predictor import Predictor from timm.utils import setup_default_logging _logger = logging.getLogger("inference") def get_direct_video_url(video_url): ydl_opts = { "format": "bestvideo", "quiet": True, # Suppress terminal output } with yt_dlp.YoutubeDL(ydl_opts) as ydl: info_dict = ydl.extract_info(video_url, download=False) if "url" in info_dict: direct_url = info_dict["url"] resolution = (info_dict["width"], info_dict["height"]) fps = info_dict["fps"] yid = info_dict["id"] return direct_url, resolution, fps, yid return None, None, None, None def get_random_frames(cap, num_frames): total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_indices = random.sample(range(total_frames), num_frames) frames = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frames.append(frame) return frames def get_parser(): parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference") parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images") parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results") parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).") parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint") parser.add_argument( "--with_persons", action="store_true", default=False, help="If set model will run with persons, if available" ) parser.add_argument( "--disable_faces", action="store_true", default=False, help="If set model will use only persons if available" ) parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn") parser.add_argument("--device", default="cpu", type=str, help="Device (accelerator) to use.") return parser def main(video_path, output_folder, detector_weights, checkpoint, device, with_persons, disable_faces,draw=False): setup_default_logging() if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True os.makedirs(output_folder, exist_ok=True) # Initialize predictor args = argparse.Namespace( input=video_path, output=output_folder, detector_weights=detector_weights, checkpoint=checkpoint, draw=draw, device=device, with_persons=with_persons, disable_faces=disable_faces ) predictor = Predictor(args, verbose=True) if "youtube" in video_path: video_path, res, fps, yid = get_direct_video_url(video_path) if not video_path: raise ValueError(f"Failed to get direct video url {video_path}") cap = cv2.VideoCapture(video_path) if not cap.isOpened(): raise ValueError(f"Failed to open video source {video_path}") # Extract 4-5 random frames from the video random_frames = get_random_frames(cap, num_frames=10) age_list = [] for frame in random_frames: detected_objects, out_im, age = predictor.recognize(frame) try: age_list.append(age[0]) # Attempt to access the first element of age if draw: bname = os.path.splitext(os.path.basename(video_path))[0] filename = os.path.join(output_folder, f"out_{bname}.jpg") cv2.imwrite(filename, out_im) _logger.info(f"Saved result to {filename}") except IndexError: continue if len(age_list)==0: raise ValueError("No person was detected in the frame. Please upload a proper face video.") # Calculate and print average age avg_age = sum(age_list) / len(age_list) if age_list else 0 print(f"Age list: {age_list}") print(f"Average age: {avg_age:.2f}") absolute_age = round(abs(avg_age)) # Define the range lower_bound = absolute_age - 2 upper_bound = absolute_age + 2 return absolute_age, lower_bound, upper_bound if __name__ == "__main__": parser = get_parser() args = parser.parse_args() absolute_age, lower_bound, upper_bound = main(args.input, args.output, args.detector_weights, args.checkpoint, args.device, args.with_persons, args.disable_faces ,args.draw) # Output the results in the desired format print(f"Absolute Age: {absolute_age}") print(f"Range: {lower_bound} - {upper_bound}")