Julien Simon
Add message
5995a28
import cv2
import gradio as gr
from PIL import Image
from transformers import BridgeTowerForImageAndTextRetrieval, BridgeTowerProcessor
model_id = "BridgeTower/bridgetower-large-itm-mlm"
processor = BridgeTowerProcessor.from_pretrained(model_id)
model = BridgeTowerForImageAndTextRetrieval.from_pretrained(model_id)
# Process a frame
def process_frame(image, texts):
scores = {}
texts = texts.split(",")
for t in texts:
encoding = processor(image, t, return_tensors="pt")
outputs = model(**encoding)
scores[t] = "{:.2f}".format(outputs.logits[0, 1].item())
# sort scores in descending order
scores = dict(sorted(scores.items(), key=lambda item: item[1], reverse=True))
return scores
# Process a video
def process(video, text, sample_rate, min_score):
video = cv2.VideoCapture(video)
fps = round(video.get(cv2.CAP_PROP_FPS))
frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
length = frames // fps
print(f"{fps} fps, {frames} frames, {length} seconds")
frame_count = 0
clips = []
clip_images = []
clip_started = False
while True:
ret, frame = video.read()
if not ret:
break
if frame_count % (fps * sample_rate) == 0:
frame = Image.fromarray(frame)
score = process_frame(frame, text)
# print(f"{frame_count} {scores}")
if float(score[text]) > min_score:
if clip_started:
end_time = frame_count / fps
else:
clip_started = True
start_time = frame_count / fps
end_time = start_time
start_score = score[text]
clip_images.append(frame)
elif clip_started:
clip_started = False
end_time = frame_count / fps
clips.append((start_score, start_time, end_time))
frame_count += 1
return clip_images, clips
# Inputs
video = gr.Video(label="Video")
text = gr.Text(label="Text query")
sample_rate = gr.Number(value=5, label="Sample rate (1 frame every 'n' seconds)")
min_score = gr.Number(value=3, label="Minimum score")
# Output
gallery = gr.Gallery(label="Images")
clips = gr.Text(label="Clips (score, start time, end time)")
description = "This Space lets you run semantic search on a video."
iface = gr.Interface(
description=description,
fn=process,
inputs=[video, text, sample_rate, min_score],
outputs=[gallery, clips],
examples=[
[
"video.mp4",
"wild bears",
5,
3,
]
],
allow_flagging="never",
)
iface.launch()