fffiloni's picture
Upload 244 files
b3f324b verified
raw
history blame contribute delete
No virus
4.27 kB
import os.path as osp
import random
from glob import glob
from torchvision import transforms
import numpy as np
import torch
import torch.utils.data as data
import torch.nn.functional as F
from torchvision.transforms import Lambda
from ....dataset.transform import ToTensorVideo, CenterCropVideo
from ....utils.dataset_utils import DecordInit
def TemporalRandomCrop(total_frames, size):
"""
Performs a random temporal crop on a video sequence.
This function randomly selects a continuous frame sequence of length `size` from a video sequence.
`total_frames` indicates the total number of frames in the video sequence, and `size` represents the length of the frame sequence to be cropped.
Parameters:
- total_frames (int): The total number of frames in the video sequence.
- size (int): The length of the frame sequence to be cropped.
Returns:
- (int, int): A tuple containing two integers. The first integer is the starting frame index of the cropped sequence,
and the second integer is the ending frame index (inclusive) of the cropped sequence.
"""
rand_end = max(0, total_frames - size - 1)
begin_index = random.randint(0, rand_end)
end_index = min(begin_index + size, total_frames)
return begin_index, end_index
def resize(x, resolution):
height, width = x.shape[-2:]
resolution = min(2 * resolution, height, width)
aspect_ratio = width / height
if width <= height:
new_width = resolution
new_height = int(resolution / aspect_ratio)
else:
new_height = resolution
new_width = int(resolution * aspect_ratio)
resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)
return resized_x
class VideoDataset(data.Dataset):
""" Generic dataset for videos files stored in folders
Returns BCTHW videos in the range [-0.5, 0.5] """
video_exts = ['avi', 'mp4', 'webm']
def __init__(self, video_folder, sequence_length, image_folder=None, train=True, resolution=64, sample_rate=1, dynamic_sample=True):
self.train = train
self.sequence_length = sequence_length
self.sample_rate = sample_rate
self.resolution = resolution
self.v_decoder = DecordInit()
self.video_folder = video_folder
self.dynamic_sample = dynamic_sample
self.transform = transforms.Compose([
ToTensorVideo(),
# Lambda(lambda x: resize(x, self.resolution)),
CenterCropVideo(self.resolution),
Lambda(lambda x: 2.0 * x - 1.0)
])
print('Building datasets...')
self.samples = self._make_dataset()
def _make_dataset(self):
samples = []
samples += sum([glob(osp.join(self.video_folder, '**', f'*.{ext}'), recursive=True)
for ext in self.video_exts], [])
return samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
video_path = self.samples[idx]
try:
video = self.decord_read(video_path)
video = self.transform(video) # T C H W -> T C H W
video = video.transpose(0, 1) # T C H W -> C T H W
return dict(video=video, label="")
except Exception as e:
print(f'Error with {e}, {video_path}')
return self.__getitem__(random.randint(0, self.__len__()-1))
def decord_read(self, path):
decord_vr = self.v_decoder(path)
total_frames = len(decord_vr)
# Sampling video frames
if self.dynamic_sample:
sample_rate = random.randint(1, self.sample_rate)
else:
sample_rate = self.sample_rate
size = self.sequence_length * sample_rate
start_frame_ind, end_frame_ind = TemporalRandomCrop(total_frames, size)
# assert end_frame_ind - start_frame_ind >= self.num_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sequence_length, dtype=int)
video_data = decord_vr.get_batch(frame_indice).asnumpy()
video_data = torch.from_numpy(video_data)
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
return video_data