LucidDreamer-mini / luciddreamer.py
ironjr's picture
Upload luciddreamer.py with huggingface_hub
6a43180
# Copyright (C) 2023, Computer Vision Lab, Seoul National University, https://cv.snu.ac.kr
#
# Copyright 2023 LucidDreamer Authors
#
# Computer Vision Lab, SNU, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from the Computer Vision Lab, SNU or
# its affiliates is strictly prohibited.
#
# For permission requests, please contact [email protected], [email protected], [email protected], [email protected].
import os
import glob
import json
import time
import datetime
import warnings
import shutil
from random import randint
from argparse import ArgumentParser
warnings.filterwarnings(action='ignore')
import pickle
import imageio
import numpy as np
import open3d as o3d
from PIL import Image
from tqdm import tqdm
from scipy.interpolate import griddata as interp_grid
from scipy.ndimage import minimum_filter, maximum_filter
import torch
import torch.nn.functional as F
import gradio as gr
from diffusers import (
StableDiffusionInpaintPipeline, StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetInpaintPipeline)
from arguments import GSParams, CameraParams
from gaussian_renderer import render
from scene import Scene, GaussianModel
from scene.dataset_readers import loadCameraPreset
from utils.loss import l1_loss, ssim
from utils.camera import load_json
from utils.depth import colorize
from utils.lama import LaMa
from utils.trajectory import get_camerapaths, get_pcdGenPoses
get_kernel = lambda p: torch.ones(1, 1, p * 2 + 1, p * 2 + 1).to('cuda')
t2np = lambda x: (x[0].permute(1, 2, 0).clamp_(0, 1) * 255.0).to(torch.uint8).detach().cpu().numpy()
np2t = lambda x: (torch.as_tensor(x).to(torch.float32).permute(2, 0, 1) / 255.0)[None, ...].to('cuda')
pad_mask = lambda x, padamount=1: t2np(
F.conv2d(np2t(x[..., None]), get_kernel(padamount), padding=padamount))[..., 0].astype(bool)
class LucidDreamer:
def __init__(self):
self.opt = GSParams()
self.cam = CameraParams()
self.root = 'outputs'
self.default_model = 'SD1.5 (default)'
self.timestamp = datetime.datetime.now().strftime('%y%m%d_%H%M%S')
self.gaussians = GaussianModel(self.opt.sh_degree)
bg_color = [1, 1, 1] if self.opt.white_background else [0, 0, 0]
self.background = torch.tensor(bg_color, dtype=torch.float32, device='cuda')
self.rgb_model = StableDiffusionInpaintPipeline.from_pretrained(
'stablediffusion/SD1-5', revision='fp16', torch_dtype=torch.float16).to('cuda')
self.d_model = torch.hub.load('./ZoeDepth', 'ZoeD_N', source='local', pretrained=True).to('cuda')
self.controlnet = None
self.lama = None
self.current_model = self.default_model
def load_model(self, model_name, use_lama=True):
if model_name is None:
model_name = self.default_model
if self.current_model == model_name:
return
if model_name == self.default_model:
self.controlnet = None
self.lama = None
self.rgb_model = StableDiffusionInpaintPipeline.from_pretrained(
# 'runwayml/stable-diffusion-inpainting',
'stablediffusion/SD1-5',
revision='fp16',
torch_dtype=torch.float16,
safety_checker=None,
).to('cuda')
else:
if self.controlnet is None:
self.controlnet = ControlNetModel.from_pretrained(
'lllyasviel/control_v11p_sd15_inpaint', torch_dtype=torch.float16)
if self.lama is None and use_lama:
self.lama = LaMa('cuda')
self.rgb_model = StableDiffusionControlNetInpaintPipeline.from_pretrained(
f'stablediffusion/{model_name}',
controlnet=self.controlnet,
revision='fp16',
torch_dtype=torch.float16,
safety_checker=None,
).to('cuda')
# self.rgb_model.enable_model_cpu_offload()
torch.cuda.empty_cache()
self.current_model = model_name
def rgb(self, prompt, image, negative_prompt='', generator=None, num_inference_steps=50, mask_image=None):
image_pil = Image.fromarray(np.round(image * 255.).astype(np.uint8))
mask_pil = Image.fromarray(np.round((1 - mask_image) * 255.).astype(np.uint8))
if self.current_model == self.default_model:
return self.rgb_model(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
num_inference_steps=num_inference_steps,
image=image_pil,
mask_image=mask_pil,
).images[0]
kwargs = {
'negative_prompt': negative_prompt,
'generator': generator,
'strength': 0.9,
'num_inference_steps': num_inference_steps,
'height': self.cam.H,
'width': self.cam.W,
}
# image_np = np.array(image).astype(float) / 255.0
# mask_np = 1.0 - np.array(mask_image) / 255.0
image_np = np.round(np.clip(image, 0, 1) * 255.).astype(np.uint8)
mask_sum = np.clip((image.prod(axis=-1) == 0) + (1 - mask_image), 0, 1)
mask_padded = pad_mask(mask_sum, 3)
masked = image_np * np.logical_not(mask_padded[..., None])
if self.lama is not None:
lama_image = Image.fromarray(self.lama(masked, mask_padded).astype(np.uint8))
else:
lama_image = image
mask_image = Image.fromarray(mask_padded.astype(np.uint8) * 255)
control_image = self.make_controlnet_inpaint_condition(lama_image, mask_image)
return self.rgb_model(
prompt=prompt,
image=lama_image,
control_image=control_image,
mask_image=mask_image,
**kwargs,
).images[0]
def d(self, im):
return self.d_model.infer_pil(im)
def make_controlnet_inpaint_condition(self, image, image_mask):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
image[image_mask > 0.5] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image
def run(self, rgb_cond, txt_cond, neg_txt_cond, pcdgenpath, seed, diff_steps, render_camerapath, model_name=None, example_name=None):
# gaussians, default_gallery = self.create(
gaussians = self.create(
rgb_cond, txt_cond, neg_txt_cond, pcdgenpath, seed, diff_steps, model_name, example_name)
gallery, depth = self.render_video(render_camerapath, example_name=example_name)
return (gaussians, gallery, depth)
# return (gaussians, default_gallery, gallery)
def create(self, rgb_cond, txt_cond, neg_txt_cond, pcdgenpath, seed, diff_steps, model_name=None, example_name=None):
self.cleaner()
self.load_model(model_name)
if example_name and example_name != 'DON\'T':
outfile = os.path.join('examples', f'{example_name}.ply')
if not os.path.exists(outfile):
self.traindata = self.generate_pcd(rgb_cond, txt_cond, neg_txt_cond, pcdgenpath, seed, diff_steps)
self.scene = Scene(self.traindata, self.gaussians, self.opt)
self.training()
outfile = self.save_ply(outfile)
else:
self.traindata = self.generate_pcd(rgb_cond, txt_cond, neg_txt_cond, pcdgenpath, seed, diff_steps)
self.scene = Scene(self.traindata, self.gaussians, self.opt)
self.training()
self.timestamp = datetime.datetime.now().strftime('%y%m%d_%H%M%S')
outfile = self.save_ply()
# default_gallery = self.render_video('llff', example_name=example_name)
return outfile #, default_gallery
def save_ply(self, fpath=None):
if fpath is None:
dpath = os.path.join(self.root, self.timestamp)
fpath = os.path.join(dpath, 'gsplat.ply')
os.makedirs(dpath, exist_ok=True)
if not os.path.exists(fpath):
self.gaussians.save_ply(fpath)
else:
self.gaussians.load_ply(fpath)
return fpath
def cleaner(self):
# Remove the temporary file created yesterday.
for dpath in glob.glob(os.path.join('/tmp/gradio', '*', self.root, '*')):
timestamp = datetime.datetime.strptime(os.path.basename(dpath), '%y%m%d_%H%M%S')
if timestamp < datetime.datetime.now() - datetime.timedelta(days=1):
try:
shutil.rmtree(dpath)
except OSError as e:
print("Error: %s - %s." % (e.filename, e.strerror))
def render_video(self, preset, example_name=None, progress=gr.Progress()):
if example_name and example_name != 'DON\'T':
videopath = os.path.join('examples', f'{example_name}_{preset}.mp4')
depthpath = os.path.join('examples', f'depth_{example_name}_{preset}.mp4')
else:
videopath = os.path.join(self.root, self.timestamp, f'{preset}.mp4')
depthpath = os.path.join(self.root, self.timestamp, f'depth_{preset}.mp4')
if os.path.exists(videopath) and os.path.exists(depthpath):
return videopath, depthpath
if not hasattr(self, 'scene'):
views = load_json(os.path.join('cameras', f'{preset}.json'), self.cam.H, self.cam.W)
else:
views = self.scene.getPresetCameras(preset)
framelist = []
depthlist = []
dmin, dmax = 1e8, -1e8
for view in progress.tqdm(views, desc='[4/4] Rendering a video'):
results = render(view, self.gaussians, self.opt, self.background, render_only=True)
frame, depth = results['render'], results['depth']
framelist.append(
np.round(frame.permute(1,2,0).detach().cpu().numpy().clip(0,1)*255.).astype(np.uint8))
depth = -(depth * (depth > 0)).detach().cpu().numpy()
dmin_local = depth.min().item()
dmax_local = depth.max().item()
if dmin_local < dmin:
dmin = dmin_local
if dmax_local > dmax:
dmax = dmax_local
depthlist.append(depth)
progress(1, desc='[4/4] Rendering a video...')
# depthlist = [colorize(depth, vmin=dmin, vmax=dmax) for depth in depthlist]
depthlist = [colorize(depth) for depth in depthlist]
if not os.path.exists(videopath):
imageio.mimwrite(videopath, framelist, fps=60, quality=8)
if not os.path.exists(depthpath):
imageio.mimwrite(depthpath, depthlist, fps=60, quality=8)
return videopath, depthpath
def training(self, progress=gr.Progress()):
if not self.scene:
raise('Build 3D Scene First!')
for iteration in progress.tqdm(range(1, self.opt.iterations + 1), desc='[3/4] Baking Gaussians'):
self.gaussians.update_learning_rate(iteration)
# Every 1000 its we increase the levels of SH up to a maximum degree
if iteration % 1000 == 0:
self.gaussians.oneupSHdegree()
# Pick a random Camera
viewpoint_stack = self.scene.getTrainCameras().copy()
viewpoint_cam = viewpoint_stack.pop(randint(0, len(viewpoint_stack)-1))
# import pdb; pdb.set_trace()
# Render
render_pkg = render(viewpoint_cam, self.gaussians, self.opt, self.background)
image, viewspace_point_tensor, visibility_filter, radii = (
render_pkg['render'], render_pkg['viewspace_points'], render_pkg['visibility_filter'], render_pkg['radii'])
# Loss
gt_image = viewpoint_cam.original_image.cuda()
Ll1 = l1_loss(image, gt_image)
loss = (1.0 - self.opt.lambda_dssim) * Ll1 + self.opt.lambda_dssim * (1.0 - ssim(image, gt_image))
loss.backward()
with torch.no_grad():
# Densification
if iteration < self.opt.densify_until_iter:
# Keep track of max radii in image-space for pruning
self.gaussians.max_radii2D[visibility_filter] = torch.max(
self.gaussians.max_radii2D[visibility_filter], radii[visibility_filter])
self.gaussians.add_densification_stats(viewspace_point_tensor, visibility_filter)
if iteration > self.opt.densify_from_iter and iteration % self.opt.densification_interval == 0:
size_threshold = 20 if iteration > self.opt.opacity_reset_interval else None
self.gaussians.densify_and_prune(
self.opt.densify_grad_threshold, 0.005, self.scene.cameras_extent, size_threshold)
if (iteration % self.opt.opacity_reset_interval == 0
or (self.opt.white_background and iteration == self.opt.densify_from_iter)
):
self.gaussians.reset_opacity()
# Optimizer step
if iteration < self.opt.iterations:
self.gaussians.optimizer.step()
self.gaussians.optimizer.zero_grad(set_to_none = True)
def generate_pcd(self, rgb_cond, prompt, negative_prompt, pcdgenpath, seed, diff_steps, progress=gr.Progress()):
## processing inputs
generator=torch.Generator(device='cuda').manual_seed(seed)
w_in, h_in = rgb_cond.size
if w_in/h_in > 1.1 or h_in/w_in > 1.1: # if height and width are similar, do center crop
in_res = max(w_in, h_in)
image_in, mask_in = np.zeros((in_res, in_res, 3), dtype=np.uint8), 255*np.ones((in_res, in_res, 3), dtype=np.uint8)
image_in[int(in_res/2-h_in/2):int(in_res/2+h_in/2), int(in_res/2-w_in/2):int(in_res/2+w_in/2)] = np.array(rgb_cond)
mask_in[int(in_res/2-h_in/2):int(in_res/2+h_in/2), int(in_res/2-w_in/2):int(in_res/2+w_in/2)] = 0
image2 = np.array(Image.fromarray(image_in).resize((self.cam.W, self.cam.H))).astype(float) / 255.0
mask2 = np.array(Image.fromarray(mask_in).resize((self.cam.W, self.cam.H))).astype(float) / 255.0
image_curr = self.rgb(
prompt=prompt,
image=image2,
negative_prompt=negative_prompt, generator=generator,
mask_image=mask2,
)
else: # if there is a large gap between height and width, do inpainting
if w_in > h_in:
image_curr = rgb_cond.crop((int(w_in/2-h_in/2), 0, int(w_in/2+h_in/2), h_in)).resize((self.cam.W, self.cam.H))
else: # w <= h
image_curr = rgb_cond.crop((0, int(h_in/2-w_in/2), w_in, int(h_in/2+w_in/2))).resize((self.cam.W, self.cam.H))
render_poses = get_pcdGenPoses(pcdgenpath)
depth_curr = self.d(image_curr)
center_depth = np.mean(depth_curr[h_in//2-10:h_in//2+10, w_in//2-10:w_in//2+10])
###########################################################################################################################
# Iterative scene generation
H, W, K = self.cam.H, self.cam.W, self.cam.K
x, y = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') # pixels
edgeN = 2
edgemask = np.ones((H-2*edgeN, W-2*edgeN))
edgemask = np.pad(edgemask, ((edgeN,edgeN),(edgeN,edgeN)))
### initialize
R0, T0 = render_poses[0,:3,:3], render_poses[0,:3,3:4]
pts_coord_cam = np.matmul(np.linalg.inv(K), np.stack((x*depth_curr, y*depth_curr, 1*depth_curr), axis=0).reshape(3,-1))
new_pts_coord_world2 = (np.linalg.inv(R0).dot(pts_coord_cam) - np.linalg.inv(R0).dot(T0)).astype(np.float32) ## new_pts_coord_world2
new_pts_colors2 = (np.array(image_curr).reshape(-1,3).astype(np.float32)/255.) ## new_pts_colors2
pts_coord_world, pts_colors = new_pts_coord_world2.copy(), new_pts_colors2.copy()
progress(0, desc='[1/4] Dreaming...')
# time.sleep(0.5)
for i in progress.tqdm(range(1, len(render_poses)), desc='[1/4] Dreaming'):
R, T = render_poses[i,:3,:3], render_poses[i,:3,3:4]
### Transform world to pixel
pts_coord_cam2 = R.dot(pts_coord_world) + T ### Same with c2w*world_coord (in homogeneous space)
pixel_coord_cam2 = np.matmul(K, pts_coord_cam2) #.reshape(3,H,W).transpose(1,2,0).astype(np.float32)
valid_idx = np.where(np.logical_and.reduce((pixel_coord_cam2[2]>0,
pixel_coord_cam2[0]/pixel_coord_cam2[2]>=0,
pixel_coord_cam2[0]/pixel_coord_cam2[2]<=W-1,
pixel_coord_cam2[1]/pixel_coord_cam2[2]>=0,
pixel_coord_cam2[1]/pixel_coord_cam2[2]<=H-1)))[0]
pixel_coord_cam2 = pixel_coord_cam2[:2, valid_idx]/pixel_coord_cam2[-1:, valid_idx]
round_coord_cam2 = np.round(pixel_coord_cam2).astype(np.int32)
x, y = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
grid = np.stack((x,y), axis=-1).reshape(-1,2)
image2 = interp_grid(pixel_coord_cam2.transpose(1,0), pts_colors[valid_idx], grid, method='linear', fill_value=0).reshape(H,W,3)
image2 = edgemask[...,None]*image2 + (1-edgemask[...,None])*np.pad(image2[1:-1,1:-1], ((1,1),(1,1),(0,0)), mode='edge')
round_mask2 = np.zeros((H,W), dtype=np.float32)
round_mask2[round_coord_cam2[1], round_coord_cam2[0]] = 1
round_mask2 = maximum_filter(round_mask2, size=(9,9), axes=(0,1))
image2 = round_mask2[...,None]*image2 + (1-round_mask2[...,None])*(-1)
mask2 = minimum_filter((image2.sum(-1)!=-3)*1, size=(11,11), axes=(0,1))
image2 = mask2[...,None]*image2 + (1-mask2[...,None])*0
mask_hf = np.abs(mask2[:H-1, :W-1] - mask2[1:, :W-1]) + np.abs(mask2[:H-1, :W-1] - mask2[:H-1, 1:])
mask_hf = np.pad(mask_hf, ((0,1), (0,1)), 'edge')
mask_hf = np.where(mask_hf < 0.3, 0, 1)
border_valid_idx = np.where(mask_hf[round_coord_cam2[1], round_coord_cam2[0]] == 1)[0] # use valid_idx[border_valid_idx] for world1
image_curr = self.rgb(
prompt=prompt, image=image2, #Image.fromarray(np.round(image2*255.).astype(np.uint8)),
negative_prompt=negative_prompt, generator=generator, num_inference_steps=diff_steps,
mask_image=mask2, #Image.fromarray(np.round((1-mask2[:,:])*255.).astype(np.uint8))
)
depth_curr = self.d(image_curr)
### depth optimize
t_z2 = torch.tensor(depth_curr)
sc = torch.ones(1).float().requires_grad_(True)
optimizer = torch.optim.Adam(params=[sc], lr=0.001)
for idx in range(100):
trans3d = torch.tensor([[sc,0,0,0], [0,sc,0,0], [0,0,sc,0], [0,0,0,1]]).requires_grad_(True)
coord_cam2 = torch.matmul(torch.tensor(np.linalg.inv(K)), torch.stack((torch.tensor(x)*t_z2, torch.tensor(y)*t_z2, 1*t_z2), axis=0)[:,round_coord_cam2[1], round_coord_cam2[0]].reshape(3,-1))
coord_world2 = (torch.tensor(np.linalg.inv(R)).float().matmul(coord_cam2) - torch.tensor(np.linalg.inv(R)).float().matmul(torch.tensor(T).float()))
coord_world2_warp = torch.cat((coord_world2, torch.ones((1,valid_idx.shape[0]))), dim=0)
coord_world2_trans = torch.matmul(trans3d, coord_world2_warp)
coord_world2_trans = coord_world2_trans[:3] / coord_world2_trans[-1]
loss = torch.mean((torch.tensor(pts_coord_world[:,valid_idx]).float() - coord_world2_trans)**2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
coord_cam2 = torch.matmul(torch.tensor(np.linalg.inv(K)), torch.stack((torch.tensor(x)*t_z2, torch.tensor(y)*t_z2, 1*t_z2), axis=0)[:,round_coord_cam2[1, border_valid_idx], round_coord_cam2[0, border_valid_idx]].reshape(3,-1))
coord_world2 = (torch.tensor(np.linalg.inv(R)).float().matmul(coord_cam2) - torch.tensor(np.linalg.inv(R)).float().matmul(torch.tensor(T).float()))
coord_world2_warp = torch.cat((coord_world2, torch.ones((1, border_valid_idx.shape[0]))), dim=0)
coord_world2_trans = torch.matmul(trans3d, coord_world2_warp)
coord_world2_trans = coord_world2_trans[:3] / coord_world2_trans[-1]
trans3d = trans3d.detach().numpy()
pts_coord_cam2 = np.matmul(np.linalg.inv(K), np.stack((x*depth_curr, y*depth_curr, 1*depth_curr), axis=0).reshape(3,-1))[:,np.where(1-mask2.reshape(-1))[0]]
camera_origin_coord_world2 = - np.linalg.inv(R).dot(T).astype(np.float32) # 3, 1
new_pts_coord_world2 = (np.linalg.inv(R).dot(pts_coord_cam2) - np.linalg.inv(R).dot(T)).astype(np.float32)
new_pts_coord_world2_warp = np.concatenate((new_pts_coord_world2, np.ones((1, new_pts_coord_world2.shape[1]))), axis=0)
new_pts_coord_world2 = np.matmul(trans3d, new_pts_coord_world2_warp)
new_pts_coord_world2 = new_pts_coord_world2[:3] / new_pts_coord_world2[-1]
new_pts_colors2 = (np.array(image_curr).reshape(-1,3).astype(np.float32)/255.)[np.where(1-mask2.reshape(-1))[0]]
vector_camorigin_to_campixels = coord_world2_trans.detach().numpy() - camera_origin_coord_world2
vector_camorigin_to_pcdpixels = pts_coord_world[:,valid_idx[border_valid_idx]] - camera_origin_coord_world2
compensate_depth_coeff = np.sum(vector_camorigin_to_pcdpixels * vector_camorigin_to_campixels, axis=0) / np.sum(vector_camorigin_to_campixels * vector_camorigin_to_campixels, axis=0) # N_correspond
compensate_pts_coord_world2_correspond = camera_origin_coord_world2 + vector_camorigin_to_campixels * compensate_depth_coeff.reshape(1,-1)
compensate_coord_cam2_correspond = R.dot(compensate_pts_coord_world2_correspond) + T
homography_coord_cam2_correspond = R.dot(coord_world2_trans.detach().numpy()) + T
compensate_depth_correspond = compensate_coord_cam2_correspond[-1] - homography_coord_cam2_correspond[-1] # N_correspond
compensate_depth_zero = np.zeros(4)
compensate_depth = np.concatenate((compensate_depth_correspond, compensate_depth_zero), axis=0) # N_correspond+4
pixel_cam2_correspond = pixel_coord_cam2[:, border_valid_idx] # 2, N_correspond (xy)
pixel_cam2_zero = np.array([[0,0,W-1,W-1],[0,H-1,0,H-1]])
pixel_cam2 = np.concatenate((pixel_cam2_correspond, pixel_cam2_zero), axis=1).transpose(1,0) # N+H, 2
# Calculate for masked pixels
masked_pixels_xy = np.stack(np.where(1-mask2), axis=1)[:, [1,0]]
new_depth_linear, new_depth_nearest = interp_grid(pixel_cam2, compensate_depth, masked_pixels_xy), interp_grid(pixel_cam2, compensate_depth, masked_pixels_xy, method='nearest')
new_depth = np.where(np.isnan(new_depth_linear), new_depth_nearest, new_depth_linear)
pts_coord_cam2 = np.matmul(np.linalg.inv(K), np.stack((x*depth_curr, y*depth_curr, 1*depth_curr), axis=0).reshape(3,-1))[:,np.where(1-mask2.reshape(-1))[0]]
x_nonmask, y_nonmask = x.reshape(-1)[np.where(1-mask2.reshape(-1))[0]], y.reshape(-1)[np.where(1-mask2.reshape(-1))[0]]
compensate_pts_coord_cam2 = np.matmul(np.linalg.inv(K), np.stack((x_nonmask*new_depth, y_nonmask*new_depth, 1*new_depth), axis=0))
new_warp_pts_coord_cam2 = pts_coord_cam2 + compensate_pts_coord_cam2
new_pts_coord_world2 = (np.linalg.inv(R).dot(new_warp_pts_coord_cam2) - np.linalg.inv(R).dot(T)).astype(np.float32)
new_pts_coord_world2_warp = np.concatenate((new_pts_coord_world2, np.ones((1, new_pts_coord_world2.shape[1]))), axis=0)
new_pts_coord_world2 = np.matmul(trans3d, new_pts_coord_world2_warp)
new_pts_coord_world2 = new_pts_coord_world2[:3] / new_pts_coord_world2[-1]
new_pts_colors2 = (np.array(image_curr).reshape(-1,3).astype(np.float32)/255.)[np.where(1-mask2.reshape(-1))[0]]
pts_coord_world = np.concatenate((pts_coord_world, new_pts_coord_world2), axis=-1) ### Same with inv(c2w) * cam_coord (in homogeneous space)
pts_colors = np.concatenate((pts_colors, new_pts_colors2), axis=0)
#################################################################################################
yz_reverse = np.array([[1,0,0], [0,-1,0], [0,0,-1]])
traindata = {
'camera_angle_x': self.cam.fov[0],
'W': W,
'H': H,
'pcd_points': pts_coord_world,
'pcd_colors': pts_colors,
'frames': [],
}
# render_poses = get_pcdGenPoses(pcdgenpath)
internel_render_poses = get_pcdGenPoses('hemisphere', {'center_depth': center_depth})
progress(0, desc='[2/4] Aligning...')
# time.sleep(0.5)
for i in progress.tqdm(range(len(render_poses)), desc='[2/4] Aligning'):
for j in range(len(internel_render_poses)):
idx = i * len(internel_render_poses) + j
print(f'{idx+1} / {len(render_poses)*len(internel_render_poses)}')
### Transform world to pixel
Rw2i = render_poses[i,:3,:3]
Tw2i = render_poses[i,:3,3:4]
Ri2j = internel_render_poses[j,:3,:3]
Ti2j = internel_render_poses[j,:3,3:4]
Rw2j = np.matmul(Ri2j, Rw2i)
Tw2j = np.matmul(Ri2j, Tw2i) + Ti2j
# Transfrom cam2 to world + change sign of yz axis
Rj2w = np.matmul(yz_reverse, Rw2j).T
Tj2w = -np.matmul(Rj2w, np.matmul(yz_reverse, Tw2j))
Pc2w = np.concatenate((Rj2w, Tj2w), axis=1)
Pc2w = np.concatenate((Pc2w, np.array([[0,0,0,1]])), axis=0)
pts_coord_camj = Rw2j.dot(pts_coord_world) + Tw2j
pixel_coord_camj = np.matmul(K, pts_coord_camj)
valid_idxj = np.where(np.logical_and.reduce((pixel_coord_camj[2]>0,
pixel_coord_camj[0]/pixel_coord_camj[2]>=0,
pixel_coord_camj[0]/pixel_coord_camj[2]<=W-1,
pixel_coord_camj[1]/pixel_coord_camj[2]>=0,
pixel_coord_camj[1]/pixel_coord_camj[2]<=H-1)))[0]
if len(valid_idxj) == 0:
continue
pts_depthsj = pixel_coord_camj[-1:, valid_idxj]
pixel_coord_camj = pixel_coord_camj[:2, valid_idxj]/pixel_coord_camj[-1:, valid_idxj]
round_coord_camj = np.round(pixel_coord_camj).astype(np.int32)
x, y = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy') # pixels
grid = np.stack((x,y), axis=-1).reshape(-1,2)
imagej = interp_grid(pixel_coord_camj.transpose(1,0), pts_colors[valid_idxj], grid, method='linear', fill_value=0).reshape(H,W,3)
imagej = edgemask[...,None]*imagej + (1-edgemask[...,None])*np.pad(imagej[1:-1,1:-1], ((1,1),(1,1),(0,0)), mode='edge')
depthj = interp_grid(pixel_coord_camj.transpose(1,0), pts_depthsj.T, grid, method='linear', fill_value=0).reshape(H,W)
depthj = edgemask*depthj + (1-edgemask)*np.pad(depthj[1:-1,1:-1], ((1,1),(1,1)), mode='edge')
maskj = np.zeros((H,W), dtype=np.float32)
maskj[round_coord_camj[1], round_coord_camj[0]] = 1
maskj = maximum_filter(maskj, size=(9,9), axes=(0,1))
imagej = maskj[...,None]*imagej + (1-maskj[...,None])*(-1)
maskj = minimum_filter((imagej.sum(-1)!=-3)*1, size=(11,11), axes=(0,1))
imagej = maskj[...,None]*imagej + (1-maskj[...,None])*0
traindata['frames'].append({
'image': Image.fromarray(np.round(imagej*255.).astype(np.uint8)),
'transform_matrix': Pc2w.tolist(),
})
progress(1, desc='[3/4] Baking Gaussians...')
return traindata