Spaces:
Runtime error
Runtime error
import os | |
import random | |
import json | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as F | |
import numpy as np | |
from decord import VideoReader | |
from torch.utils.data.dataset import Dataset | |
from packaging import version as pver | |
class RandomHorizontalFlipWithPose(nn.Module): | |
def __init__(self, p=0.5): | |
super(RandomHorizontalFlipWithPose, self).__init__() | |
self.p = p | |
def get_flip_flag(self, n_image): | |
return torch.rand(n_image) < self.p | |
def forward(self, image, flip_flag=None): | |
n_image = image.shape[0] | |
if flip_flag is not None: | |
assert n_image == flip_flag.shape[0] | |
else: | |
flip_flag = self.get_flip_flag(n_image) | |
ret_images = [] | |
for fflag, img in zip(flip_flag, image): | |
if fflag: | |
ret_images.append(F.hflip(img)) | |
else: | |
ret_images.append(img) | |
return torch.stack(ret_images, dim=0) | |
class Camera(object): | |
def __init__(self, entry): | |
fx, fy, cx, cy = entry[1:5] | |
self.fx = fx | |
self.fy = fy | |
self.cx = cx | |
self.cy = cy | |
w2c_mat = np.array(entry[7:]).reshape(3, 4) | |
w2c_mat_4x4 = np.eye(4) | |
w2c_mat_4x4[:3, :] = w2c_mat | |
self.w2c_mat = w2c_mat_4x4 | |
self.c2w_mat = np.linalg.inv(w2c_mat_4x4) | |
def custom_meshgrid(*args): | |
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid | |
if pver.parse(torch.__version__) < pver.parse('1.10'): | |
return torch.meshgrid(*args) | |
else: | |
return torch.meshgrid(*args, indexing='ij') | |
def ray_condition(K, c2w, H, W, device, flip_flag=None): | |
# c2w: B, V, 4, 4 | |
# K: B, V, 4 | |
B, V = K.shape[:2] | |
j, i = custom_meshgrid( | |
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype), | |
) | |
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW] | |
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0 | |
if n_flip > 0: | |
j_flip, i_flip = custom_meshgrid( | |
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype), | |
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype) | |
) | |
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5 | |
i[:, flip_flag, ...] = i_flip | |
j[:, flip_flag, ...] = j_flip | |
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1 | |
zs = torch.ones_like(i) # [B, V, HxW] | |
xs = (i - cx) / fx * zs | |
ys = (j - cy) / fy * zs | |
zs = zs.expand_as(ys) | |
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3 | |
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3 | |
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3 | |
rays_o = c2w[..., :3, 3] # B, V, 3 | |
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3 | |
# c2w @ dirctions | |
rays_dxo = torch.linalg.cross(rays_o, rays_d) # B, V, HW, 3 | |
plucker = torch.cat([rays_dxo, rays_d], dim=-1) | |
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6 | |
# plucker = plucker.permute(0, 1, 4, 2, 3) | |
return plucker | |
class RealEstate10K(Dataset): | |
def __init__( | |
self, | |
root_path, | |
annotation_json, | |
sample_stride=4, | |
sample_n_frames=16, | |
sample_size=[256, 384], | |
is_image=False, | |
): | |
self.root_path = root_path | |
self.sample_stride = sample_stride | |
self.sample_n_frames = sample_n_frames | |
self.is_image = is_image | |
self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r')) | |
self.length = len(self.dataset) | |
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
pixel_transforms = [transforms.Resize(sample_size), | |
transforms.RandomHorizontalFlip(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] | |
self.pixel_transforms = transforms.Compose(pixel_transforms) | |
def load_video_reader(self, idx): | |
video_dict = self.dataset[idx] | |
video_path = os.path.join(self.root_path, video_dict['clip_path']) | |
video_reader = VideoReader(video_path) | |
return video_reader, video_dict['caption'] | |
def get_batch(self, idx): | |
video_reader, video_caption = self.load_video_reader(idx) | |
total_frames = len(video_reader) | |
if self.is_image: | |
frame_indice = [random.randint(0, total_frames - 1)] | |
else: | |
if isinstance(self.sample_stride, int): | |
current_sample_stride = self.sample_stride | |
else: | |
assert len(self.sample_stride) == 2 | |
assert (self.sample_stride[0] >= 1) and (self.sample_stride[1] >= self.sample_stride[0]) | |
current_sample_stride = random.randint(self.sample_stride[0], self.sample_stride[1]) | |
cropped_length = self.sample_n_frames * current_sample_stride | |
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1)) | |
end_frame_ind = min(start_frame_ind + cropped_length, total_frames) | |
assert end_frame_ind - start_frame_ind >= self.sample_n_frames | |
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int) | |
pixel_values = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
pixel_values = pixel_values / 255. | |
if self.is_image: | |
pixel_values = pixel_values[0] | |
return pixel_values, video_caption | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
while True: | |
try: | |
video, video_caption = self.get_batch(idx) | |
break | |
except Exception as e: | |
idx = random.randint(0, self.length - 1) | |
video = self.pixel_transforms(video) | |
sample = dict(pixel_values=video, caption=video_caption) | |
return sample | |
class RealEstate10KPose(Dataset): | |
def __init__( | |
self, | |
root_path, | |
annotation_json, | |
sample_stride=4, | |
minimum_sample_stride=1, | |
sample_n_frames=16, | |
relative_pose=False, | |
zero_t_first_frame=False, | |
sample_size=[256, 384], | |
rescale_fxy=False, | |
shuffle_frames=False, | |
use_flip=False, | |
return_clip_name=False, | |
): | |
self.root_path = root_path | |
self.relative_pose = relative_pose | |
self.zero_t_first_frame = zero_t_first_frame | |
self.sample_stride = sample_stride | |
self.minimum_sample_stride = minimum_sample_stride | |
self.sample_n_frames = sample_n_frames | |
self.return_clip_name = return_clip_name | |
self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r')) | |
self.length = len(self.dataset) | |
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) | |
self.sample_size = sample_size | |
if use_flip: | |
pixel_transforms = [transforms.Resize(sample_size), | |
RandomHorizontalFlipWithPose(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] | |
else: | |
pixel_transforms = [transforms.Resize(sample_size), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)] | |
self.rescale_fxy = rescale_fxy | |
self.sample_wh_ratio = sample_size[1] / sample_size[0] | |
self.pixel_transforms = pixel_transforms | |
self.shuffle_frames = shuffle_frames | |
self.use_flip = use_flip | |
def get_relative_pose(self, cam_params): | |
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params] | |
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params] | |
source_cam_c2w = abs_c2ws[0] | |
if self.zero_t_first_frame: | |
cam_to_origin = 0 | |
else: | |
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3]) | |
target_cam_c2w = np.array([ | |
[1, 0, 0, 0], | |
[0, 1, 0, -cam_to_origin], | |
[0, 0, 1, 0], | |
[0, 0, 0, 1] | |
]) | |
abs2rel = target_cam_c2w @ abs_w2cs[0] | |
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]] | |
ret_poses = np.array(ret_poses, dtype=np.float32) | |
return ret_poses | |
def load_video_reader(self, idx): | |
video_dict = self.dataset[idx] | |
video_path = os.path.join(self.root_path, video_dict['clip_path']) | |
video_reader = VideoReader(video_path) | |
return video_dict['clip_name'], video_reader, video_dict['caption'] | |
def load_cameras(self, idx): | |
video_dict = self.dataset[idx] | |
pose_file = os.path.join(self.root_path, video_dict['pose_file']) | |
with open(pose_file, 'r') as f: | |
poses = f.readlines() | |
poses = [pose.strip().split(' ') for pose in poses[1:]] | |
cam_params = [[float(x) for x in pose] for pose in poses] | |
cam_params = [Camera(cam_param) for cam_param in cam_params] | |
return cam_params | |
def get_batch(self, idx): | |
clip_name, video_reader, video_caption = self.load_video_reader(idx) | |
cam_params = self.load_cameras(idx) | |
assert len(cam_params) >= self.sample_n_frames | |
total_frames = len(cam_params) | |
current_sample_stride = self.sample_stride | |
if total_frames < self.sample_n_frames * current_sample_stride: | |
maximum_sample_stride = int(total_frames // self.sample_n_frames) | |
current_sample_stride = random.randint(self.minimum_sample_stride, maximum_sample_stride) | |
cropped_length = self.sample_n_frames * current_sample_stride | |
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1)) | |
end_frame_ind = min(start_frame_ind + cropped_length, total_frames) | |
assert end_frame_ind - start_frame_ind >= self.sample_n_frames | |
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int) | |
condition_image_ind = random.sample(list(set(range(total_frames)) - set(frame_indices.tolist())), 1) | |
condition_image = torch.from_numpy(video_reader.get_batch(condition_image_ind).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
condition_image = condition_image / 255. | |
if self.shuffle_frames: | |
perm = np.random.permutation(self.sample_n_frames) | |
frame_indices = frame_indices[perm] | |
pixel_values = torch.from_numpy(video_reader.get_batch(frame_indices).asnumpy()).permute(0, 3, 1, 2).contiguous() | |
pixel_values = pixel_values / 255. | |
cam_params = [cam_params[indice] for indice in frame_indices] | |
if self.rescale_fxy: | |
ori_h, ori_w = pixel_values.shape[-2:] | |
ori_wh_ratio = ori_w / ori_h | |
if ori_wh_ratio > self.sample_wh_ratio: # rescale fx | |
resized_ori_w = self.sample_size[0] * ori_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fx = resized_ori_w * cam_param.fx / self.sample_size[1] | |
else: # rescale fy | |
resized_ori_h = self.sample_size[1] / ori_wh_ratio | |
for cam_param in cam_params: | |
cam_param.fy = resized_ori_h * cam_param.fy / self.sample_size[0] | |
intrinsics = np.asarray([[cam_param.fx * self.sample_size[1], | |
cam_param.fy * self.sample_size[0], | |
cam_param.cx * self.sample_size[1], | |
cam_param.cy * self.sample_size[0]] | |
for cam_param in cam_params], dtype=np.float32) | |
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4] | |
if self.relative_pose: | |
c2w_poses = self.get_relative_pose(cam_params) | |
else: | |
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32) | |
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4] | |
if self.use_flip: | |
flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames) | |
else: | |
flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device) | |
plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu', | |
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous() | |
return pixel_values, condition_image, plucker_embedding, video_caption, flip_flag, clip_name | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
while True: | |
try: | |
video, condition_image, plucker_embedding, video_caption, flip_flag, clip_name = self.get_batch(idx) | |
break | |
except Exception as e: | |
idx = random.randint(0, self.length - 1) | |
if self.use_flip: | |
video = self.pixel_transforms[0](video) | |
video = self.pixel_transforms[1](video, flip_flag) | |
for transform in self.pixel_transforms[2:]: | |
video = transform(video) | |
else: | |
for transform in self.pixel_transforms: | |
video = transform(video) | |
for transform in self.pixel_transforms: | |
condition_image = transform(condition_image) | |
if self.return_clip_name: | |
sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption, clip_name=clip_name) | |
else: | |
sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption) | |
return sample | |