import sys sys.path.insert(0, './') import decord import numpy as np import torch import os from lavila.data.video_transforms import Permute from lavila.data.datasets import get_frame_ids, video_loader_by_frames from lavila.models.models import VCLM_OPENAI_TIMESFORMER_BASE_GPT2 from lavila.models.tokenizer import MyGPT2Tokenizer from collections import OrderedDict import torch import torchvision.transforms as transforms import torchvision.transforms._transforms_video as transforms_video import gradio as gr def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True): seg_size = float(end_frame - start_frame - 1) / num_segments seq = [] for i in range(num_segments): start = int(np.round(seg_size * i) + start_frame) end = int(np.round(seg_size * (i + 1)) + start_frame) end = min(end, end_frame) if jitter: frame_id = np.random.randint(low=start, high=(end + 1)) else: frame_id = (start + end) // 2 seq.append(frame_id) return seq def video_loader_by_frames(root, vid, frame_ids): vr = decord.VideoReader(os.path.join(root, vid)) try: frames = vr.get_batch(frame_ids).asnumpy() frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames] except (IndexError, decord.DECORDError) as error: print(error) print("Erroneous video: ", vid) frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))] return torch.stack(frames, dim=0) def iter_clips(video_path, num_segments=4, stride_size=16): # The video is represented by `num_seg=4` frames vr = decord.VideoReader(video_path) frame_sample_size = num_segments * stride_size max_start_frame = len(vr) - frame_sample_size curr_frame = 0 fps = vr.get_avg_fps() while curr_frame < max_start_frame: stop_frame = min(frame_sample_size, len(vr)) curr_sec, stop_sec = curr_frame / fps, stop_frame / fps frame_ids = get_frame_ids(curr_frame, stop_frame, num_segments=num_segments, jitter=False) frames = video_loader_by_frames('./', video_path, frame_ids) yield curr_sec, stop_sec, frames curr_frame += frame_sample_size class Pipeline: def __init__(self, path=""): ckpt_path = os.path.join(path, 'vclm_openai_timesformer_base_gpt2_base.pt_ego4d.jobid_319630.ep_0002.md5sum_68a71f.pth') ckpt = torch.load(ckpt_path, map_location='cpu') state_dict = OrderedDict() for k, v in ckpt['state_dict'].items(): state_dict[k.replace('module.', '')] = v self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.model = VCLM_OPENAI_TIMESFORMER_BASE_GPT2( text_use_cls_token=False, project_embed_dim=256, gated_xattn=True, timesformer_gated_xattn=False, freeze_lm_vclm=False, freeze_visual_vclm=False, freeze_visual_vclm_temporal=False, num_frames=4, drop_path_rate=0. ) self.model.load_state_dict(state_dict, strict=True) self.model.to(self.device) self.model.eval() self.tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True) crop_size = 224 self.val_transform = transforms.Compose([ Permute([3, 0, 1, 2]), transforms.Resize(crop_size), transforms.CenterCrop(crop_size), transforms_video.NormalizeVideo(mean=[108.3272985, 116.7460125, 104.09373615000001], std=[68.5005327, 66.6321579, 70.32316305]) ]) def decode_one(self, generated_ids, tokenizer): # get the index of if tokenizer.eos_token_id == tokenizer.bos_token_id: if tokenizer.eos_token_id in generated_ids[1:].tolist(): eos_id = generated_ids[1:].tolist().index(tokenizer.eos_token_id) + 1 else: eos_id = len(generated_ids.tolist()) - 1 elif tokenizer.eos_token_id in generated_ids.tolist(): eos_id = generated_ids.tolist().index(tokenizer.eos_token_id) else: eos_id = len(generated_ids.tolist()) - 1 generated_text_str = tokenizer.tokenizer.decode(generated_ids[1:eos_id].tolist()) return generated_text_str def __call__(self, video_path, temperature=0.7, top_p=0.95, max_text_length=77, num_return_sequences=10): text = "" MAX_ITERATIONS = 5 with torch.autocast(self.device): for i, start, stop, frames in enumerate(iter_clips(video_path)): text_to_add = f"{'-'*30} Predictions From: {start:2.3f}-{stop:2.3f} seconds {'-'*30}\n" print(text_to_add) text += text_to_add frames = self.val_transform(frames).unsqueeze(0) if self.device == 'cuda': frames = frames.to(self.device).half() with torch.no_grad(): image_features = self.model.encode_image(frames) generated_text_ids, ppls = self.model.generate( image_features, self.tokenizer, target=None, # free-form generation max_text_length=max_text_length, top_k=None, top_p=top_p, # nucleus sampling num_return_sequences=num_return_sequences, # number of candidates: 10 temperature=temperature, early_stopping=True, ) for i in range(num_return_sequences): generated_text_str = self.decode_one(generated_text_ids[i], self.tokenizer) text_to_add = '\t{}: {}\n'.format(i, generated_text_str) print(text_to_add) text += text_to_add if (i+1) >= MAX_ITERATIONS: return text return text interface = gr.Interface( Pipeline(), inputs=[ gr.Video(label='video_path'), gr.Slider(0.0, 1.0, 0.7, label='temperature'), gr.Slider(0.0, 1.0, 0.95, label='top_p'), ], outputs='text', examples=[['eating_spaghetti.mp4', 0.7, 0.95], ['assets/3c0dffd0-e38e-4643-bc48-d513943dc20b_012_014.mp4', 0.7, 0.95]] ) if __name__ == '__main__': interface.launch(debug=True)