lavila / app.py
nateraw's picture
Update app.py
a5d0544
raw
history blame
6.44 kB
import sys
sys.path.insert(0, './')
import decord
import numpy as np
import torch
import os
from lavila.data.video_transforms import Permute
from lavila.data.datasets import get_frame_ids, video_loader_by_frames
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2
from lavila.models.tokenizer import MyGPT2Tokenizer
from collections import OrderedDict
import torch
import torchvision.transforms as transforms
import torchvision.transforms._transforms_video as transforms_video
import gradio as gr
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
seg_size = float(end_frame - start_frame - 1) / num_segments
seq = []
for i in range(num_segments):
start = int(np.round(seg_size * i) + start_frame)
end = int(np.round(seg_size * (i + 1)) + start_frame)
end = min(end, end_frame)
if jitter:
frame_id = np.random.randint(low=start, high=(end + 1))
else:
frame_id = (start + end) // 2
seq.append(frame_id)
return seq
def video_loader_by_frames(root, vid, frame_ids):
vr = decord.VideoReader(os.path.join(root, vid))
try:
frames = vr.get_batch(frame_ids).asnumpy()
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
except (IndexError, decord.DECORDError) as error:
print(error)
print("Erroneous video: ", vid)
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
return torch.stack(frames, dim=0)
def iter_clips(video_path, num_segments=4, stride_size=16):
# The video is represented by `num_seg=4` frames
vr = decord.VideoReader(video_path)
frame_sample_size = num_segments * stride_size
max_start_frame = len(vr) - frame_sample_size
curr_frame = 0
fps = vr.get_avg_fps()
while curr_frame == 0 or curr_frame < max_start_frame:
stop_frame = min(frame_sample_size, len(vr))
curr_sec, stop_sec = curr_frame / fps, stop_frame / fps
frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False)
frames = video_loader_by_frames('./', video_path, frame_ids)
yield curr_sec, stop_sec, frames
curr_frame += frame_sample_size
class Pipeline:
def __init__(self, path=""):
ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth')
ckpt = torch.load(ckpt_path, map_location='cpu')
state_dict = OrderedDict()
for k, v in ckpt['state_dict'].items():
state_dict[k.replace('module.', '')] = v
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2(
text_use_cls_token=False,
project_embed_dim=256,
gated_xattn=True,
timesformer_gated_xattn=False,
freeze_lm_vclm=False,
freeze_visual_vclm=False,
freeze_visual_vclm_temporal=False,
num_frames=4,
drop_path_rate=0.
)
self.model.load_state_dict(state_dict, strict=True)
self.model.to(self.device)
self.model.eval()
self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
crop_size = 224
self.val_transform = transforms.Compose([
Permute([3, 0, 1, 2]),
transforms.Resize(crop_size),
transforms.CenterCrop(crop_size),
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305])
])
def decode_one(self, generated_ids, tokenizer):
# get the index of <EOS>
if tokenizer.eos_token_id == tokenizer.bos_token_id:
if tokenizer.eos_token_id in generated_ids[1:].tolist():
eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1
else:
eos_id = len(generated_ids.tolist()) - 1
elif tokenizer.eos_token_id in generated_ids.tolist():
eos_id = generated_ids.tolist().index(tokenizer.eos_token_id)
else:
eos_id = len(generated_ids.tolist()) - 1
generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist())
return generated_text_str
def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10):
text = ""
MAX_ITERATIONS = 5
with torch.autocast(self.device):
for clip_idx, (start, stop, frames) in enumerate(iter_clips(video_path)):
text_to_add = f"{'-'*30} Predictions From: {start:2.3f}-{stop:2.3f} seconds {'-'*30}\n"
print(text_to_add)
text += text_to_add
frames = self.val_transform(frames).unsqueeze(0)
if self.device == 'cuda':
frames = frames.to(self.device).half()
with torch.no_grad():
image_features = self.model.encode_image(frames)
generated_text_ids, ppls = self.model.generate(
image_features,
self.tokenizer,
target=None, # free-form generation
max_text_length=max_text_length,
top_k=None,
top_p=top_p, # nucleus sampling
num_return_sequences=num_return_sequences, # number of candidates: 10
temperature=temperature,
early_stopping=True,
)
for i in range(num_return_sequences):
generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer)
text_to_add = '\t{}: {}\n'.format(i, generated_text_str)
print(text_to_add)
text += text_to_add
if (clip_idx+1) >= MAX_ITERATIONS:
return text
return text
interface = gr.Interface(
Pipeline(),
inputs=[
gr.Video(label='video_path'),
gr.Slider(0.0, 1.0, 0.7, label='temperature'),
gr.Slider(0.0, 1.0, 0.95, label='top_p'),
],
outputs='text',
examples=[['eating_spaghetti.mp4', 0.7, 0.95], ['assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', 0.7, 0.95]]
)
if __name__ == '__main__':
interface.launch(debug=True)