Spaces:
Running
on
Zero
Running
on
Zero
from utils.dataset_utils import * | |
class SingleVideoDataset(Dataset): | |
def __init__( | |
self, | |
tokenizer = None, | |
width: int = 256, | |
height: int = 256, | |
n_sample_frames: int = 4, | |
frame_step: int = 1, | |
single_video_path: str = "", | |
single_video_prompt: str = "", | |
use_caption: bool = False, | |
use_bucketing: bool = False, | |
**kwargs | |
): | |
self.tokenizer = tokenizer | |
self.use_bucketing = use_bucketing | |
self.frames = [] | |
self.index = 1 | |
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") | |
self.n_sample_frames = n_sample_frames | |
self.frame_step = frame_step | |
self.single_video_path = single_video_path | |
self.single_video_prompt = single_video_prompt | |
self.width = width | |
self.height = height | |
def create_video_chunks(self): | |
vr = decord.VideoReader(self.single_video_path) | |
vr_range = range(0, len(vr), self.frame_step) | |
self.frames = list(self.chunk(vr_range, self.n_sample_frames)) | |
return self.frames | |
def chunk(self, it, size): | |
it = iter(it) | |
return iter(lambda: tuple(islice(it, size)), ()) | |
def get_frame_batch(self, vr, resize=None): | |
index = self.index | |
frames = vr.get_batch(self.frames[self.index]) | |
if type(frames) == decord.ndarray.NDArray: | |
frames = torch.from_numpy(frames.asnumpy()) | |
video = rearrange(frames, "f h w c -> f c h w") | |
if resize is not None: video = resize(video) | |
return video | |
def get_frame_buckets(self, vr): | |
h, w, c = vr[0].shape | |
width, height = sensible_buckets(self.width, self.height, w, h) | |
resize = T.transforms.Resize((height, width), antialias=True) | |
return resize | |
def process_video_wrapper(self, vid_path): | |
video, vr = process_video( | |
vid_path, | |
self.use_bucketing, | |
self.width, | |
self.height, | |
self.get_frame_buckets, | |
self.get_frame_batch | |
) | |
return video, vr | |
def single_video_batch(self, index): | |
train_data = self.single_video_path | |
self.index = index | |
if train_data.endswith(self.vid_types): | |
video, _ = self.process_video_wrapper(train_data) | |
prompt = self.single_video_prompt | |
prompt_ids = get_prompt_ids(prompt, self.tokenizer) | |
return video, prompt, prompt_ids | |
else: | |
raise ValueError(f"Single video is not a video type. Types: {self.vid_types}") | |
def __getname__(): return 'single_video' | |
def __len__(self): | |
return len(self.create_video_chunks()) | |
def __getitem__(self, index): | |
video, prompt, prompt_ids = self.single_video_batch(index) | |
example = { | |
"pixel_values": (video / 127.5 - 1.0), | |
"prompt_ids": prompt_ids[0], | |
"text_prompt": prompt, | |
'dataset': self.__getname__() | |
} | |
return example |