Spaces:
Runtime error
Runtime error
import io | |
import os | |
import sys | |
from functools import partial | |
import math | |
import torchvision.transforms as TT | |
from sgm.webds import MetaDistributedWebDataset | |
import random | |
from fractions import Fraction | |
from typing import Union, Optional, Dict, Any, Tuple | |
from torchvision.io.video import av | |
import numpy as np | |
import torch | |
from torchvision.io import _video_opt | |
from torchvision.io.video import _check_av_available, _read_from_stream, _align_audio_frames | |
from torchvision.transforms.functional import center_crop, resize | |
from torchvision.transforms import InterpolationMode | |
import decord | |
from decord import VideoReader | |
from torch.utils.data import Dataset | |
def read_video( | |
filename: str, | |
start_pts: Union[float, Fraction] = 0, | |
end_pts: Optional[Union[float, Fraction]] = None, | |
pts_unit: str = "pts", | |
output_format: str = "THWC", | |
) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: | |
""" | |
Reads a video from a file, returning both the video frames and the audio frames | |
Args: | |
filename (str): path to the video file | |
start_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): | |
The start presentation time of the video | |
end_pts (int if pts_unit = 'pts', float / Fraction if pts_unit = 'sec', optional): | |
The end presentation time | |
pts_unit (str, optional): unit in which start_pts and end_pts values will be interpreted, | |
either 'pts' or 'sec'. Defaults to 'pts'. | |
output_format (str, optional): The format of the output video tensors. Can be either "THWC" (default) or "TCHW". | |
Returns: | |
vframes (Tensor[T, H, W, C] or Tensor[T, C, H, W]): the `T` video frames | |
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points | |
info (Dict): metadata for the video and audio. Can contain the fields video_fps (float) and audio_fps (int) | |
""" | |
output_format = output_format.upper() | |
if output_format not in ("THWC", "TCHW"): | |
raise ValueError(f"output_format should be either 'THWC' or 'TCHW', got {output_format}.") | |
_check_av_available() | |
if end_pts is None: | |
end_pts = float("inf") | |
if end_pts < start_pts: | |
raise ValueError(f"end_pts should be larger than start_pts, got start_pts={start_pts} and end_pts={end_pts}") | |
info = {} | |
audio_frames = [] | |
audio_timebase = _video_opt.default_timebase | |
with av.open(filename, metadata_errors="ignore") as container: | |
if container.streams.audio: | |
audio_timebase = container.streams.audio[0].time_base | |
if container.streams.video: | |
video_frames = _read_from_stream( | |
container, | |
start_pts, | |
end_pts, | |
pts_unit, | |
container.streams.video[0], | |
{"video": 0}, | |
) | |
video_fps = container.streams.video[0].average_rate | |
# guard against potentially corrupted files | |
if video_fps is not None: | |
info["video_fps"] = float(video_fps) | |
if container.streams.audio: | |
audio_frames = _read_from_stream( | |
container, | |
start_pts, | |
end_pts, | |
pts_unit, | |
container.streams.audio[0], | |
{"audio": 0}, | |
) | |
info["audio_fps"] = container.streams.audio[0].rate | |
aframes_list = [frame.to_ndarray() for frame in audio_frames] | |
vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) | |
if aframes_list: | |
aframes = np.concatenate(aframes_list, 1) | |
aframes = torch.as_tensor(aframes) | |
if pts_unit == "sec": | |
start_pts = int(math.floor(start_pts * (1 / audio_timebase))) | |
if end_pts != float("inf"): | |
end_pts = int(math.ceil(end_pts * (1 / audio_timebase))) | |
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) | |
else: | |
aframes = torch.empty((1, 0), dtype=torch.float32) | |
if output_format == "TCHW": | |
# [T,H,W,C] --> [T,C,H,W] | |
vframes = vframes.permute(0, 3, 1, 2) | |
return vframes, aframes, info | |
def resize_for_rectangle_crop(arr, image_size, reshape_mode="random"): | |
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: | |
arr = resize( | |
arr, | |
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], | |
interpolation=InterpolationMode.BICUBIC, | |
) | |
else: | |
arr = resize( | |
arr, | |
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], | |
interpolation=InterpolationMode.BICUBIC, | |
) | |
h, w = arr.shape[2], arr.shape[3] | |
arr = arr.squeeze(0) | |
delta_h = h - image_size[0] | |
delta_w = w - image_size[1] | |
if reshape_mode == "random" or reshape_mode == "none": | |
top = np.random.randint(0, delta_h + 1) | |
left = np.random.randint(0, delta_w + 1) | |
elif reshape_mode == "center": | |
top, left = delta_h // 2, delta_w // 2 | |
else: | |
raise NotImplementedError | |
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) | |
return arr | |
def pad_last_frame(tensor, num_frames): | |
# T, H, W, C | |
if tensor.shape[0] < num_frames: | |
last_frame = tensor[-int(num_frames - tensor.shape[1]) :] | |
padded_tensor = torch.cat([tensor, last_frame], dim=0) | |
return padded_tensor | |
else: | |
return tensor[:num_frames] | |
def load_video( | |
video_data, | |
sampling="uniform", | |
duration=None, | |
num_frames=4, | |
wanted_fps=None, | |
actual_fps=None, | |
skip_frms_num=0.0, | |
nb_read_frames=None, | |
): | |
decord.bridge.set_bridge("torch") | |
vr = VideoReader(uri=video_data, height=-1, width=-1) | |
if nb_read_frames is not None: | |
ori_vlen = nb_read_frames | |
else: | |
ori_vlen = min(int(duration * actual_fps) - 1, len(vr)) | |
max_seek = int(ori_vlen - skip_frms_num - num_frames / wanted_fps * actual_fps) | |
start = random.randint(skip_frms_num, max_seek + 1) | |
end = int(start + num_frames / wanted_fps * actual_fps) | |
n_frms = num_frames | |
if sampling == "uniform": | |
indices = np.arange(start, end, (end - start) / n_frms).astype(int) | |
else: | |
raise NotImplementedError | |
# get_batch -> T, H, W, C | |
temp_frms = vr.get_batch(np.arange(start, end)) | |
assert temp_frms is not None | |
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] | |
return pad_last_frame(tensor_frms, num_frames) | |
import threading | |
def load_video_with_timeout(*args, **kwargs): | |
video_container = {} | |
def target_function(): | |
video = load_video(*args, **kwargs) | |
video_container["video"] = video | |
thread = threading.Thread(target=target_function) | |
thread.start() | |
timeout = 20 | |
thread.join(timeout) | |
if thread.is_alive(): | |
print("Loading video timed out") | |
raise TimeoutError | |
return video_container.get("video", None).contiguous() | |
def process_video( | |
video_path, | |
image_size=None, | |
duration=None, | |
num_frames=4, | |
wanted_fps=None, | |
actual_fps=None, | |
skip_frms_num=0.0, | |
nb_read_frames=None, | |
): | |
""" | |
video_path: str or io.BytesIO | |
image_size: . | |
duration: preknow the duration to speed up by seeking to sampled start. TODO by_pass if unknown. | |
num_frames: wanted num_frames. | |
wanted_fps: . | |
skip_frms_num: ignore the first and the last xx frames, avoiding transitions. | |
""" | |
video = load_video_with_timeout( | |
video_path, | |
duration=duration, | |
num_frames=num_frames, | |
wanted_fps=wanted_fps, | |
actual_fps=actual_fps, | |
skip_frms_num=skip_frms_num, | |
nb_read_frames=nb_read_frames, | |
) | |
# --- copy and modify the image process --- | |
video = video.permute(0, 3, 1, 2) # [T, C, H, W] | |
# resize | |
if image_size is not None: | |
video = resize_for_rectangle_crop(video, image_size, reshape_mode="center") | |
return video | |
def process_fn_video(src, image_size, fps, num_frames, skip_frms_num=0.0, txt_key="caption"): | |
while True: | |
r = next(src) | |
if "mp4" in r: | |
video_data = r["mp4"] | |
elif "avi" in r: | |
video_data = r["avi"] | |
else: | |
print("No video data found") | |
continue | |
if txt_key not in r: | |
txt = "" | |
else: | |
txt = r[txt_key] | |
if isinstance(txt, bytes): | |
txt = txt.decode("utf-8") | |
else: | |
txt = str(txt) | |
duration = r.get("duration", None) | |
if duration is not None: | |
duration = float(duration) | |
else: | |
continue | |
actual_fps = r.get("fps", None) | |
if actual_fps is not None: | |
actual_fps = float(actual_fps) | |
else: | |
continue | |
required_frames = num_frames / fps * actual_fps + 2 * skip_frms_num | |
required_duration = num_frames / fps + 2 * skip_frms_num / actual_fps | |
if duration is not None and duration < required_duration: | |
continue | |
try: | |
frames = process_video( | |
io.BytesIO(video_data), | |
num_frames=num_frames, | |
wanted_fps=fps, | |
image_size=image_size, | |
duration=duration, | |
actual_fps=actual_fps, | |
skip_frms_num=skip_frms_num, | |
) | |
frames = (frames - 127.5) / 127.5 | |
except Exception as e: | |
print(e) | |
continue | |
item = { | |
"mp4": frames, | |
"txt": txt, | |
"num_frames": num_frames, | |
"fps": fps, | |
} | |
yield item | |
class VideoDataset(MetaDistributedWebDataset): | |
def __init__( | |
self, | |
path, | |
image_size, | |
num_frames, | |
fps, | |
skip_frms_num=0.0, | |
nshards=sys.maxsize, | |
seed=1, | |
meta_names=None, | |
shuffle_buffer=1000, | |
include_dirs=None, | |
txt_key="caption", | |
**kwargs, | |
): | |
if seed == -1: | |
seed = random.randint(0, 1000000) | |
if meta_names is None: | |
meta_names = [] | |
if path.startswith(";"): | |
path, include_dirs = path.split(";", 1) | |
super().__init__( | |
path, | |
partial( | |
process_fn_video, num_frames=num_frames, image_size=image_size, fps=fps, skip_frms_num=skip_frms_num | |
), | |
seed, | |
meta_names=meta_names, | |
shuffle_buffer=shuffle_buffer, | |
nshards=nshards, | |
include_dirs=include_dirs, | |
) | |
def create_dataset_function(cls, path, args, **kwargs): | |
return cls(path, **kwargs) | |
class SFTDataset(Dataset): | |
def __init__(self, data_dir, video_size, fps, max_num_frames, skip_frms_num=3): | |
""" | |
skip_frms_num: ignore the first and the last xx frames, avoiding transitions. | |
""" | |
super(SFTDataset, self).__init__() | |
self.videos_list = [] | |
self.captions_list = [] | |
self.num_frames_list = [] | |
self.fps_list = [] | |
decord.bridge.set_bridge("torch") | |
for root, dirnames, filenames in os.walk(data_dir): | |
for filename in filenames: | |
if filename.endswith(".mp4"): | |
video_path = os.path.join(root, filename) | |
vr = VideoReader(uri=video_path, height=-1, width=-1) | |
actual_fps = vr.get_avg_fps() | |
ori_vlen = len(vr) | |
if ori_vlen / actual_fps * fps > max_num_frames: | |
num_frames = max_num_frames | |
start = int(skip_frms_num) | |
end = int(start + num_frames / fps * actual_fps) | |
indices = np.arange(start, end, (end - start) / num_frames).astype(int) | |
temp_frms = vr.get_batch(np.arange(start, end)) | |
assert temp_frms is not None | |
tensor_frms = torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] | |
else: | |
if ori_vlen > max_num_frames: | |
num_frames = max_num_frames | |
start = int(skip_frms_num) | |
end = int(ori_vlen - skip_frms_num) | |
indices = np.arange(start, end, (end - start) / num_frames).astype(int) | |
temp_frms = vr.get_batch(np.arange(start, end)) | |
assert temp_frms is not None | |
tensor_frms = ( | |
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
) | |
tensor_frms = tensor_frms[torch.tensor((indices - start).tolist())] | |
else: | |
def nearest_smaller_4k_plus_1(n): | |
remainder = n % 4 | |
if remainder == 0: | |
return n - 3 | |
else: | |
return n - remainder + 1 | |
start = int(skip_frms_num) | |
end = int(ori_vlen - skip_frms_num) | |
num_frames = nearest_smaller_4k_plus_1( | |
end - start | |
) # 3D VAE requires the number of frames to be 4k+1 | |
end = int(start + num_frames) | |
temp_frms = vr.get_batch(np.arange(start, end)) | |
assert temp_frms is not None | |
tensor_frms = ( | |
torch.from_numpy(temp_frms) if type(temp_frms) is not torch.Tensor else temp_frms | |
) | |
tensor_frms = pad_last_frame( | |
tensor_frms, num_frames | |
) # the len of indices may be less than num_frames, due to round error | |
tensor_frms = tensor_frms.permute(0, 3, 1, 2) # [T, H, W, C] -> [T, C, H, W] | |
tensor_frms = resize_for_rectangle_crop(tensor_frms, video_size, reshape_mode="center") | |
tensor_frms = (tensor_frms - 127.5) / 127.5 | |
self.videos_list.append(tensor_frms) | |
# caption | |
caption_path = os.path.join(root, filename.replace("videos", "labels").replace(".mp4", ".txt")) | |
if os.path.exists(caption_path): | |
caption = open(caption_path, "r").read().splitlines()[0] | |
else: | |
caption = "" | |
self.captions_list.append(caption) | |
self.num_frames_list.append(num_frames) | |
self.fps_list.append(fps) | |
def __getitem__(self, index): | |
item = { | |
"mp4": self.videos_list[index], | |
"txt": self.captions_list[index], | |
"num_frames": self.num_frames_list[index], | |
"fps": self.fps_list[index], | |
} | |
return item | |
def __len__(self): | |
return len(self.fps_list) | |
def create_dataset_function(cls, path, args, **kwargs): | |
return cls(data_dir=path, **kwargs) | |