MotionInversion / dataset /single_video_dataset.py
ziyangmai's picture
page demo
113884e
from utils.dataset_utils import *
class SingleVideoDataset(Dataset):
def __init__(
self,
tokenizer = None,
width: int = 256,
height: int = 256,
n_sample_frames: int = 4,
frame_step: int = 1,
single_video_path: str = "",
single_video_prompt: str = "",
use_caption: bool = False,
use_bucketing: bool = False,
**kwargs
):
self.tokenizer = tokenizer
self.use_bucketing = use_bucketing
self.frames = []
self.index = 1
self.vid_types = (".mp4", ".avi", ".mov", ".webm", ".flv", ".mjpeg")
self.n_sample_frames = n_sample_frames
self.frame_step = frame_step
self.single_video_path = single_video_path
self.single_video_prompt = single_video_prompt
self.width = width
self.height = height
def create_video_chunks(self):
vr = decord.VideoReader(self.single_video_path)
vr_range = range(0, len(vr), self.frame_step)
self.frames = list(self.chunk(vr_range, self.n_sample_frames))
return self.frames
def chunk(self, it, size):
it = iter(it)
return iter(lambda: tuple(islice(it, size)), ())
def get_frame_batch(self, vr, resize=None):
index = self.index
frames = vr.get_batch(self.frames[self.index])
if type(frames) == decord.ndarray.NDArray:
frames = torch.from_numpy(frames.asnumpy())
video = rearrange(frames, "f h w c -> f c h w")
if resize is not None: video = resize(video)
return video
def get_frame_buckets(self, vr):
h, w, c = vr[0].shape
width, height = sensible_buckets(self.width, self.height, w, h)
resize = T.transforms.Resize((height, width), antialias=True)
return resize
def process_video_wrapper(self, vid_path):
video, vr = process_video(
vid_path,
self.use_bucketing,
self.width,
self.height,
self.get_frame_buckets,
self.get_frame_batch
)
return video, vr
def single_video_batch(self, index):
train_data = self.single_video_path
self.index = index
if train_data.endswith(self.vid_types):
video, _ = self.process_video_wrapper(train_data)
prompt = self.single_video_prompt
prompt_ids = get_prompt_ids(prompt, self.tokenizer)
return video, prompt, prompt_ids
else:
raise ValueError(f"Single video is not a video type. Types: {self.vid_types}")
@staticmethod
def __getname__(): return 'single_video'
def __len__(self):
return len(self.create_video_chunks())
def __getitem__(self, index):
video, prompt, prompt_ids = self.single_video_batch(index)
example = {
"pixel_values": (video / 127.5 - 1.0),
"prompt_ids": prompt_ids[0],
"text_prompt": prompt,
'dataset': self.__getname__()
}
return example