import os import json import decord decord.bridge.set_bridge('torch') import torch from torch.utils.data import Dataset import torchvision import torchvision.transforms as T from itertools import islice from glob import glob from PIL import Image from einops import rearrange, repeat def read_caption_file(caption_file): with open(caption_file, 'r', encoding="utf8") as t: return t.read() def get_text_prompt( text_prompt: str = '', fallback_prompt: str= '', file_path:str = '', ext_types=['.mp4'], use_caption=False ): try: if use_caption: if len(text_prompt) > 1: return text_prompt caption_file = '' # Use caption on per-video basis (One caption PER video) for ext in ext_types: maybe_file = file_path.replace(ext, '.txt') if maybe_file.endswith(ext_types): continue if os.path.exists(maybe_file): caption_file = maybe_file break if os.path.exists(caption_file): return read_caption_file(caption_file) # Return fallback prompt if no conditions are met. return fallback_prompt return text_prompt except: print(f"Couldn't read prompt caption for {file_path}. Using fallback.") return fallback_prompt def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24): max_range = len(vr) frame_number = sorted((0, start_idx, max_range))[1] frame_range = range(frame_number, max_range, sample_rate) frame_range_indices = list(frame_range)[:max_frames] return frame_range_indices def get_prompt_ids(prompt, tokenizer): prompt_ids = tokenizer( prompt, truncation=True, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt", ).input_ids return prompt_ids def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch): if use_bucketing: vr = decord.VideoReader(vid_path) resize = get_frame_buckets(vr) video = get_frame_batch(vr, resize=resize) else: vr = decord.VideoReader(vid_path, width=w, height=h) video = get_frame_batch(vr) return video, vr def min_res(size, min_size): return 192 if size < 192 else size def up_down_bucket(m_size, in_size, direction): if direction == 'down': return abs(int(m_size - in_size)) if direction == 'up': return abs(int(m_size + in_size)) def get_bucket_sizes(size, direction: 'down', min_size): multipliers = [64, 128] for i, m in enumerate(multipliers): res = up_down_bucket(m, size, direction) multipliers[i] = min_res(res, min_size=min_size) return multipliers def closest_bucket(m_size, size, direction, min_size): lst = get_bucket_sizes(m_size, direction, min_size) return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))] def resolve_bucket(i,h,w): return (i / (h / w)) def sensible_buckets(m_width, m_height, w, h, min_size=192): if h > w: w = resolve_bucket(m_width, h, w) w = closest_bucket(m_width, w, 'down', min_size=min_size) return w, m_height if h < w: h = resolve_bucket(m_height, w, h) h = closest_bucket(m_height, h, 'down', min_size=min_size) return m_width, h return m_width, m_height