File size: 4,355 Bytes
0fb1163
 
 
 
 
 
 
 
 
 
 
 
 
 
dea9d86
0fb1163
dea9d86
 
 
 
 
 
 
 
 
 
 
 
0fb1163
 
 
 
 
90e2561
 
 
 
0fb1163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0bb375d
0fb1163
 
 
 
90e2561
 
 
 
 
 
 
0fb1163
88226d6
fbe1110
88226d6
0fb1163
 
 
 
 
 
 
fbe1110
0fb1163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import gradio as gr
from transformers import AutoProcessor, AutoModel
from utils import (
    convert_frames_to_gif,
    download_youtube_video,
    get_num_total_frames,
    sample_frames_from_video_file,
)

FRAME_SAMPLING_RATE = 4
DEFAULT_MODEL = "microsoft/xclip-base-patch16-zero-shot"

VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS = [
    "microsoft/xclip-base-patch32",
    "microsoft/xclip-base-patch16-zero-shot",
    "microsoft/xclip-base-patch16-kinetics-600",
    "microsoft/xclip-large-patch14ft/xclip-base-patch32-16-frames",
    "microsoft/xclip-large-patch14",
    "microsoft/xclip-base-patch16-hmdb-4-shot",
    "microsoft/xclip-base-patch16-16-frames",
    "microsoft/xclip-base-patch16-hmdb-2-shot",
    "microsoft/xclip-base-patch16-ucf-2-shot",
    "microsoft/xclip-base-patch16-ucf-8-shot",
    "microsoft/xclip-base-patch16",
    "microsoft/xclip-base-patch16-hmdb-8-shot",
    "microsoft/xclip-base-patch16-hmdb-16-shot",
    "microsoft/xclip-base-patch16-ucf-16-shot",
]

processor = AutoProcessor.from_pretrained(DEFAULT_MODEL)
model = AutoModel.from_pretrained(DEFAULT_MODEL)

# def select_model(model_name):
#     global processor, model
#     processor = AutoProcessor.from_pretrained(model_name)
#     model = AutoModel.from_pretrained(model_name)


def predict(youtube_url_or_file_path, labels_text):

    if youtube_url_or_file_path.startswith("http"):
        video_path = download_youtube_video(youtube_url_or_file_path)
    else:
        video_path = youtube_url_or_file_path
    num_total_frames = get_num_total_frames(video_path)
    num_model_input_frames = model.config.vision_config.num_frames
    if num_total_frames < FRAME_SAMPLING_RATE * num_model_input_frames:
        frame_sampling_rate = num_total_frames // num_model_input_frames
    else:
        frame_sampling_rate = FRAME_SAMPLING_RATE

    labels = labels_text.split(",")

    frames = sample_frames_from_video_file(
        video_path, num_model_input_frames, frame_sampling_rate
    )
    gif_path = convert_frames_to_gif(frames, save_path="video.gif")

    inputs = processor(
        text=labels, videos=list(frames), return_tensors="pt", padding=True
    )
    with torch.no_grad():
        outputs = model(**inputs)

    probs = outputs.logits_per_video[0].softmax(dim=-1).cpu().numpy()
    label_to_prob = {}
    for ind, label in enumerate(labels):
        label_to_prob[label] = float(probs[ind])

    return label_to_prob, gif_path


app = gr.Blocks()
with app:
    gr.Markdown(
        "# **<p align='center'> PROTOG - VIOLENCE DETECTION MODULE</p>**"
    )

    with gr.Row():
        with gr.Column():
            # model_names_dropdown = gr.Dropdown(
            #     choices=VALID_ZEROSHOT_VIDEOCLASSIFICATION_MODELS,
            #     label="Model:",
            #     show_label=True,
            #     value=DEFAULT_MODEL,
            # )
            # model_names_dropdown.change(fn=select_model, inputs=model_names_dropdown)
            with gr.Tab(label="Youtube URL"):
                gr.Markdown(
                    "### **Enter Youtube URL**"
                )
                youtube_url = gr.Textbox(label="Youtube URL:", show_label=True)
                youtube_url_labels_text = gr.Textbox(
                    label="Labels Text:", show_label=True
                )
                youtube_url_predict_btn = gr.Button(value="Predict")
            with gr.Tab(label="Local File"):
                gr.Markdown(
                    "### **Video Upload**"
                )
                video_file = gr.Video(label="Video File:", show_label=True)
                local_video_labels_text = gr.Textbox(
                    label="Labels Text:", show_label=True
                )
                local_video_predict_btn = gr.Button(value="Predict")
        with gr.Column():
            video_gif = gr.Image(
                label="Input Clip",
                show_label=True,
            )
        with gr.Column():
            predictions = gr.Label(label="Predictions:", show_label=True)

    youtube_url_predict_btn.click(
        predict,
        inputs=[youtube_url, youtube_url_labels_text],
        outputs=[predictions, video_gif],
    )
    local_video_predict_btn.click(
        predict,
        inputs=[video_file, local_video_labels_text],
        outputs=[predictions, video_gif],
    )

app.launch()