jaimin's picture
Upload 78 files
bf53f45 verified
import argparse
import logging
import os
import random
import cv2
import torch
import yt_dlp
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '././')))
from mivolo.data.data_reader import InputType, get_all_files, get_input_type
from mivolo.predictor import Predictor
from timm.utils import setup_default_logging
_logger = logging.getLogger("inference")
def get_direct_video_url(video_url):
ydl_opts = {
"format": "bestvideo",
"quiet": True, # Suppress terminal output
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(video_url, download=False)
if "url" in info_dict:
direct_url = info_dict["url"]
resolution = (info_dict["width"], info_dict["height"])
fps = info_dict["fps"]
yid = info_dict["id"]
return direct_url, resolution, fps, yid
return None, None, None, None
def get_random_frames(cap, num_frames):
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_indices = random.sample(range(total_frames), num_frames)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frames.append(frame)
return frames
def get_parser():
parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference")
parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images")
parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results")
parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).")
parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint")
parser.add_argument(
"--with_persons", action="store_true", default=False, help="If set model will run with persons, if available"
)
parser.add_argument(
"--disable_faces", action="store_true", default=False, help="If set model will use only persons if available"
)
parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn")
parser.add_argument("--device", default="cpu", type=str, help="Device (accelerator) to use.")
return parser
def main(video_path, output_folder, detector_weights, checkpoint, device, with_persons, disable_faces,draw=False):
setup_default_logging()
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
os.makedirs(output_folder, exist_ok=True)
# Initialize predictor
args = argparse.Namespace(
input=video_path,
output=output_folder,
detector_weights=detector_weights,
checkpoint=checkpoint,
draw=draw,
device=device,
with_persons=with_persons,
disable_faces=disable_faces
)
predictor = Predictor(args, verbose=True)
if "youtube" in video_path:
video_path, res, fps, yid = get_direct_video_url(video_path)
if not video_path:
raise ValueError(f"Failed to get direct video url {video_path}")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise ValueError(f"Failed to open video source {video_path}")
# Extract 4-5 random frames from the video
random_frames = get_random_frames(cap, num_frames=10)
age_list = []
for frame in random_frames:
detected_objects, out_im, age = predictor.recognize(frame)
try:
age_list.append(age[0]) # Attempt to access the first element of age
if draw:
bname = os.path.splitext(os.path.basename(video_path))[0]
filename = os.path.join(output_folder, f"out_{bname}.jpg")
cv2.imwrite(filename, out_im)
_logger.info(f"Saved result to {filename}")
except IndexError:
continue
if len(age_list)==0:
raise ValueError("No person was detected in the frame. Please upload a proper face video.")
# Calculate and print average age
avg_age = sum(age_list) / len(age_list) if age_list else 0
print(f"Age list: {age_list}")
print(f"Average age: {avg_age:.2f}")
absolute_age = round(abs(avg_age))
# Define the range
lower_bound = absolute_age - 2
upper_bound = absolute_age + 2
return absolute_age, lower_bound, upper_bound
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
absolute_age, lower_bound, upper_bound = main(args.input, args.output, args.detector_weights, args.checkpoint, args.device, args.with_persons, args.disable_faces ,args.draw)
# Output the results in the desired format
print(f"Absolute Age: {absolute_age}")
print(f"Range: {lower_bound} - {upper_bound}")