Virtual-Try-On / densepose /data /video /video_keyframe_dataset.py
IDM-VTON
update IDM-VTON Demo
938e515
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import csv
import logging
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Union
import av
import torch
from torch.utils.data.dataset import Dataset
from detectron2.utils.file_io import PathManager
from ..utils import maybe_prepend_base_path
from .frame_selector import FrameSelector, FrameTsList
FrameList = List[av.frame.Frame] # pyre-ignore[16]
FrameTransform = Callable[[torch.Tensor], torch.Tensor]
def list_keyframes(video_fpath: str, video_stream_idx: int = 0) -> FrameTsList:
"""
Traverses all keyframes of a video file. Returns a list of keyframe
timestamps. Timestamps are counts in timebase units.
Args:
video_fpath (str): Video file path
video_stream_idx (int): Video stream index (default: 0)
Returns:
List[int]: list of keyframe timestaps (timestamp is a count in timebase
units)
"""
try:
with PathManager.open(video_fpath, "rb") as io:
container = av.open(io, mode="r")
stream = container.streams.video[video_stream_idx]
keyframes = []
pts = -1
# Note: even though we request forward seeks for keyframes, sometimes
# a keyframe in backwards direction is returned. We introduce tolerance
# as a max count of ignored backward seeks
tolerance_backward_seeks = 2
while True:
try:
container.seek(pts + 1, backward=False, any_frame=False, stream=stream)
except av.AVError as e:
# the exception occurs when the video length is exceeded,
# we then return whatever data we've already collected
logger = logging.getLogger(__name__)
logger.debug(
f"List keyframes: Error seeking video file {video_fpath}, "
f"video stream {video_stream_idx}, pts {pts + 1}, AV error: {e}"
)
return keyframes
except OSError as e:
logger = logging.getLogger(__name__)
logger.warning(
f"List keyframes: Error seeking video file {video_fpath}, "
f"video stream {video_stream_idx}, pts {pts + 1}, OS error: {e}"
)
return []
packet = next(container.demux(video=video_stream_idx))
if packet.pts is not None and packet.pts <= pts:
logger = logging.getLogger(__name__)
logger.warning(
f"Video file {video_fpath}, stream {video_stream_idx}: "
f"bad seek for packet {pts + 1} (got packet {packet.pts}), "
f"tolerance {tolerance_backward_seeks}."
)
tolerance_backward_seeks -= 1
if tolerance_backward_seeks == 0:
return []
pts += 1
continue
tolerance_backward_seeks = 2
pts = packet.pts
if pts is None:
return keyframes
if packet.is_keyframe:
keyframes.append(pts)
return keyframes
except OSError as e:
logger = logging.getLogger(__name__)
logger.warning(
f"List keyframes: Error opening video file container {video_fpath}, " f"OS error: {e}"
)
except RuntimeError as e:
logger = logging.getLogger(__name__)
logger.warning(
f"List keyframes: Error opening video file container {video_fpath}, "
f"Runtime error: {e}"
)
return []
def read_keyframes(
video_fpath: str, keyframes: FrameTsList, video_stream_idx: int = 0
) -> FrameList: # pyre-ignore[11]
"""
Reads keyframe data from a video file.
Args:
video_fpath (str): Video file path
keyframes (List[int]): List of keyframe timestamps (as counts in
timebase units to be used in container seek operations)
video_stream_idx (int): Video stream index (default: 0)
Returns:
List[Frame]: list of frames that correspond to the specified timestamps
"""
try:
with PathManager.open(video_fpath, "rb") as io:
container = av.open(io)
stream = container.streams.video[video_stream_idx]
frames = []
for pts in keyframes:
try:
container.seek(pts, any_frame=False, stream=stream)
frame = next(container.decode(video=0))
frames.append(frame)
except av.AVError as e:
logger = logging.getLogger(__name__)
logger.warning(
f"Read keyframes: Error seeking video file {video_fpath}, "
f"video stream {video_stream_idx}, pts {pts}, AV error: {e}"
)
container.close()
return frames
except OSError as e:
logger = logging.getLogger(__name__)
logger.warning(
f"Read keyframes: Error seeking video file {video_fpath}, "
f"video stream {video_stream_idx}, pts {pts}, OS error: {e}"
)
container.close()
return frames
except StopIteration:
logger = logging.getLogger(__name__)
logger.warning(
f"Read keyframes: Error decoding frame from {video_fpath}, "
f"video stream {video_stream_idx}, pts {pts}"
)
container.close()
return frames
container.close()
return frames
except OSError as e:
logger = logging.getLogger(__name__)
logger.warning(
f"Read keyframes: Error opening video file container {video_fpath}, OS error: {e}"
)
except RuntimeError as e:
logger = logging.getLogger(__name__)
logger.warning(
f"Read keyframes: Error opening video file container {video_fpath}, Runtime error: {e}"
)
return []
def video_list_from_file(video_list_fpath: str, base_path: Optional[str] = None):
"""
Create a list of paths to video files from a text file.
Args:
video_list_fpath (str): path to a plain text file with the list of videos
base_path (str): base path for entries from the video list (default: None)
"""
video_list = []
with PathManager.open(video_list_fpath, "r") as io:
for line in io:
video_list.append(maybe_prepend_base_path(base_path, str(line.strip())))
return video_list
def read_keyframe_helper_data(fpath: str):
"""
Read keyframe data from a file in CSV format: the header should contain
"video_id" and "keyframes" fields. Value specifications are:
video_id: int
keyframes: list(int)
Example of contents:
video_id,keyframes
2,"[1,11,21,31,41,51,61,71,81]"
Args:
fpath (str): File containing keyframe data
Return:
video_id_to_keyframes (dict: int -> list(int)): for a given video ID it
contains a list of keyframes for that video
"""
video_id_to_keyframes = {}
try:
with PathManager.open(fpath, "r") as io:
csv_reader = csv.reader(io)
header = next(csv_reader)
video_id_idx = header.index("video_id")
keyframes_idx = header.index("keyframes")
for row in csv_reader:
video_id = int(row[video_id_idx])
assert (
video_id not in video_id_to_keyframes
), f"Duplicate keyframes entry for video {fpath}"
video_id_to_keyframes[video_id] = (
[int(v) for v in row[keyframes_idx][1:-1].split(",")]
if len(row[keyframes_idx]) > 2
else []
)
except Exception as e:
logger = logging.getLogger(__name__)
logger.warning(f"Error reading keyframe helper data from {fpath}: {e}")
return video_id_to_keyframes
class VideoKeyframeDataset(Dataset):
"""
Dataset that provides keyframes for a set of videos.
"""
_EMPTY_FRAMES = torch.empty((0, 3, 1, 1))
def __init__(
self,
video_list: List[str],
category_list: Union[str, List[str], None] = None,
frame_selector: Optional[FrameSelector] = None,
transform: Optional[FrameTransform] = None,
keyframe_helper_fpath: Optional[str] = None,
):
"""
Dataset constructor
Args:
video_list (List[str]): list of paths to video files
category_list (Union[str, List[str], None]): list of animal categories for each
video file. If it is a string, or None, this applies to all videos
frame_selector (Callable: KeyFrameList -> KeyFrameList):
selects keyframes to process, keyframes are given by
packet timestamps in timebase counts. If None, all keyframes
are selected (default: None)
transform (Callable: torch.Tensor -> torch.Tensor):
transforms a batch of RGB images (tensors of size [B, 3, H, W]),
returns a tensor of the same size. If None, no transform is
applied (default: None)
"""
if type(category_list) == list:
self.category_list = category_list
else:
self.category_list = [category_list] * len(video_list)
assert len(video_list) == len(
self.category_list
), "length of video and category lists must be equal"
self.video_list = video_list
self.frame_selector = frame_selector
self.transform = transform
self.keyframe_helper_data = (
read_keyframe_helper_data(keyframe_helper_fpath)
if keyframe_helper_fpath is not None
else None
)
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
Gets selected keyframes from a given video
Args:
idx (int): video index in the video list file
Returns:
A dictionary containing two keys:
images (torch.Tensor): tensor of size [N, H, W, 3] or of size
defined by the transform that contains keyframes data
categories (List[str]): categories of the frames
"""
categories = [self.category_list[idx]]
fpath = self.video_list[idx]
keyframes = (
list_keyframes(fpath)
if self.keyframe_helper_data is None or idx not in self.keyframe_helper_data
else self.keyframe_helper_data[idx]
)
transform = self.transform
frame_selector = self.frame_selector
if not keyframes:
return {"images": self._EMPTY_FRAMES, "categories": []}
if frame_selector is not None:
keyframes = frame_selector(keyframes)
frames = read_keyframes(fpath, keyframes)
if not frames:
return {"images": self._EMPTY_FRAMES, "categories": []}
frames = np.stack([frame.to_rgb().to_ndarray() for frame in frames])
frames = torch.as_tensor(frames, device=torch.device("cpu"))
frames = frames[..., [2, 1, 0]] # RGB -> BGR
frames = frames.permute(0, 3, 1, 2).float() # NHWC -> NCHW
if transform is not None:
frames = transform(frames)
return {"images": frames, "categories": categories}
def __len__(self):
return len(self.video_list)