Spaces:
Running
on
T4
Running
on
T4
import decord | |
decord.bridge.set_bridge('torch') | |
from torch.utils.data import Dataset | |
from einops import rearrange | |
class TuneAVideoDataset(Dataset): | |
def __init__( | |
self, | |
video_path: str, | |
prompt: str, | |
width: int = 512, | |
height: int = 512, | |
n_sample_frames: int = 8, | |
sample_start_idx: int = 0, | |
sample_frame_rate: int = 1, | |
): | |
self.video_path = video_path | |
self.prompt = prompt | |
self.prompt_ids = None | |
self.width = width | |
self.height = height | |
self.n_sample_frames = n_sample_frames | |
self.sample_start_idx = sample_start_idx | |
self.sample_frame_rate = sample_frame_rate | |
def __len__(self): | |
return 1 | |
def __getitem__(self, index): | |
# load and sample video frames | |
vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) | |
sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] | |
video = vr.get_batch(sample_index) | |
video = rearrange(video, "f h w c -> f c h w") | |
example = { | |
"pixel_values": (video / 127.5 - 1.0), | |
"prompt_ids": self.prompt_ids | |
} | |
return example | |