Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import decord | |
decord.bridge.set_bridge('torch') | |
import torch | |
from torch.utils.data import Dataset | |
import torchvision | |
import torchvision.transforms as T | |
from itertools import islice | |
from glob import glob | |
from PIL import Image | |
from einops import rearrange, repeat | |
def read_caption_file(caption_file): | |
with open(caption_file, 'r', encoding="utf8") as t: | |
return t.read() | |
def get_text_prompt( | |
text_prompt: str = '', | |
fallback_prompt: str= '', | |
file_path:str = '', | |
ext_types=['.mp4'], | |
use_caption=False | |
): | |
try: | |
if use_caption: | |
if len(text_prompt) > 1: return text_prompt | |
caption_file = '' | |
# Use caption on per-video basis (One caption PER video) | |
for ext in ext_types: | |
maybe_file = file_path.replace(ext, '.txt') | |
if maybe_file.endswith(ext_types): continue | |
if os.path.exists(maybe_file): | |
caption_file = maybe_file | |
break | |
if os.path.exists(caption_file): | |
return read_caption_file(caption_file) | |
# Return fallback prompt if no conditions are met. | |
return fallback_prompt | |
return text_prompt | |
except: | |
print(f"Couldn't read prompt caption for {file_path}. Using fallback.") | |
return fallback_prompt | |
def get_video_frames(vr, start_idx, sample_rate=1, max_frames=24): | |
max_range = len(vr) | |
frame_number = sorted((0, start_idx, max_range))[1] | |
frame_range = range(frame_number, max_range, sample_rate) | |
frame_range_indices = list(frame_range)[:max_frames] | |
return frame_range_indices | |
def get_prompt_ids(prompt, tokenizer): | |
prompt_ids = tokenizer( | |
prompt, | |
truncation=True, | |
padding="max_length", | |
max_length=tokenizer.model_max_length, | |
return_tensors="pt", | |
).input_ids | |
return prompt_ids | |
def process_video(vid_path, use_bucketing, w, h, get_frame_buckets, get_frame_batch): | |
if use_bucketing: | |
vr = decord.VideoReader(vid_path) | |
resize = get_frame_buckets(vr) | |
video = get_frame_batch(vr, resize=resize) | |
else: | |
vr = decord.VideoReader(vid_path, width=w, height=h) | |
video = get_frame_batch(vr) | |
return video, vr | |
def min_res(size, min_size): return 192 if size < 192 else size | |
def up_down_bucket(m_size, in_size, direction): | |
if direction == 'down': return abs(int(m_size - in_size)) | |
if direction == 'up': return abs(int(m_size + in_size)) | |
def get_bucket_sizes(size, direction: 'down', min_size): | |
multipliers = [64, 128] | |
for i, m in enumerate(multipliers): | |
res = up_down_bucket(m, size, direction) | |
multipliers[i] = min_res(res, min_size=min_size) | |
return multipliers | |
def closest_bucket(m_size, size, direction, min_size): | |
lst = get_bucket_sizes(m_size, direction, min_size) | |
return lst[min(range(len(lst)), key=lambda i: abs(lst[i]-size))] | |
def resolve_bucket(i,h,w): return (i / (h / w)) | |
def sensible_buckets(m_width, m_height, w, h, min_size=192): | |
if h > w: | |
w = resolve_bucket(m_width, h, w) | |
w = closest_bucket(m_width, w, 'down', min_size=min_size) | |
return w, m_height | |
if h < w: | |
h = resolve_bucket(m_height, w, h) | |
h = closest_bucket(m_height, h, 'down', min_size=min_size) | |
return m_width, h | |
return m_width, m_height |