UniVTG / run_on_video /data_utils.py
KevinQHLin's picture
Upload 60 files
9d0a4ae
raw
history blame
No virus
6.11 kB
import torch
import os
import numpy as np
import ffmpeg
import math
from run_on_video import clip
class ClipFeatureExtractor:
def __init__(self, framerate=1/2, size=224, centercrop=True, model_name_or_path="ViT-B/32", device="cuda"):
self.video_loader = VideoLoader(framerate=framerate, size=size, centercrop=centercrop)
print("Loading CLIP models")
self.clip_extractor, _ = clip.load(model_name_or_path, device=device, jit=False)
self.tokenizer = clip.tokenize
self.video_preprocessor = Preprocessing()
self.device = device
@torch.no_grad()
def encode_video(self, video_path: str, bsz=60):
video_frames = self.video_loader.read_video_from_file(video_path) # (T, H, W, 3)
video_frames = self.video_preprocessor(video_frames)
n_frames = len(video_frames)
n_batch = int(math.ceil(n_frames / bsz))
video_features = []
for i in range(n_batch):
st_idx = i * bsz
ed_idx = (i+1) * bsz
_video_frames = video_frames[st_idx:ed_idx].to(self.device)
_video_features = self.clip_extractor.encode_image(_video_frames)
video_features.append(_video_features)
video_features = torch.cat(video_features, dim=0)
return video_features # (T=#frames, d) torch tensor
@torch.no_grad()
def encode_text(self, text_list, bsz=60):
n_text = len(text_list)
n_batch = int(math.ceil(n_text / bsz))
text_features = []
for i in range(n_batch):
st_idx = i * bsz
ed_idx = (i+1) * bsz
encoded_texts = self.tokenizer(text_list[st_idx:ed_idx], context_length=77).to(self.device)
output = self.clip_extractor.encode_text(encoded_texts)
valid_lengths = (encoded_texts != 0).sum(1).tolist()
batch_last_hidden_states = output["last_hidden_state"]
for j, valid_len in enumerate(valid_lengths):
text_features.append(batch_last_hidden_states[j, :valid_len])
return text_features # List([L_j, d]) torch tensor
def convert_to_float(frac_str):
try:
return float(frac_str)
except ValueError:
try:
num, denom = frac_str.split('/')
except ValueError:
return None
try:
leading, num = num.split(' ')
except ValueError:
return float(num) / float(denom)
if float(leading) < 0:
sign_mult = -1
else:
sign_mult = 1
return float(leading) + sign_mult * (float(num) / float(denom))
class Normalize(object):
def __init__(self, mean, std):
self.mean = torch.FloatTensor(mean).view(1, 3, 1, 1)
self.std = torch.FloatTensor(std).view(1, 3, 1, 1)
def __call__(self, tensor):
tensor = (tensor - self.mean) / (self.std + 1e-8)
return tensor
class Preprocessing(object):
def __init__(self):
self.norm = Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
def __call__(self, tensor):
tensor = tensor / 255.0
tensor = self.norm(tensor)
return tensor
class VideoLoader:
"""Pytorch video loader.
Copied and modified from:
https://github.com/linjieli222/HERO_Video_Feature_Extractor/blob/main/clip/video_loader.py
"""
def __init__(
self,
framerate=1/2,
size=224,
centercrop=True,
):
self.centercrop = centercrop
self.size = size
self.framerate = framerate
def _get_video_info(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
try:
frames_length = int(video_stream['nb_frames'])
duration = float(video_stream['duration'])
except Exception:
frames_length, duration = -1, -1
info = {"duration": duration, "frames_length": frames_length,
"fps": fps, "height": height, "width": width}
return info
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def read_video_from_file(self, video_path):
try:
info = self._get_video_info(video_path)
h, w = info["height"], info["width"]
except Exception:
print('ffprobe failed at: {}'.format(video_path))
return {'video': torch.zeros(1), 'input': video_path,
'info': {}}
height, width = self._get_output_dim(h, w)
try:
duration = info["duration"]
fps = self.framerate
if duration > 0 and duration < 1/fps+0.1:
fps = 2/max(int(duration), 1)
print(duration, fps)
except Exception:
fps = self.framerate
cmd = (
ffmpeg
.input(video_path)
.filter('fps', fps=fps)
.filter('scale', width, height)
)
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
out, _ = (
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape(
[-1, height, width, 3])
video = torch.from_numpy(video.astype('float32'))
video = video.permute(0, 3, 1, 2)
return video