Spaces:
Runtime error
Runtime error
import gradio as gr | |
import argparse | |
import shutil | |
import os | |
from video_keyframe_detector.cli import keyframeDetection | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
hf_path = 'tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B' | |
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True) | |
config = model.config | |
tokenizer = AutoTokenizer.from_pretrained(hf_path, use_fast=False, model_max_length = config.tokenizer_model_max_length,padding_side = config.tokenizer_padding_side) | |
def extract_keyframes(video_path, num_keyframes=12): | |
video_id = video_path.split('/')[-1].strip().split('.')[0] | |
os.makedirs("temp", exist_ok=True) | |
keyframeDetection(video_path, "temp", 0.2) | |
video_frame_list = sorted(os.listdir(os.path.join("temp", "keyFrames")), key=lambda x: int(x.split('.')[0][8:])) | |
os.makedirs(os.path.join("video_frames", video_id), exist_ok=True) | |
selected_frame_idx_set = set(np.linspace(1, len(video_frame_list) - 1, num_keyframes).astype(int)) | |
cnt = 0 | |
for i in range(len(video_frame_list)): | |
if i in selected_frame_idx_set: | |
source_file = os.path.join("temp", "keyFrames", video_frame_list[i]) | |
target_file = os.path.join("video_frames", video_id, f"frame_{cnt}.jpg") | |
shutil.copyfile(source_file, target_file) | |
cnt += 1 | |
shutil.rmtree("temp", ignore_errors=True) | |
def concatenate_frames(video_path): | |
os.makedirs("concatenated_frames", exist_ok=True) | |
video_id = video_path.split('/')[-1].strip().split('.')[0] | |
image_frame_dir = os.path.join("video_frames", video_id) | |
image_frame_list = sorted(os.listdir(os.path.join(image_frame_dir)), key=lambda x: int(x.split('.')[0].split('_')[1])) | |
img_list = [] | |
for image_frame in image_frame_list: | |
img_frame = cv2.imread(os.path.join(image_frame_dir, image_frame)) | |
img_list.append(img_frame) | |
img_row1 = cv2.hconcat(img_list[:4]) | |
img_row2 = cv2.hconcat(img_list[4:8]) | |
img_row3 = cv2.hconcat(img_list[8:12]) | |
img_v = cv2.vconcat([img_row1, img_row2, img_row3]) | |
cv2.imwrite(os.path.join("concatenated_frames", f"{video_id}.jpg"), img_v) | |
def image_parser(args): | |
out = args.image_file.split(args.sep) | |
return out | |
def generate_video_caption(video_path): | |
video_id = video_path.split('/')[-1].strip().split('.')[0] | |
image_file = os.path.join("concatenated_frames", f"{video_id}.jpg") | |
prompt = "In a short sentence, describe the process in the video." | |
output_text, generation_time = model.chat(prompt=prompt, image=image_file, tokenizer=tokenizer) | |
return output_text | |
def clean_files_and_folders(): | |
shutil.rmtree("concatenated_frames") | |
shutil.rmtree("video_frames") | |
def video_to_text(video_file): | |
video_path = video_file.name | |
extract_keyframes(video_path) | |
concatenate_frames(video_path) | |
video_caption = generate_video_caption(video_path) | |
clean_files_and_folders() | |
return video_caption | |
iface = gr.Interface( | |
fn=video_to_text, | |
inputs=gr.File(file_types=["video"]), | |
outputs="text", | |
title="MAMA Video-Text Generation Pipeline", | |
description="Upload a video and get the description. Due to limited budget, we can only use TinyLLaVA on CPUs. Please only try videos which are less than 1MB. Thank you so much and Welcome to MAMA!" | |
) | |
iface.launch(share=True) |