Spaces:
Runtime error
Runtime error
import os | |
import random | |
import bisect | |
import pandas as pd | |
import omegaconf | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from decord import VideoReader, cpu | |
import torchvision.transforms._transforms_video as transforms_video | |
class WebVid(Dataset): | |
""" | |
WebVid Dataset. | |
Assumes webvid data is structured as follows. | |
Webvid/ | |
videos/ | |
000001_000050/ ($page_dir) | |
1.mp4 (videoid.mp4) | |
... | |
5000.mp4 | |
... | |
""" | |
def __init__(self, | |
meta_path, | |
data_dir, | |
subsample=None, | |
video_length=16, | |
resolution=[256, 512], | |
frame_stride=1, | |
spatial_transform=None, | |
crop_resolution=None, | |
fps_max=None, | |
load_raw_resolution=False, | |
fps_schedule=None, | |
fs_probs=None, | |
bs_per_gpu=None, | |
trigger_word='', | |
dataname='', | |
): | |
self.meta_path = meta_path | |
self.data_dir = data_dir | |
self.subsample = subsample | |
self.video_length = video_length | |
self.resolution = [resolution, resolution] if isinstance(resolution, int) else resolution | |
self.frame_stride = frame_stride | |
self.fps_max = fps_max | |
self.load_raw_resolution = load_raw_resolution | |
self.fs_probs = fs_probs | |
self.trigger_word = trigger_word | |
self.dataname = dataname | |
self._load_metadata() | |
if spatial_transform is not None: | |
if spatial_transform == "random_crop": | |
self.spatial_transform = transforms_video.RandomCropVideo(crop_resolution) | |
elif spatial_transform == "resize_center_crop": | |
assert(self.resolution[0] == self.resolution[1]) | |
self.spatial_transform = transforms.Compose([ | |
transforms.Resize(resolution), | |
transforms_video.CenterCropVideo(resolution), | |
]) | |
else: | |
raise NotImplementedError | |
else: | |
self.spatial_transform = None | |
self.fps_schedule = fps_schedule | |
self.bs_per_gpu = bs_per_gpu | |
if self.fps_schedule is not None: | |
assert(self.bs_per_gpu is not None) | |
self.counter = 0 | |
self.stage_idx = 0 | |
def _load_metadata(self): | |
metadata = pd.read_csv(self.meta_path) | |
if self.subsample is not None: | |
metadata = metadata.sample(self.subsample, random_state=0) | |
metadata['caption'] = metadata['name'] | |
del metadata['name'] | |
self.metadata = metadata | |
self.metadata.dropna(inplace=True) | |
# self.metadata['caption'] = self.metadata['caption'].str[:350] | |
def _get_video_path(self, sample): | |
if self.dataname == "loradata": | |
rel_video_fp = str(sample['videoid']) + '.mp4' | |
full_video_fp = os.path.join(self.data_dir, rel_video_fp) | |
else: | |
rel_video_fp = os.path.join(sample['page_dir'], str(sample['videoid']) + '.mp4') | |
full_video_fp = os.path.join(self.data_dir, 'videos', rel_video_fp) | |
return full_video_fp, rel_video_fp | |
def get_fs_based_on_schedule(self, frame_strides, schedule): | |
assert(len(frame_strides) == len(schedule) + 1) # nstage=len_fps_schedule + 1 | |
global_step = self.counter // self.bs_per_gpu # TODO: support resume. | |
stage_idx = bisect.bisect(schedule, global_step) | |
frame_stride = frame_strides[stage_idx] | |
# log stage change | |
if stage_idx != self.stage_idx: | |
print(f'fps stage: {stage_idx} start ... new frame stride = {frame_stride}') | |
self.stage_idx = stage_idx | |
return frame_stride | |
def get_fs_based_on_probs(self, frame_strides, probs): | |
assert(len(frame_strides) == len(probs)) | |
return random.choices(frame_strides, weights=probs)[0] | |
def get_fs_randomly(self, frame_strides): | |
return random.choice(frame_strides) | |
def __getitem__(self, index): | |
if isinstance(self.frame_stride, list) or isinstance(self.frame_stride, omegaconf.listconfig.ListConfig): | |
if self.fps_schedule is not None: | |
frame_stride = self.get_fs_based_on_schedule(self.frame_stride, self.fps_schedule) | |
elif self.fs_probs is not None: | |
frame_stride = self.get_fs_based_on_probs(self.frame_stride, self.fs_probs) | |
else: | |
frame_stride = self.get_fs_randomly(self.frame_stride) | |
else: | |
frame_stride = self.frame_stride | |
assert(isinstance(frame_stride, int)), type(frame_stride) | |
while True: | |
index = index % len(self.metadata) | |
sample = self.metadata.iloc[index] | |
video_path, rel_fp = self._get_video_path(sample) | |
caption = sample['caption']+self.trigger_word | |
# make reader | |
try: | |
if self.load_raw_resolution: | |
video_reader = VideoReader(video_path, ctx=cpu(0)) | |
else: | |
video_reader = VideoReader(video_path, ctx=cpu(0), width=self.resolution[1], height=self.resolution[0]) | |
if len(video_reader) < self.video_length: | |
print(f"video length ({len(video_reader)}) is smaller than target length({self.video_length})") | |
index += 1 | |
continue | |
else: | |
pass | |
except: | |
index += 1 | |
print(f"Load video failed! path = {video_path}") | |
continue | |
# sample strided frames | |
all_frames = list(range(0, len(video_reader), frame_stride)) | |
if len(all_frames) < self.video_length: # recal a max fs | |
frame_stride = len(video_reader) // self.video_length | |
assert(frame_stride != 0) | |
all_frames = list(range(0, len(video_reader), frame_stride)) | |
# select a random clip | |
rand_idx = random.randint(0, len(all_frames) - self.video_length) | |
frame_indices = all_frames[rand_idx:rand_idx+self.video_length] | |
try: | |
frames = video_reader.get_batch(frame_indices) | |
break | |
except: | |
print(f"Get frames failed! path = {video_path}") | |
index += 1 | |
continue | |
assert(frames.shape[0] == self.video_length),f'{len(frames)}, self.video_length={self.video_length}' | |
frames = torch.tensor(frames.asnumpy()).permute(3, 0, 1, 2).float() # [t,h,w,c] -> [c,t,h,w] | |
if self.spatial_transform is not None: | |
frames = self.spatial_transform(frames) | |
if self.resolution is not None: | |
assert(frames.shape[2] == self.resolution[0] and frames.shape[3] == self.resolution[1]), f'frames={frames.shape}, self.resolution={self.resolution}' | |
frames = (frames / 255 - 0.5) * 2 | |
fps_ori = video_reader.get_avg_fps() | |
fps_clip = fps_ori // frame_stride | |
if self.fps_max is not None and fps_clip > self.fps_max: | |
fps_clip = self.fps_max | |
data = {'video': frames, 'caption': caption, 'path': video_path, 'fps': fps_clip, 'frame_stride': frame_stride} | |
if self.fps_schedule is not None: | |
self.counter += 1 | |
return data | |
def __len__(self): | |
return len(self.metadata) | |