Spaces:
Runtime error
Runtime error
import sys | |
sys.path.insert(0, './') | |
import decord | |
import numpy as np | |
import torch | |
import os | |
from lavila.data.video_transforms import Permute | |
from lavila.data.datasets import get_frame_ids, video_loader_by_frames | |
from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2 | |
from lavila.models.tokenizer import MyGPT2Tokenizer | |
from collections import OrderedDict | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.transforms._transforms_video as transforms_video | |
import gradio as gr | |
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): | |
seg_size = float(end_frame - start_frame - 1) / num_segments | |
seq = [] | |
for i in range(num_segments): | |
start = int(np.round(seg_size * i) + start_frame) | |
end = int(np.round(seg_size * (i + 1)) + start_frame) | |
end = min(end, end_frame) | |
if jitter: | |
frame_id = np.random.randint(low=start, high=(end + 1)) | |
else: | |
frame_id = (start + end) // 2 | |
seq.append(frame_id) | |
return seq | |
def video_loader_by_frames(root, vid, frame_ids): | |
vr = decord.VideoReader(os.path.join(root, vid)) | |
try: | |
frames = vr.get_batch(frame_ids).asnumpy() | |
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] | |
except (IndexError, decord.DECORDError) as error: | |
print(error) | |
print("Erroneous video: ", vid) | |
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] | |
return torch.stack(frames, dim=0) | |
def iter_clips(video_path, num_segments=4, stride_size=16): | |
# The video is represented by `num_seg=4` frames | |
vr = decord.VideoReader(video_path) | |
frame_sample_size = num_segments * stride_size | |
max_start_frame = len(vr) - frame_sample_size | |
curr_frame = 0 | |
fps = vr.get_avg_fps() | |
while curr_frame == 0 or curr_frame < max_start_frame: | |
stop_frame = min(frame_sample_size, len(vr)) | |
curr_sec, stop_sec = curr_frame / fps, stop_frame / fps | |
frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False) | |
frames = video_loader_by_frames('./', video_path, frame_ids) | |
yield curr_sec, stop_sec, frames | |
curr_frame += frame_sample_size | |
class Pipeline: | |
def __init__(self, path=""): | |
ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth') | |
ckpt = torch.load(ckpt_path, map_location='cpu') | |
state_dict = OrderedDict() | |
for k, v in ckpt['state_dict'].items(): | |
state_dict[k.replace('module.', '')] = v | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2( | |
text_use_cls_token=False, | |
project_embed_dim=256, | |
gated_xattn=True, | |
timesformer_gated_xattn=False, | |
freeze_lm_vclm=False, | |
freeze_visual_vclm=False, | |
freeze_visual_vclm_temporal=False, | |
num_frames=4, | |
drop_path_rate=0. | |
) | |
self.model.load_state_dict(state_dict, strict=True) | |
self.model.to(self.device) | |
self.model.eval() | |
self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True) | |
crop_size = 224 | |
self.val_transform = transforms.Compose([ | |
Permute([3, 0, 1, 2]), | |
transforms.Resize(crop_size), | |
transforms.CenterCrop(crop_size), | |
transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]) | |
]) | |
def decode_one(self, generated_ids, tokenizer): | |
# get the index of <EOS> | |
if tokenizer.eos_token_id == tokenizer.bos_token_id: | |
if tokenizer.eos_token_id in generated_ids[1:].tolist(): | |
eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1 | |
else: | |
eos_id = len(generated_ids.tolist()) - 1 | |
elif tokenizer.eos_token_id in generated_ids.tolist(): | |
eos_id = generated_ids.tolist().index(tokenizer.eos_token_id) | |
else: | |
eos_id = len(generated_ids.tolist()) - 1 | |
generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist()) | |
return generated_text_str | |
def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10): | |
text = "" | |
MAX_ITERATIONS = 5 | |
with torch.autocast(self.device): | |
for clip_idx, (start, stop, frames) in enumerate(iter_clips(video_path)): | |
text_to_add = f"{'-'*30} Predictions From: {start:2.3f}-{stop:2.3f} seconds {'-'*30}\n" | |
print(text_to_add) | |
text += text_to_add | |
frames = self.val_transform(frames).unsqueeze(0) | |
if self.device == 'cuda': | |
frames = frames.to(self.device).half() | |
with torch.no_grad(): | |
image_features = self.model.encode_image(frames) | |
generated_text_ids, ppls = self.model.generate( | |
image_features, | |
self.tokenizer, | |
target=None, # free-form generation | |
max_text_length=max_text_length, | |
top_k=None, | |
top_p=top_p, # nucleus sampling | |
num_return_sequences=num_return_sequences, # number of candidates: 10 | |
temperature=temperature, | |
early_stopping=True, | |
) | |
for i in range(num_return_sequences): | |
generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer) | |
text_to_add = '\t{}: {}\n'.format(i, generated_text_str) | |
print(text_to_add) | |
text += text_to_add | |
if (clip_idx+1) >= MAX_ITERATIONS: | |
return text | |
return text | |
interface = gr.Interface( | |
Pipeline(), | |
inputs=[ | |
gr.Video(label='video_path'), | |
gr.Slider(0.0, 1.0, 0.7, label='temperature'), | |
gr.Slider(0.0, 1.0, 0.95, label='top_p'), | |
], | |
outputs='text', | |
examples=[['eating_spaghetti.mp4', 0.7, 0.95], ['assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', 0.7, 0.95]] | |
) | |
if __name__ == '__main__': | |
interface.launch(debug=True) |