Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
import zipfile | |
import imageio | |
import tensorflow as tf | |
from tensorflow import keras | |
from utils import read_video, frame_sampling | |
from utils import num_frames, patch_size, input_size | |
from labels import K400_label_map, SSv2_label_map | |
LABEL_MAPS = { | |
'K400': K400_label_map, | |
'SSv2': SSv2_label_map, | |
} | |
ALL_MODELS = [ | |
'TFUniFormerV2_K400_K710_L14_16x224', | |
'TFUniFormerV2_SSV2_B16_16x224', | |
] | |
sample_example = [ | |
["examples/k400.mp4", ALL_MODELS[0]], | |
["examples/ssv2.mp4", ALL_MODELS[1]], | |
] | |
def get_model(model_type): | |
model_path = keras.utils.get_file( | |
origin=f'https://github.com/innat/UniFormerV2/releases/download/v1.1/{model_type}.zip', | |
) | |
with zipfile.ZipFile(model_path, 'r') as zip_ref: | |
zip_ref.extractall('./') | |
model = keras.models.load_model(model_type) | |
if 'K400' in model_type: | |
data_type = 'K400' | |
else: | |
data_type = 'SSv2' | |
label_map = LABEL_MAPS.get(data_type) | |
label_map = {v: k for k, v in label_map.items()} | |
return model, label_map | |
def inference(video_file, model_type): | |
# get sample data | |
container = read_video(video_file) | |
frames = frame_sampling(container, num_frames=num_frames) | |
# get models | |
model, label_map = get_model(model_type) | |
model.trainable = False | |
# inference on model | |
outputs = model(frames[None, ...], training=False) | |
probabilities = tf.nn.softmax(outputs).numpy().squeeze(0) | |
confidences = { | |
label_map[i]: float(probabilities[i]) for i in np.argsort(probabilities)[::-1] | |
} | |
return confidences | |
def main(): | |
iface = gr.Interface( | |
fn=inference, | |
inputs=[ | |
gr.Video(type="file", label="Input Video"), | |
gr.Dropdown( | |
choices=ALL_MODELS, | |
label="Model" | |
) | |
], | |
outputs=gr.Label(num_top_classes=3, label='scores'), | |
examples=sample_example, | |
title="UniFormerV2: Spatiotemporal Learning.", | |
description="Keras reimplementation of <a href='https://github.com/innat/UniFormerV2'>UniFormerV2</a> is presented here." | |
) | |
iface.launch() | |
if __name__ == '__main__': | |
main() |