UniVTG / run_on_video /video_loader.py
KevinQHLin's picture
Upload 60 files
9d0a4ae
raw
history blame
No virus
4.27 kB
import torch as th
from torch.utils.data import Dataset
import pandas as pd
import os
import numpy as np
import ffmpeg
import math
def convert_to_float(frac_str):
try:
return float(frac_str)
except ValueError:
try:
num, denom = frac_str.split('/')
except ValueError:
return None
try:
leading, num = num.split(' ')
except ValueError:
return float(num) / float(denom)
if float(leading) < 0:
sign_mult = -1
else:
sign_mult = 1
return float(leading) + sign_mult * (float(num) / float(denom))
class VideoLoader(Dataset):
"""Pytorch video loader."""
def __init__(
self,
vid_path,
framerate=1,
size=112,
centercrop=False,
overwrite=False,
model_version="ViT-B/32",
):
"""
Args:
"""
self.vid_path = vid_path
self.centercrop = centercrop
self.size = size
self.framerate = framerate
self.overwrite = overwrite
self.model_version = model_version
def __len__(self):
return 1
def _get_video_info(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
fps = math.floor(convert_to_float(video_stream['avg_frame_rate']))
try:
frames_length = int(video_stream['nb_frames'])
duration = float(video_stream['duration'])
except Exception:
frames_length, duration = -1, -1
info = {"duration": duration, "frames_length": frames_length,
"fps": fps, "height": height, "width": width}
return info
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def __getitem__(self, id):
video_path = self.vid_path
load_flag = os.path.isfile(video_path)
if load_flag:
try:
info = self._get_video_info(video_path)
h, w = info["height"], info["width"]
except Exception:
print('ffprobe failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,'info': {}}
try:
height, width = self._get_output_dim(h, w)
try:
duration = info["duration"]
fps = self.framerate
if duration > 0 and duration < 1/fps+0.1:
fps = 2/max(int(duration), 1)
print(duration, fps)
except Exception:
fps = self.framerate
cmd = (
ffmpeg
.input(video_path)
.filter('fps', fps=fps)
.filter('scale', width, height)
# .filter('scale', self.size, self.size)
)
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
out, _ = (
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape(
[-1, height, width, 3])
video = th.from_numpy(video.astype('float32'))
video = video.permute(0, 3, 1, 2)
except:
return {'video': th.zeros(1), 'input': video_path,'info': {}}
else:
video = th.zeros(1)
return {'video': video, 'input': video_path}