Spaces:
Sleeping
Sleeping
import argparse | |
import logging | |
import os | |
import random | |
import cv2 | |
import torch | |
import yt_dlp | |
from mivolo.data.data_reader import InputType, get_all_files, get_input_type | |
from mivolo.predictor import Predictor | |
from timm.utils import setup_default_logging | |
_logger = logging.getLogger("inference") | |
def get_direct_video_url(video_url): | |
ydl_opts = { | |
"format": "bestvideo", | |
"quiet": True, # Suppress terminal output | |
} | |
with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
info_dict = ydl.extract_info(video_url, download=False) | |
if "url" in info_dict: | |
direct_url = info_dict["url"] | |
resolution = (info_dict["width"], info_dict["height"]) | |
fps = info_dict["fps"] | |
yid = info_dict["id"] | |
return direct_url, resolution, fps, yid | |
return None, None, None, None | |
def get_local_video_info(vid_uri): | |
cap = cv2.VideoCapture(vid_uri) | |
if not cap.isOpened(): | |
raise ValueError(f"Failed to open video source {vid_uri}") | |
res = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) | |
fps = cap.get(cv2.CAP_PROP_FPS) | |
return res, fps | |
def get_random_frames(cap, num_frames): | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_indices = random.sample(range(total_frames), num_frames) | |
frames = [] | |
for idx in frame_indices: | |
cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
ret, frame = cap.read() | |
if ret: | |
frames.append(frame) | |
return frames | |
def get_parser(): | |
parser = argparse.ArgumentParser(description="PyTorch MiVOLO Inference") | |
parser.add_argument("--input", type=str, default=None, required=True, help="image file or folder with images") | |
parser.add_argument("--output", type=str, default=None, required=True, help="folder for output results") | |
parser.add_argument("--detector-weights", type=str, default=None, required=True, help="Detector weights (YOLOv8).") | |
parser.add_argument("--checkpoint", default="", type=str, required=True, help="path to mivolo checkpoint") | |
parser.add_argument( | |
"--with-persons", action="store_true", default=False, help="If set model will run with persons, if available" | |
) | |
parser.add_argument( | |
"--disable-faces", action="store_true", default=False, help="If set model will use only persons if available" | |
) | |
parser.add_argument("--draw", action="store_true", default=False, help="If set, resulted images will be drawn") | |
parser.add_argument("--device", default="cuda", type=str, help="Device (accelerator) to use.") | |
return parser | |
def main(): | |
parser = get_parser() | |
setup_default_logging() | |
args = parser.parse_args() | |
if torch.cuda.is_available(): | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.benchmark = True | |
os.makedirs(args.output, exist_ok=True) | |
predictor = Predictor(args, verbose=True) | |
input_type = get_input_type(args.input) | |
if input_type == InputType.Video or input_type == InputType.VideoStream: | |
if "youtube" in args.input: | |
args.input, res, fps, yid = get_direct_video_url(args.input) | |
if not args.input: | |
raise ValueError(f"Failed to get direct video url {args.input}") | |
else: | |
cap = cv2.VideoCapture(args.input) | |
if not cap.isOpened(): | |
raise ValueError(f"Failed to open video source {args.input}") | |
# Extract 4-5 random frames from the video | |
random_frames = get_random_frames(cap, num_frames=5) | |
age_list = [] | |
for frame in random_frames: | |
detected_objects, out_im, age = predictor.recognize(frame) | |
age_list.append(age[0]) | |
if args.draw: | |
bname = os.path.splitext(os.path.basename(args.input))[0] | |
filename = os.path.join(args.output, f"out_{bname}.jpg") | |
cv2.imwrite(filename, out_im) | |
_logger.info(f"Saved result to {filename}") | |
# Calculate and print average age | |
avg_age = sum(age_list) / len(age_list) if age_list else 0 | |
print(f"Age list: {age_list}") | |
print(f"Average age: {avg_age:.2f}") | |
absolute_age = round(abs(avg_age)) | |
# Define the range | |
lower_bound = absolute_age - 2 | |
upper_bound = absolute_age + 2 | |
return absolute_age, lower_bound, upper_bound | |
elif input_type == InputType.Image: | |
image_files = get_all_files(args.input) if os.path.isdir(args.input) else [args.input] | |
for img_p in image_files: | |
img = cv2.imread(img_p) | |
detected_objects, out_im, age = predictor.recognize(img) | |
if args.draw: | |
bname = os.path.splitext(os.path.basename(img_p))[0] | |
filename = os.path.join(args.output, f"out_{bname}.jpg") | |
cv2.imwrite(filename, out_im) | |
_logger.info(f"Saved result to {filename}") | |
if __name__ == "__main__": | |
absolute_age, lower_bound, upper_bound = main() | |
# Output the results in the desired format | |
print(f"Absolute Age: {absolute_age}") | |
print(f"Range: {lower_bound} - {upper_bound}") |