hao he
Add gradio codes for CameraCtrl with SVD-xt model
308c973
raw
history blame
14.5 kB
import os
import random
import json
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import numpy as np
from decord import VideoReader
from torch.utils.data.dataset import Dataset
from packaging import version as pver
class RandomHorizontalFlipWithPose(nn.Module):
def __init__(self, p=0.5):
super(RandomHorizontalFlipWithPose, self).__init__()
self.p = p
def get_flip_flag(self, n_image):
return torch.rand(n_image) < self.p
def forward(self, image, flip_flag=None):
n_image = image.shape[0]
if flip_flag is not None:
assert n_image == flip_flag.shape[0]
else:
flip_flag = self.get_flip_flag(n_image)
ret_images = []
for fflag, img in zip(flip_flag, image):
if fflag:
ret_images.append(F.hflip(img))
else:
ret_images.append(img)
return torch.stack(ret_images, dim=0)
class Camera(object):
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
w2c_mat = np.array(entry[7:]).reshape(3, 4)
w2c_mat_4x4 = np.eye(4)
w2c_mat_4x4[:3, :] = w2c_mat
self.w2c_mat = w2c_mat_4x4
self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
return torch.meshgrid(*args)
else:
return torch.meshgrid(*args, indexing='ij')
def ray_condition(K, c2w, H, W, device, flip_flag=None):
# c2w: B, V, 4, 4
# K: B, V, 4
B, V = K.shape[:2]
j, i = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
)
i = i.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
j = j.reshape([1, 1, H * W]).expand([B, V, H * W]) + 0.5 # [B, V, HxW]
n_flip = torch.sum(flip_flag).item() if flip_flag is not None else 0
if n_flip > 0:
j_flip, i_flip = custom_meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(W - 1, 0, W, device=device, dtype=c2w.dtype)
)
i_flip = i_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
j_flip = j_flip.reshape([1, 1, H * W]).expand(B, 1, H * W) + 0.5
i[:, flip_flag, ...] = i_flip
j[:, flip_flag, ...] = j_flip
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, V, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, HW, 3
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, HW, 3
# c2w @ dirctions
rays_dxo = torch.linalg.cross(rays_o, rays_d) # B, V, HW, 3
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
class RealEstate10K(Dataset):
def __init__(
self,
root_path,
annotation_json,
sample_stride=4,
sample_n_frames=16,
sample_size=[256, 384],
is_image=False,
):
self.root_path = root_path
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
self.length = len(self.dataset)
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
pixel_transforms = [transforms.Resize(sample_size),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
self.pixel_transforms = transforms.Compose(pixel_transforms)
def load_video_reader(self, idx):
video_dict = self.dataset[idx]
video_path = os.path.join(self.root_path, video_dict['clip_path'])
video_reader = VideoReader(video_path)
return video_reader, video_dict['caption']
def get_batch(self, idx):
video_reader, video_caption = self.load_video_reader(idx)
total_frames = len(video_reader)
if self.is_image:
frame_indice = [random.randint(0, total_frames - 1)]
else:
if isinstance(self.sample_stride, int):
current_sample_stride = self.sample_stride
else:
assert len(self.sample_stride) == 2
assert (self.sample_stride[0] >= 1) and (self.sample_stride[1] >= self.sample_stride[0])
current_sample_stride = random.randint(self.sample_stride[0], self.sample_stride[1])
cropped_length = self.sample_n_frames * current_sample_stride
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
assert end_frame_ind - start_frame_ind >= self.sample_n_frames
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
pixel_values = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, video_caption
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
video, video_caption = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length - 1)
video = self.pixel_transforms(video)
sample = dict(pixel_values=video, caption=video_caption)
return sample
class RealEstate10KPose(Dataset):
def __init__(
self,
root_path,
annotation_json,
sample_stride=4,
minimum_sample_stride=1,
sample_n_frames=16,
relative_pose=False,
zero_t_first_frame=False,
sample_size=[256, 384],
rescale_fxy=False,
shuffle_frames=False,
use_flip=False,
return_clip_name=False,
):
self.root_path = root_path
self.relative_pose = relative_pose
self.zero_t_first_frame = zero_t_first_frame
self.sample_stride = sample_stride
self.minimum_sample_stride = minimum_sample_stride
self.sample_n_frames = sample_n_frames
self.return_clip_name = return_clip_name
self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
self.length = len(self.dataset)
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.sample_size = sample_size
if use_flip:
pixel_transforms = [transforms.Resize(sample_size),
RandomHorizontalFlipWithPose(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
else:
pixel_transforms = [transforms.Resize(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
self.rescale_fxy = rescale_fxy
self.sample_wh_ratio = sample_size[1] / sample_size[0]
self.pixel_transforms = pixel_transforms
self.shuffle_frames = shuffle_frames
self.use_flip = use_flip
def get_relative_pose(self, cam_params):
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
source_cam_c2w = abs_c2ws[0]
if self.zero_t_first_frame:
cam_to_origin = 0
else:
cam_to_origin = np.linalg.norm(source_cam_c2w[:3, 3])
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
def load_video_reader(self, idx):
video_dict = self.dataset[idx]
video_path = os.path.join(self.root_path, video_dict['clip_path'])
video_reader = VideoReader(video_path)
return video_dict['clip_name'], video_reader, video_dict['caption']
def load_cameras(self, idx):
video_dict = self.dataset[idx]
pose_file = os.path.join(self.root_path, video_dict['pose_file'])
with open(pose_file, 'r') as f:
poses = f.readlines()
poses = [pose.strip().split(' ') for pose in poses[1:]]
cam_params = [[float(x) for x in pose] for pose in poses]
cam_params = [Camera(cam_param) for cam_param in cam_params]
return cam_params
def get_batch(self, idx):
clip_name, video_reader, video_caption = self.load_video_reader(idx)
cam_params = self.load_cameras(idx)
assert len(cam_params) >= self.sample_n_frames
total_frames = len(cam_params)
current_sample_stride = self.sample_stride
if total_frames < self.sample_n_frames * current_sample_stride:
maximum_sample_stride = int(total_frames // self.sample_n_frames)
current_sample_stride = random.randint(self.minimum_sample_stride, maximum_sample_stride)
cropped_length = self.sample_n_frames * current_sample_stride
start_frame_ind = random.randint(0, max(0, total_frames - cropped_length - 1))
end_frame_ind = min(start_frame_ind + cropped_length, total_frames)
assert end_frame_ind - start_frame_ind >= self.sample_n_frames
frame_indices = np.linspace(start_frame_ind, end_frame_ind - 1, self.sample_n_frames, dtype=int)
condition_image_ind = random.sample(list(set(range(total_frames)) - set(frame_indices.tolist())), 1)
condition_image = torch.from_numpy(video_reader.get_batch(condition_image_ind).asnumpy()).permute(0, 3, 1, 2).contiguous()
condition_image = condition_image / 255.
if self.shuffle_frames:
perm = np.random.permutation(self.sample_n_frames)
frame_indices = frame_indices[perm]
pixel_values = torch.from_numpy(video_reader.get_batch(frame_indices).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
cam_params = [cam_params[indice] for indice in frame_indices]
if self.rescale_fxy:
ori_h, ori_w = pixel_values.shape[-2:]
ori_wh_ratio = ori_w / ori_h
if ori_wh_ratio > self.sample_wh_ratio: # rescale fx
resized_ori_w = self.sample_size[0] * ori_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / self.sample_size[1]
else: # rescale fy
resized_ori_h = self.sample_size[1] / ori_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / self.sample_size[0]
intrinsics = np.asarray([[cam_param.fx * self.sample_size[1],
cam_param.fy * self.sample_size[0],
cam_param.cx * self.sample_size[1],
cam_param.cy * self.sample_size[0]]
for cam_param in cam_params], dtype=np.float32)
intrinsics = torch.as_tensor(intrinsics)[None] # [1, n_frame, 4]
if self.relative_pose:
c2w_poses = self.get_relative_pose(cam_params)
else:
c2w_poses = np.array([cam_param.c2w_mat for cam_param in cam_params], dtype=np.float32)
c2w = torch.as_tensor(c2w_poses)[None] # [1, n_frame, 4, 4]
if self.use_flip:
flip_flag = self.pixel_transforms[1].get_flip_flag(self.sample_n_frames)
else:
flip_flag = torch.zeros(self.sample_n_frames, dtype=torch.bool, device=c2w.device)
plucker_embedding = ray_condition(intrinsics, c2w, self.sample_size[0], self.sample_size[1], device='cpu',
flip_flag=flip_flag)[0].permute(0, 3, 1, 2).contiguous()
return pixel_values, condition_image, plucker_embedding, video_caption, flip_flag, clip_name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
video, condition_image, plucker_embedding, video_caption, flip_flag, clip_name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length - 1)
if self.use_flip:
video = self.pixel_transforms[0](video)
video = self.pixel_transforms[1](video, flip_flag)
for transform in self.pixel_transforms[2:]:
video = transform(video)
else:
for transform in self.pixel_transforms:
video = transform(video)
for transform in self.pixel_transforms:
condition_image = transform(condition_image)
if self.return_clip_name:
sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption, clip_name=clip_name)
else:
sample = dict(pixel_values=video, condition_image=condition_image, plucker_embedding=plucker_embedding, video_caption=video_caption)
return sample