Spaces:
Runtime error
Runtime error
import cv2 | |
import gradio as gr | |
from PIL import Image | |
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor | |
model_id = "BridgeTower/bridgetower-large-itm-mlm" | |
processor = BridgeTowerProcessor.from_pretrained(model_id) | |
model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id) | |
# Process a frame | |
def process_frame(image, texts): | |
scores = {} | |
texts = texts.split(",") | |
for t in texts: | |
encoding = processor(image, t, return_tensors="pt") | |
outputs = model(**encoding) | |
scores[t] = "{:.2f}".format(outputs.logits[0, 1].item()) | |
# sort scores in descending order | |
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True)) | |
return scores | |
# Process a video | |
def process(video, text, sample_rate, min_score): | |
video = cv2.VideoCapture(video) | |
fps = round(video.get(cv2.CAP_PROP_FPS)) | |
frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
length = frames // fps | |
print(f"{fps} fps, {frames} frames, {length} seconds") | |
frame_count = 0 | |
clips = [] | |
clip_images = [] | |
clip_started = False | |
while True: | |
ret, frame = video.read() | |
if not ret: | |
break | |
if frame_count % (fps * sample_rate) == 0: | |
frame = Image.fromarray(frame) | |
score = process_frame(frame, text) | |
# print(f"{frame_count} {scores}") | |
if float(score[text]) > min_score: | |
if clip_started: | |
end_time = frame_count / fps | |
else: | |
clip_started = True | |
start_time = frame_count / fps | |
end_time = start_time | |
start_score = score[text] | |
clip_images.append(frame) | |
elif clip_started: | |
clip_started = False | |
end_time = frame_count / fps | |
clips.append((start_score, start_time, end_time)) | |
frame_count += 1 | |
return clip_images, clips | |
# Inputs | |
video = gr.Video(label="Video") | |
text = gr.Text(label="Text query") | |
sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)") | |
min_score = gr.Number(value=3, label="Minimum score") | |
# Output | |
gallery = gr.Gallery(label="Images") | |
clips = gr.Text(label="Clips (score, start time, end time)") | |
description = "This Space lets you run semantic search on a video." | |
iface = gr.Interface( | |
description=description, | |
fn=process, | |
inputs=[video, text, sample_rate, min_score], | |
outputs=[gallery, clips], | |
examples=[ | |
[ | |
"video.mp4", | |
"wild bears", | |
5, | |
3, | |
] | |
], | |
allow_flagging="never", | |
) | |
iface.launch() | |