Spaces:
Runtime error
Runtime error
File size: 4,355 Bytes
0fb1163 dea9d86 0fb1163 dea9d86 0fb1163 90e2561 0fb1163 0bb375d 0fb1163 90e2561 0fb1163 88226d6 fbe1110 88226d6 0fb1163 fbe1110 0fb1163 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
import gradio as gr
from transformers import AutoProcessor, AutoModel
from utils import (
convert_frames_to_gif,
download_youtube_video,
get_num_total_frames,
sample_frames_from_video_file,
)
FRAME_SAMPLING_RATE = 4
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot"
VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [
"microsoft/xclip-base-patch32",
"microsoft/xclip-base-patch16-zero-shot",
"microsoft/xclip-base-patch16-kinetics-600",
"microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames",
"microsoft/xclip-large-patch14",
"microsoft/xclip-base-patch16-hmdb-4-shot",
"microsoft/xclip-base-patch16-16-frames",
"microsoft/xclip-base-patch16-hmdb-2-shot",
"microsoft/xclip-base-patch16-ucf-2-shot",
"microsoft/xclip-base-patch16-ucf-8-shot",
"microsoft/xclip-base-patch16",
"microsoft/xclip-base-patch16-hmdb-8-shot",
"microsoft/xclip-base-patch16-hmdb-16-shot",
"microsoft/xclip-base-patch16-ucf-16-shot",
]
processor = AutoProcessor.from_pretrained(DEFAULT_MODEL)
model = AutoModel.from_pretrained(DEFAULT_MODEL)
# def select_model(model_name):
# global processor, model
# processor = AutoProcessor.from_pretrained(model_name)
# model = AutoModel.from_pretrained(model_name)
def predict(youtube_url_or_file_path, labels_text):
if youtube_url_or_file_path.startswith("http"):
video_path = download_youtube_video(youtube_url_or_file_path)
else:
video_path = youtube_url_or_file_path
num_total_frames = get_num_total_frames(video_path)
num_model_input_frames = model.config.vision_config.num_frames
if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames:
frame_sampling_rate = num_total_frames // num_model_input_frames
else:
frame_sampling_rate = FRAME_SAMPLING_RATE
labels = labels_text.split(",")
frames = sample_frames_from_video_file(
video_path, num_model_input_frames, frame_sampling_rate
)
gif_path = convert_frames_to_gif(frames, save_path="video.gif")
inputs = processor(
text=labels, videos=list(frames), return_tensors="pt", padding=True
)
with torch.no_grad():
outputs = model(**inputs)
probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
label_to_prob = {}
for ind, label in enumerate(labels):
label_to_prob[label] = float(probs[ind])
return label_to_prob, gif_path
app = gr.Blocks()
with app:
gr.Markdown(
"# **<p align='center'> PROTOG - VIOLENCE DETECTION MODULE</p>**"
)
with gr.Row():
with gr.Column():
# model_names_dropdown = gr.Dropdown(
# choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS,
# label="Model:",
# show_label=True,
# value=DEFAULT_MODEL,
# )
# model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown)
with gr.Tab(label="Youtube URL"):
gr.Markdown(
"### **Enter Youtube URL**"
)
youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
youtube_url_labels_text = gr.Textbox(
label="Labels Text:", show_label=True
)
youtube_url_predict_btn = gr.Button(value="Predict")
with gr.Tab(label="Local File"):
gr.Markdown(
"### **Video Upload**"
)
video_file = gr.Video(label="Video File:", show_label=True)
local_video_labels_text = gr.Textbox(
label="Labels Text:", show_label=True
)
local_video_predict_btn = gr.Button(value="Predict")
with gr.Column():
video_gif = gr.Image(
label="Input Clip",
show_label=True,
)
with gr.Column():
predictions = gr.Label(label="Predictions:", show_label=True)
youtube_url_predict_btn.click(
predict,
inputs=[youtube_url, youtube_url_labels_text],
outputs=[predictions, video_gif],
)
local_video_predict_btn.click(
predict,
inputs=[video_file, local_video_labels_text],
outputs=[predictions, video_gif],
)
app.launch() |