|
import gradio as gr |
|
import tempfile |
|
import torch |
|
from pytorchvideo.data import make_clip_sampler |
|
from pytorchvideo.data.clip_sampling import ClipInfoList |
|
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV |
|
from pytorchvideo.data.video import VideoPathHandler |
|
from pytorchvideo.transforms import ( |
|
Normalize, |
|
UniformTemporalSubsample, RandomShortSideScale, |
|
) |
|
from torchvision.transforms import ( |
|
Compose, |
|
Lambda, |
|
Resize, RandomCrop, |
|
) |
|
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor |
|
|
|
from video_utils import change_video_resolution_and_fps |
|
|
|
MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass" |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
CLIPS_FROM_SINGLE_VIDEO = 5 |
|
|
|
trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE) |
|
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT) |
|
|
|
mean = image_processor.image_mean |
|
std = image_processor.image_std |
|
if "shortest_edge" in image_processor.size: |
|
height = width = image_processor.size["shortest_edge"] |
|
else: |
|
height = image_processor.size["height"] |
|
width = image_processor.size["width"] |
|
resize_to = (height, width) |
|
|
|
num_frames_to_sample = trained_model.config.num_frames |
|
sample_rate = 4 |
|
fps = 30 |
|
clip_duration = num_frames_to_sample * sample_rate / fps |
|
|
|
|
|
inference_transform = Compose( |
|
[ |
|
UniformTemporalSubsample(num_frames_to_sample), |
|
Lambda(lambda x: x / 255.0), |
|
Normalize(mean, std), |
|
RandomShortSideScale(min_size=256, max_size=320), |
|
RandomCrop(resize_to), |
|
] |
|
) |
|
|
|
labels = list(trained_model.config.label2id.keys()) |
|
|
|
|
|
def parse_video_to_clips(video_file): |
|
"""A utility to parse the input videos """ |
|
new_resolution = (320, 256) |
|
new_fps = 30 |
|
acceptable_fps_violation = 5 |
|
with tempfile.NamedTemporaryFile() as new_video: |
|
print(new_video.name) |
|
change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation) |
|
video_path_handler = VideoPathHandler() |
|
video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file) |
|
|
|
clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO) |
|
|
|
clip_info: ClipInfoList = clip_sampler(0, video.duration, {}) |
|
|
|
video_clips_list = [] |
|
for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec): |
|
video_clip = video.get_clip(clip_start, clip_end)["video"] |
|
video_clips_list.append(inference_transform(video_clip)) |
|
|
|
videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list]) |
|
return videos_tensor.to(DEVICE) |
|
|
|
|
|
def infer(video_file): |
|
videos_tensor = parse_video_to_clips(video_file) |
|
inputs = {"pixel_values": videos_tensor} |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = trained_model(**inputs) |
|
multiple_logits = outputs.logits |
|
logits = multiple_logits.sum(dim=0) |
|
softmax_scores = torch.nn.functional.softmax(logits, dim=-1).squeeze(0) |
|
confidences = {labels[i]: float(softmax_scores[i]) for i in range(len(labels))} |
|
return confidences |
|
|
|
|
|
gr.Interface( |
|
fn=infer, |
|
inputs=gr.Video(type="file"), |
|
outputs=gr.Label(num_top_classes=3), |
|
examples=[ |
|
["examples/DUNK.avi"], |
|
["examples/FLOATING_JUMP_SHOT.avi"], |
|
["examples/JUMP_SHOT.avi"], |
|
["examples/REVERSE_LAYUP.avi"], |
|
["examples/TURNAROUND_HOOK_SHOT.avi"], |
|
], |
|
title="VideoMAE fine-tuned on nba data", |
|
description=( |
|
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the" |
|
" examples to load them. Read more at the links below." |
|
), |
|
article=( |
|
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>" |
|
" <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multiclass' target='_blank'>Fine-tuned Model</a></center></div>" |
|
), |
|
allow_flagging=False, |
|
allow_screenshot=False, |
|
).launch() |
|
|