from utils.dataset_utils import * class SingleVideoDataset(Dataset): def __init__( self, tokenizer = None, width: int = 256, height: int = 256, n_sample_frames: int = 4, frame_step: int = 1, single_video_path: str = "", single_video_prompt: str = "", use_caption: bool = False, use_bucketing: bool = False, **kwargs ): self.tokenizer = tokenizer self.use_bucketing = use_bucketing self.frames = [] self.index = 1 self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg") self.n_sample_frames = n_sample_frames self.frame_step = frame_step self.single_video_path = single_video_path self.single_video_prompt = single_video_prompt self.width = width self.height = height def create_video_chunks(self): vr = decord.VideoReader(self.single_video_path) vr_range = range(0, len(vr), self.frame_step) self.frames = list(self.chunk(vr_range, self.n_sample_frames)) return self.frames def chunk(self, it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) def get_frame_batch(self, vr, resize=None): index = self.index frames = vr.get_batch(self.frames[self.index]) if type(frames) == decord.ndarray.NDArray: frames = torch.from_numpy(frames.asnumpy()) video = rearrange(frames, "f h w c -> f c h w") if resize is not None: video = resize(video) return video def get_frame_buckets(self, vr): h, w, c = vr[0].shape width, height = sensible_buckets(self.width, self.height, w, h) resize = T.transforms.Resize((height, width), antialias=True) return resize def process_video_wrapper(self, vid_path): video, vr = process_video( vid_path, self.use_bucketing, self.width, self.height, self.get_frame_buckets, self.get_frame_batch ) return video, vr def single_video_batch(self, index): train_data = self.single_video_path self.index = index if train_data.endswith(self.vid_types): video, _ = self.process_video_wrapper(train_data) prompt = self.single_video_prompt prompt_ids = get_prompt_ids(prompt, self.tokenizer) return video, prompt, prompt_ids else: raise ValueError(f"Single video is not a video type. Types: {self.vid_types}") @staticmethod def __getname__(): return 'single_video' def __len__(self): return len(self.create_video_chunks()) def __getitem__(self, index): video, prompt, prompt_ids = self.single_video_batch(index) example = { "pixel_values": (video / 127.5 - 1.0), "prompt_ids": prompt_ids[0], "text_prompt": prompt, 'dataset': self.__getname__() } return example