VideoCap / app.py
Kaori1707's picture
change app
77bae0e
raw
history blame contribute delete
No virus
2.67 kB
from decord import VideoReader
import torch
from transformers import AutoImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
import gradio as gr
device = "cuda" if torch.cuda.is_available() else "cpu"
# load pretrained processor, tokenizer, and model
image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = VisionEncoderDecoderModel.from_pretrained(
"Neleac/timesformer-gpt2-video-captioning"
).to(device)
with gr.Blocks() as demo:
demo.title = "Video Captioning"
gr.Markdown('# Video Captioning, demo by AISEED')
with gr.Row():
with gr.Column(scale=2):
video = gr.Video(label="Upload Video", format="mp4")
generate = gr.Button(value="Generate Caption")
with gr.Column(scale=1):
text = gr.Textbox(label="Caption", placeholder="Caption will appear here")
with gr.Accordion("Settings", open=True):
with gr.Row():
max_length = gr.Slider(
label="Max Length", minimum=10, maximum=100, value=20, step=1
)
min_length = gr.Slider(
label="Min Length", minimum=1, maximum=10, value=10, step=1
)
beam_size = gr.Slider(label="Beam size", minimum=1, maximum=8, value=8, step=1)
througputs = gr.Radio(
label="througputs", choices=[1, 2, 3], value=1
)
def generate_caption(video, max_length, min_length, beam_size, througputs):
# read video
container = VideoReader(video)
clip_len = model.config.encoder.num_frames
frames = container.get_batch(
range(0, len(container), len(container) // (througputs * clip_len))
).asnumpy()
frames = [frame for frame in frames[:-1]]
# process frames
# generate caption
gen_kwargs = {
"min_length": min_length,
"max_length": max_length,
"num_beams": beam_size,
}
pixel_values = image_processor(frames, return_tensors="pt").pixel_values.to(
device
)
tokens = model.generate(pixel_values, **gen_kwargs)
caption = tokenizer.batch_decode(tokens, skip_special_tokens=True)[0]
return caption
generate.click(
generate_caption,
inputs=[video, max_length, min_length, beam_size, througputs],
outputs=text,
)
if __name__ == "__main__":
demo.launch()