Spaces:
Runtime error
Runtime error
import os | |
import math | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader, IterableDataset | |
import torchvision.transforms.functional as TF | |
import pytorch_lightning as pl | |
import datasets | |
from datasets.colmap_utils import \ | |
read_cameras_binary, read_images_binary, read_points3d_binary | |
from models.ray_utils import get_ray_directions | |
from utils.misc import get_rank | |
def get_center(pts): | |
center = pts.mean(0) | |
dis = (pts - center[None,:]).norm(p=2, dim=-1) | |
mean, std = dis.mean(), dis.std() | |
q25, q75 = torch.quantile(dis, 0.25), torch.quantile(dis, 0.75) | |
valid = (dis > mean - 1.5 * std) & (dis < mean + 1.5 * std) & (dis > mean - (q75 - q25) * 1.5) & (dis < mean + (q75 - q25) * 1.5) | |
center = pts[valid].mean(0) | |
return center | |
def normalize_poses(poses, pts, up_est_method, center_est_method): | |
if center_est_method == 'camera': | |
# estimation scene center as the average of all camera positions | |
center = poses[...,3].mean(0) | |
elif center_est_method == 'lookat': | |
# estimation scene center as the average of the intersection of selected pairs of camera rays | |
cams_ori = poses[...,3] | |
cams_dir = poses[:,:3,:3] @ torch.as_tensor([0.,0.,-1.]) | |
cams_dir = F.normalize(cams_dir, dim=-1) | |
A = torch.stack([cams_dir, -cams_dir.roll(1,0)], dim=-1) | |
b = -cams_ori + cams_ori.roll(1,0) | |
t = torch.linalg.lstsq(A, b).solution | |
center = (torch.stack([cams_dir, cams_dir.roll(1,0)], dim=-1) * t[:,None,:] + torch.stack([cams_ori, cams_ori.roll(1,0)], dim=-1)).mean((0,2)) | |
elif center_est_method == 'point': | |
# first estimation scene center as the average of all camera positions | |
# later we'll use the center of all points bounded by the cameras as the final scene center | |
center = poses[...,3].mean(0) | |
else: | |
raise NotImplementedError(f'Unknown center estimation method: {center_est_method}') | |
if up_est_method == 'ground': | |
# estimate up direction as the normal of the estimated ground plane | |
# use RANSAC to estimate the ground plane in the point cloud | |
import pyransac3d as pyrsc | |
ground = pyrsc.Plane() | |
plane_eq, inliers = ground.fit(pts.numpy(), thresh=0.01) # TODO: determine thresh based on scene scale | |
plane_eq = torch.as_tensor(plane_eq) # A, B, C, D in Ax + By + Cz + D = 0 | |
z = F.normalize(plane_eq[:3], dim=-1) # plane normal as up direction | |
signed_distance = (torch.cat([pts, torch.ones_like(pts[...,0:1])], dim=-1) * plane_eq).sum(-1) | |
if signed_distance.mean() < 0: | |
z = -z # flip the direction if points lie under the plane | |
elif up_est_method == 'camera': | |
# estimate up direction as the average of all camera up directions | |
z = F.normalize((poses[...,3] - center).mean(0), dim=0) | |
else: | |
raise NotImplementedError(f'Unknown up estimation method: {up_est_method}') | |
# new axis | |
y_ = torch.as_tensor([z[1], -z[0], 0.]) | |
x = F.normalize(y_.cross(z), dim=0) | |
y = z.cross(x) | |
if center_est_method == 'point': | |
# rotation | |
Rc = torch.stack([x, y, z], dim=1) | |
R = Rc.T | |
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) | |
inv_trans = torch.cat([torch.cat([R, torch.as_tensor([[0.,0.,0.]]).T], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) | |
poses_norm = (inv_trans @ poses_homo)[:,:3] | |
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] | |
# translation and scaling | |
poses_min, poses_max = poses_norm[...,3].min(0)[0], poses_norm[...,3].max(0)[0] | |
pts_fg = pts[(poses_min[0] < pts[:,0]) & (pts[:,0] < poses_max[0]) & (poses_min[1] < pts[:,1]) & (pts[:,1] < poses_max[1])] | |
center = get_center(pts_fg) | |
tc = center.reshape(3, 1) | |
t = -tc | |
poses_homo = torch.cat([poses_norm, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses_norm.shape[0], -1, -1)], dim=1) | |
inv_trans = torch.cat([torch.cat([torch.eye(3), t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) | |
poses_norm = (inv_trans @ poses_homo)[:,:3] | |
scale = poses_norm[...,3].norm(p=2, dim=-1).min() | |
poses_norm[...,3] /= scale | |
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] | |
pts = pts / scale | |
else: | |
# rotation and translation | |
Rc = torch.stack([x, y, z], dim=1) | |
tc = center.reshape(3, 1) | |
R, t = Rc.T, -Rc.T @ tc | |
poses_homo = torch.cat([poses, torch.as_tensor([[[0.,0.,0.,1.]]]).expand(poses.shape[0], -1, -1)], dim=1) | |
inv_trans = torch.cat([torch.cat([R, t], dim=1), torch.as_tensor([[0.,0.,0.,1.]])], dim=0) | |
poses_norm = (inv_trans @ poses_homo)[:,:3] # (N_images, 4, 4) | |
# scaling | |
scale = poses_norm[...,3].norm(p=2, dim=-1).min() | |
poses_norm[...,3] /= scale | |
# apply the transformation to the point cloud | |
pts = (inv_trans @ torch.cat([pts, torch.ones_like(pts[:,0:1])], dim=-1)[...,None])[:,:3,0] | |
pts = pts / scale | |
return poses_norm, pts | |
def create_spheric_poses(cameras, n_steps=120): | |
center = torch.as_tensor([0.,0.,0.], dtype=cameras.dtype, device=cameras.device) | |
mean_d = (cameras - center[None,:]).norm(p=2, dim=-1).mean() | |
mean_h = cameras[:,2].mean() | |
r = (mean_d**2 - mean_h**2).sqrt() | |
up = torch.as_tensor([0., 0., 1.], dtype=center.dtype, device=center.device) | |
all_c2w = [] | |
for theta in torch.linspace(0, 2 * math.pi, n_steps): | |
cam_pos = torch.stack([r * theta.cos(), r * theta.sin(), mean_h]) | |
l = F.normalize(center - cam_pos, p=2, dim=0) | |
s = F.normalize(l.cross(up), p=2, dim=0) | |
u = F.normalize(s.cross(l), p=2, dim=0) | |
c2w = torch.cat([torch.stack([s, u, -l], dim=1), cam_pos[:,None]], axis=1) | |
all_c2w.append(c2w) | |
all_c2w = torch.stack(all_c2w, dim=0) | |
return all_c2w | |
class ColmapDatasetBase(): | |
# the data only has to be processed once | |
initialized = False | |
properties = {} | |
def setup(self, config, split): | |
self.config = config | |
self.split = split | |
self.rank = get_rank() | |
if not ColmapDatasetBase.initialized: | |
camdata = read_cameras_binary(os.path.join(self.config.root_dir, 'sparse/0/cameras.bin')) | |
H = int(camdata[1].height) | |
W = int(camdata[1].width) | |
if 'img_wh' in self.config: | |
w, h = self.config.img_wh | |
assert round(W / w * h) == H | |
elif 'img_downscale' in self.config: | |
w, h = int(W / self.config.img_downscale + 0.5), int(H / self.config.img_downscale + 0.5) | |
else: | |
raise KeyError("Either img_wh or img_downscale should be specified.") | |
img_wh = (w, h) | |
factor = w / W | |
if camdata[1].model == 'SIMPLE_RADIAL': | |
fx = fy = camdata[1].params[0] * factor | |
cx = camdata[1].params[1] * factor | |
cy = camdata[1].params[2] * factor | |
elif camdata[1].model in ['PINHOLE', 'OPENCV']: | |
fx = camdata[1].params[0] * factor | |
fy = camdata[1].params[1] * factor | |
cx = camdata[1].params[2] * factor | |
cy = camdata[1].params[3] * factor | |
else: | |
raise ValueError(f"Please parse the intrinsics for camera model {camdata[1].model}!") | |
directions = get_ray_directions(w, h, fx, fy, cx, cy).to(self.rank) | |
imdata = read_images_binary(os.path.join(self.config.root_dir, 'sparse/0/images.bin')) | |
mask_dir = os.path.join(self.config.root_dir, 'masks') | |
has_mask = os.path.exists(mask_dir) # TODO: support partial masks | |
apply_mask = has_mask and self.config.apply_mask | |
all_c2w, all_images, all_fg_masks = [], [], [] | |
for i, d in enumerate(imdata.values()): | |
R = d.qvec2rotmat() | |
t = d.tvec.reshape(3, 1) | |
c2w = torch.from_numpy(np.concatenate([R.T, -R.T@t], axis=1)).float() | |
c2w[:,1:3] *= -1. # COLMAP => OpenGL | |
all_c2w.append(c2w) | |
if self.split in ['train', 'val']: | |
img_path = os.path.join(self.config.root_dir, 'images', d.name) | |
img = Image.open(img_path) | |
img = img.resize(img_wh, Image.BICUBIC) | |
img = TF.to_tensor(img).permute(1, 2, 0)[...,:3] | |
img = img.to(self.rank) if self.config.load_data_on_gpu else img.cpu() | |
if has_mask: | |
mask_paths = [os.path.join(mask_dir, d.name), os.path.join(mask_dir, d.name[3:])] | |
mask_paths = list(filter(os.path.exists, mask_paths)) | |
assert len(mask_paths) == 1 | |
mask = Image.open(mask_paths[0]).convert('L') # (H, W, 1) | |
mask = mask.resize(img_wh, Image.BICUBIC) | |
mask = TF.to_tensor(mask)[0] | |
else: | |
mask = torch.ones_like(img[...,0], device=img.device) | |
all_fg_masks.append(mask) # (h, w) | |
all_images.append(img) | |
all_c2w = torch.stack(all_c2w, dim=0) | |
pts3d = read_points3d_binary(os.path.join(self.config.root_dir, 'sparse/0/points3D.bin')) | |
pts3d = torch.from_numpy(np.array([pts3d[k].xyz for k in pts3d])).float() | |
all_c2w, pts3d = normalize_poses(all_c2w, pts3d, up_est_method=self.config.up_est_method, center_est_method=self.config.center_est_method) | |
ColmapDatasetBase.properties = { | |
'w': w, | |
'h': h, | |
'img_wh': img_wh, | |
'factor': factor, | |
'has_mask': has_mask, | |
'apply_mask': apply_mask, | |
'directions': directions, | |
'pts3d': pts3d, | |
'all_c2w': all_c2w, | |
'all_images': all_images, | |
'all_fg_masks': all_fg_masks | |
} | |
ColmapDatasetBase.initialized = True | |
for k, v in ColmapDatasetBase.properties.items(): | |
setattr(self, k, v) | |
if self.split == 'test': | |
self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.config.n_test_traj_steps) | |
self.all_images = torch.zeros((self.config.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) | |
self.all_fg_masks = torch.zeros((self.config.n_test_traj_steps, self.h, self.w), dtype=torch.float32) | |
else: | |
self.all_images, self.all_fg_masks = torch.stack(self.all_images, dim=0).float(), torch.stack(self.all_fg_masks, dim=0).float() | |
""" | |
# for debug use | |
from models.ray_utils import get_rays | |
rays_o, rays_d = get_rays(self.directions.cpu(), self.all_c2w, keepdim=True) | |
pts_out = [] | |
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 0.0 0.0' for l in rays_o[:,0,0].reshape(-1, 3).tolist()])) | |
t_vals = torch.linspace(0, 1, 8) | |
z_vals = 0.05 * (1 - t_vals) + 0.5 * t_vals | |
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,0][..., None, :]) | |
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 0.0' for l in ray_pts.view(-1, 3).tolist()])) | |
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,0][..., None, :]) | |
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) | |
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,0,self.w-1][..., None, :]) | |
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) | |
ray_pts = (rays_o[:,0,0][..., None, :] + z_vals[..., None] * rays_d[:,self.h-1,self.w-1][..., None, :]) | |
pts_out.append('\n'.join([' '.join([str(p) for p in l]) + ' 1.0 1.0 1.0' for l in ray_pts.view(-1, 3).tolist()])) | |
open('cameras.txt', 'w').write('\n'.join(pts_out)) | |
open('scene.txt', 'w').write('\n'.join([' '.join([str(p) for p in l]) + ' 0.0 0.0 0.0' for l in self.pts3d.view(-1, 3).tolist()])) | |
exit(1) | |
""" | |
self.all_c2w = self.all_c2w.float().to(self.rank) | |
if self.config.load_data_on_gpu: | |
self.all_images = self.all_images.to(self.rank) | |
self.all_fg_masks = self.all_fg_masks.to(self.rank) | |
class ColmapDataset(Dataset, ColmapDatasetBase): | |
def __init__(self, config, split): | |
self.setup(config, split) | |
def __len__(self): | |
return len(self.all_images) | |
def __getitem__(self, index): | |
return { | |
'index': index | |
} | |
class ColmapIterableDataset(IterableDataset, ColmapDatasetBase): | |
def __init__(self, config, split): | |
self.setup(config, split) | |
def __iter__(self): | |
while True: | |
yield {} | |
class ColmapDataModule(pl.LightningDataModule): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
def setup(self, stage=None): | |
if stage in [None, 'fit']: | |
self.train_dataset = ColmapIterableDataset(self.config, 'train') | |
if stage in [None, 'fit', 'validate']: | |
self.val_dataset = ColmapDataset(self.config, self.config.get('val_split', 'train')) | |
if stage in [None, 'test']: | |
self.test_dataset = ColmapDataset(self.config, self.config.get('test_split', 'test')) | |
if stage in [None, 'predict']: | |
self.predict_dataset = ColmapDataset(self.config, 'train') | |
def prepare_data(self): | |
pass | |
def general_loader(self, dataset, batch_size): | |
sampler = None | |
return DataLoader( | |
dataset, | |
num_workers=os.cpu_count(), | |
batch_size=batch_size, | |
pin_memory=True, | |
sampler=sampler | |
) | |
def train_dataloader(self): | |
return self.general_loader(self.train_dataset, batch_size=1) | |
def val_dataloader(self): | |
return self.general_loader(self.val_dataset, batch_size=1) | |
def test_dataloader(self): | |
return self.general_loader(self.test_dataset, batch_size=1) | |
def predict_dataloader(self): | |
return self.general_loader(self.predict_dataset, batch_size=1) | |