import argparse import imageio import numpy as np import torch import torch.nn.functional as F from pathlib import Path import trimesh from omegaconf import OmegaConf from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor, Callback from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning import Trainer from skimage.io import imsave from tqdm import tqdm import mcubes from renderer.renderer import NeuSRenderer, DEFAULT_SIDE_LENGTH from util import instantiate_from_config, read_pickle class ResumeCallBacks(Callback): def __init__(self): pass def on_train_start(self, trainer, pl_module): pl_module.optimizers().param_groups = pl_module.optimizers()._optimizer.param_groups def render_images(model, output,): # render from model n = 180 azimuths = (np.arange(n) / n * np.pi * 2).astype(np.float32) elevations = np.deg2rad(np.asarray([30] * n).astype(np.float32)) K, _, _, _, poses = read_pickle(f'meta_info/camera-16.pkl') output_points h, w = 256, 256 default_size = 256 K = np.diag([w/default_size,h/default_size,1.0]) @ K imgs = [] for ni in tqdm(range(n)): # R = euler2mat(azimuths[ni], elevations[ni], 0, 'szyx') # R = np.asarray([[0,-1,0],[0,0,-1],[1,0,0]]) @ R e, a = elevations[ni], azimuths[ni] row1 = np.asarray([np.sin(e)*np.cos(a),np.sin(e)*np.sin(a),-np.cos(e)]) row0 = np.asarray([-np.sin(a),np.cos(a), 0]) row2 = np.cross(row0, row1) R = np.stack([row0,row1,row2],0) t = np.asarray([0,0,1.5]) pose = np.concatenate([R,t[:,None]],1) pose_ = torch.from_numpy(pose.astype(np.float32)).unsqueeze(0) K_ = torch.from_numpy(K.astype(np.float32)).unsqueeze(0) # [1,3,3] coords = torch.stack(torch.meshgrid(torch.arange(h), torch.arange(w)), -1)[:, :, (1, 0)] # h,w,2 coords = coords.float()[None, :, :, :].repeat(1, 1, 1, 1) # imn,h,w,2 coords = coords.reshape(1, h * w, 2) coords = torch.cat([coords, torch.ones(1, h * w, 1, dtype=torch.float32)], 2) # imn,h*w,3 # imn,h*w,3 @ imn,3,3 => imn,h*w,3 rays_d = coords @ torch.inverse(K_).permute(0, 2, 1) R, t = pose_[:, :, :3], pose_[:, :, 3:] rays_d = rays_d @ R rays_d = F.normalize(rays_d, dim=-1) rays_o = -R.permute(0, 2, 1) @ t # imn,3,3 @ imn,3,1 rays_o = rays_o.permute(0, 2, 1).repeat(1, h * w, 1) # imn,h*w,3 ray_batch = { 'rays_o': rays_o.reshape(-1,3).cuda(), 'rays_d': rays_d.reshape(-1,3).cuda(), } with torch.no_grad(): image = model.renderer.render(ray_batch,False,5000)['rgb'].reshape(h,w,3) image = (image.cpu().numpy() * 255).astype(np.uint8) imgs.append(image) imageio.mimsave(f'{output}/rendering.mp4', imgs, fps=30) def extract_fields(bound_min, bound_max, resolution, query_func, batch_size=64, outside_val=1.0): N = batch_size X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) u = np.zeros([resolution, resolution, resolution], dtype=np.float32) with torch.no_grad(): for xi, xs in enumerate(X): for yi, ys in enumerate(Y): for zi, zs in enumerate(Z): xx, yy, zz = torch.meshgrid(xs, ys, zs) pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda() val = query_func(pts).detach() outside_mask = torch.norm(pts,dim=-1)>=1.0 val[outside_mask]=outside_val val = val.reshape(len(xs), len(ys), len(zs)).cpu().numpy() u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val return u def extract_geometry(bound_min, bound_max, resolution, threshold, query_func, color_func, outside_val=1.0): u = extract_fields(bound_min, bound_max, resolution, query_func, outside_val=outside_val) vertices, triangles = mcubes.marching_cubes(u, threshold) b_max_np = bound_max.detach().cpu().numpy() b_min_np = bound_min.detach().cpu().numpy() vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] vertex_colors = color_func(vertices) return vertices, triangles, vertex_colors def extract_mesh(model, output, resolution=512): if not isinstance(model.renderer, NeuSRenderer): return bbox_min = -torch.ones(3)*DEFAULT_SIDE_LENGTH bbox_max = torch.ones(3)*DEFAULT_SIDE_LENGTH with torch.no_grad(): vertices, triangles, vertex_colors = extract_geometry(bbox_min, bbox_max, resolution, 0, lambda x: model.renderer.sdf_network.sdf(x), lambda x: model.renderer.get_vertex_colors(x)) # output geometry mesh = trimesh.Trimesh(vertices, triangles, vertex_colors=vertex_colors) mesh.export(str(f'{output}/mesh.ply')) def main(): parser = argparse.ArgumentParser() parser.add_argument('-i', '--image_path', type=str, required=True) parser.add_argument('-n', '--name', type=str, required=True) parser.add_argument('-b', '--base', type=str, default='configs/neus.yaml') parser.add_argument('-d', '--data_path', type=str, default='/data/GSO/') parser.add_argument('-l', '--log', type=str, default='output/renderer') parser.add_argument('-s', '--seed', type=int, default=6033) parser.add_argument('-g', '--gpus', type=str, default='0,') parser.add_argument('-r', '--resume', action='store_true', default=False, dest='resume') parser.add_argument('--fp16', action='store_true', default=False, dest='fp16') opt = parser.parse_args() # seed_everything(opt.seed) # configs cfg = OmegaConf.load(opt.base) name = opt.name log_dir, ckpt_dir = Path(opt.log) / name, Path(opt.log) / name / 'ckpt' cfg.model.params['image_path'] = opt.image_path cfg.model.params['log_dir'] = log_dir cfg.model.params['data_path'] = opt.data_path # setup log_dir.mkdir(exist_ok=True, parents=True) ckpt_dir.mkdir(exist_ok=True, parents=True) trainer_config = cfg.trainer callback_config = cfg.callbacks model_config = cfg.model data_config = cfg.data data_config.params.seed = opt.seed data = instantiate_from_config(data_config) data.prepare_data() data.setup('fit') model = instantiate_from_config(model_config,) model.cpu() model.learning_rate = model_config.base_lr # logger logger = TensorBoardLogger(save_dir=log_dir, name='tensorboard_logs') callbacks=[] callbacks.append(LearningRateMonitor(logging_interval='step')) callbacks.append(ModelCheckpoint(dirpath=ckpt_dir, filename="{epoch:06}", verbose=True, save_last=True, every_n_train_steps=callback_config.save_interval)) # trainer trainer_config.update({ "accelerator": "cuda", "check_val_every_n_epoch": None, "benchmark": True, "num_sanity_val_steps": 0, "devices": 1, "gpus": opt.gpus, }) if opt.fp16: trainer_config['precision']=16 if opt.resume: callbacks.append(ResumeCallBacks()) trainer_config['resume_from_checkpoint'] = str(ckpt_dir / 'last.ckpt') else: if (ckpt_dir / 'last.ckpt').exists(): raise RuntimeError(f"checkpoint {ckpt_dir / 'last.ckpt'} existing ...") trainer = Trainer.from_argparse_args(args=argparse.Namespace(), **trainer_config, logger=logger, callbacks=callbacks) trainer.fit(model, data) model = model.cuda().eval() # render_images(model, log_dir) extract_mesh(model, log_dir) if __name__=="__main__": main()