import torch import gradio as gr from transformers import AutoProcessor, AutoModel from utils import ( convert_frames_to_gif, download_youtube_video, get_num_total_frames, sample_frames_from_video_file, ) FRAME_SAMPLING_RATE = 4 DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot" VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [ "microsoft/xclip-base-patch32", "microsoft/xclip-base-patch16-zero-shot", "microsoft/xclip-base-patch16-kinetics-600", "microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames", "microsoft/xclip-large-patch14", "microsoft/xclip-base-patch16-hmdb-4-shot", "microsoft/xclip-base-patch16-16-frames", "microsoft/xclip-base-patch16-hmdb-2-shot", "microsoft/xclip-base-patch16-ucf-2-shot", "microsoft/xclip-base-patch16-ucf-8-shot", "microsoft/xclip-base-patch16", "microsoft/xclip-base-patch16-hmdb-8-shot", "microsoft/xclip-base-patch16-hmdb-16-shot", "microsoft/xclip-base-patch16-ucf-16-shot", ] processor = AutoProcessor.from_pretrained(DEFAULT_MODEL) model = AutoModel.from_pretrained(DEFAULT_MODEL) # def select_model(model_name): # global processor, model # processor = AutoProcessor.from_pretrained(model_name) # model = AutoModel.from_pretrained(model_name) def predict(youtube_url_or_file_path, labels_text): if youtube_url_or_file_path.startswith("http"): video_path = download_youtube_video(youtube_url_or_file_path) else: video_path = youtube_url_or_file_path num_total_frames = get_num_total_frames(video_path) num_model_input_frames = model.config.vision_config.num_frames if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames: frame_sampling_rate = num_total_frames // num_model_input_frames else: frame_sampling_rate = FRAME_SAMPLING_RATE labels = labels_text.split(",") frames = sample_frames_from_video_file( video_path, num_model_input_frames, frame_sampling_rate ) gif_path = convert_frames_to_gif(frames, save_path="video.gif") inputs = processor( text=labels, videos=list(frames), return_tensors="pt", padding=True ) with torch.no_grad(): outputs = model(**inputs) probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy() label_to_prob = {} for ind, label in enumerate(labels): label_to_prob[label] = float(probs[ind]) return label_to_prob, gif_path app = gr.Blocks() with app: gr.Markdown( "# **
PROTOG - VIOLENCE DETECTION MODULE
**" ) with gr.Row(): with gr.Column(): # model_names_dropdown = gr.Dropdown( # choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS, # label="Model:", # show_label=True, # value=DEFAULT_MODEL, # ) # model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown) with gr.Tab(label="Youtube URL"): gr.Markdown( "### **Enter Youtube URL**" ) youtube_url = gr.Textbox(label="Youtube URL:", show_label=True) youtube_url_labels_text = gr.Textbox( label="Labels Text:", show_label=True ) youtube_url_predict_btn = gr.Button(value="Predict") with gr.Tab(label="Local File"): gr.Markdown( "### **Video Upload**" ) video_file = gr.Video(label="Video File:", show_label=True) local_video_labels_text = gr.Textbox( label="Labels Text:", show_label=True ) local_video_predict_btn = gr.Button(value="Predict") with gr.Column(): video_gif = gr.Image( label="Input Clip", show_label=True, ) with gr.Column(): predictions = gr.Label(label="Predictions:", show_label=True) youtube_url_predict_btn.click( predict, inputs=[youtube_url, youtube_url_labels_text], outputs=[predictions, video_gif], ) local_video_predict_btn.click( predict, inputs=[video_file, local_video_labels_text], outputs=[predictions, video_gif], ) app.launch()