import spaces import os import numpy as np from PIL import Image from omegaconf import OmegaConf from functools import partial import gradio as gr from huggingface_hub import hf_hub_download import torch from torchvision import transforms import rembg import cv2 from pytorch_lightning import seed_everything from src.visualizer import CameraVisualizer from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs from src.pose_funcs import find_optimal_poses from src.utils import spherical_to_cartesian, elu_to_c2w if torch.cuda.is_available(): _device_ = 'cuda:0' else: _device_ = 'cpu' _config_path_ = 'src/configs/sd-objaverse-finetune-c_concat-256.yaml' _ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/zero123-xl.ckpt', repo_type='model') _matcher_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/indoor_ds_new.ckpt', repo_type='model') _config_ = OmegaConf.load(_config_path_) _model_ = load_model_from_config(_config_, _ckpt_path_, device='cpu') _model_ = _model_.to(_device_) _model_.eval() def rgba_to_rgb(img): assert img.mode == 'RGBA' img = np.asarray(img, dtype=np.float32) img[:, :, :3] = img[:, :, :3] * (img[..., 3:]/255.) + (255-img[..., 3:]) img = img.clip(0, 255).astype(np.uint8) return Image.fromarray(img[:, :, :3]) def remove_background(image, rembg_session = None, force = False, **rembg_kwargs): do_remove = True if image.mode == "RGBA" and image.getextrema()[3][0] < 255: do_remove = False do_remove = do_remove or force if do_remove: image = rembg.remove(image, session=rembg_session, **rembg_kwargs) return image def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255, 255]): ws = [] hs = [] images = [ np.asarray(img) for img in images ] for img in images: alpha = img[:, :, 3] yy, xx = np.where(alpha > mask_thres) y0, y1 = yy.min(), yy.max() x0, x1 = xx.min(), xx.max() ws.append(float(x1 - x0) / img.shape[0]) hs.append(float(y1 - y0) / img.shape[1]) sz_w = np.max(ws) sz_h = np.max(hs) sz = max(ratio*sz_w, ratio*sz_h) out_rgbs = [] for rgba in images: rgb = rgba[:, :, :3] alpha = rgba[:, :, 3] yy, xx = np.where(alpha > mask_thres) y0, y1 = yy.min(), yy.max() x0, x1 = xx.min(), xx.max() height, width, chn = rgb.shape cy = (y0 + y1) // 2 cx = (x0 + x1) // 2 y0 = cy - int(np.floor(sz * rgba.shape[0] / 2)) y1 = cy + int(np.ceil(sz * rgba.shape[0] / 2)) x0 = cx - int(np.floor(sz * rgba.shape[1] / 2)) x1 = cx + int(np.ceil(sz * rgba.shape[1] / 2)) out = rgba[ max(y0, 0) : min(y1, height) , max(x0, 0) : min(x1, width), : ].copy() pads = [(max(0-y0, 0), max(y1-height, 0)), (max(0-x0, 0), max(x1-width, 0)), (0, 0)] out = np.pad(out, pads, mode='constant', constant_values=0) out[:, :, :3] = out[:, :, :3] * (out[..., 3:]/255.) + np.array(bkg_color)[None, None, :3] * (1-out[..., 3:]/255.) out[:, :, -1] = bkg_color[-1] out = cv2.resize(out.astype(np.uint8), (256, 256)) out = out[:, :, :3] out_rgbs.append(out) return out_rgbs def run_preprocess(image1, image2, preprocess_chk, seed_value): seed_everything(seed_value) if preprocess_chk: rembg_session = rembg.new_session() image1 = remove_background(image1, force=True, rembg_session = rembg_session) image2 = remove_background(image2, force=True, rembg_session = rembg_session) rgbs = group_recenter([image1, image2]) image1 = Image.fromarray(rgbs[0]) image2 = Image.fromarray(rgbs[1]) return image1, image2 def image_to_tensor(img, width=256, height=256): img = transforms.ToTensor()(img).unsqueeze(0) img = img * 2 - 1 img = transforms.functional.resize(img, [height, width]) return img @spaces.GPU(duration=110) def run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value): seed_everything(seed_value) image1 = image_to_tensor(image1).to(_device_) image2 = image_to_tensor(image2).to(_device_) images = [image1, image2] elevs, elev_ranges = estimate_elevs( _model_, images, est_type='all', matcher_ckpt_path=_matcher_ckpt_path_ ) anchor_polar = elevs[0] if torch.mean(torch.abs(image1 - image2)) < 0.005: theta = azimuth = radius = 0 print('Identical images found!') else: noise = np.random.randn(probe_bsz, 4, 32, 32) result_poses, aux_data = estimate_poses( _model_, images, seed_cand_num=8, explore_type='triangular', refine_type='triangular', probe_ts_range=[0.2, 0.21], ts_range=[0.2, 0.21], probe_bsz=probe_bsz, adjust_factor=10.0, adjust_iters=adj_iters, adjust_bsz=adj_bsz, refine_factor=1.0, refine_iters=0, refine_bsz=4, noise=noise, elevs=elevs, elev_ranges=elev_ranges ) theta, azimuth, radius = result_poses[0] if anchor_polar is None: anchor_polar = np.pi/2 explored_sph = (float(theta), float(azimuth), float(radius)) return float(anchor_polar), explored_sph @spaces.GPU(duration=110) def run_pose_refinement(image1, image2, est_result, refine_iters, seed_value): seed_everything(seed_value) anchor_polar = est_result[0] explored_sph = est_result[1] images = [image_to_tensor(image1).to(_device_), image_to_tensor(image2).to(_device_)] images = [ img.permute(0, 2, 3, 1) for img in images ] out_poses, _, loss = find_optimal_poses( _model_, images, 1.0, bsz=1, n_iter=refine_iters, init_poses={1: explored_sph}, ts_range=[0.2, 0.21], combinations=[(0, 1), (1, 0)], avg_last_n=20, print_n=100 ) final_sph = out_poses[0] theta, azimuth, radius = final_sph xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius)) c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) return (anchor_polar, final_sph), fig def run_example(image1, image2): image1, image2 = run_preprocess(image1, image2, True, 0) anchor_polar, explored_sph = run_pose_exploration(image1, image2, 16, 4, 10, 0) return (anchor_polar, explored_sph), image1, image2 def run_or_visualize(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result): if est_result is None: anchor_polar, explored_sph = run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value) else: anchor_polar = est_result[0] explored_sph = est_result[1] print('Using cache result.') xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) xyz1 = spherical_to_cartesian((explored_sph[0] + anchor_polar, 0. + explored_sph[1], 4. + explored_sph[2])) c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) return (anchor_polar, explored_sph), fig, gr.update(interactive=True) _HEADER_ = ''' # Official 🤗 Gradio Demo for [ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models](https://github.com/xt4d/id-pose) - ID-Pose accepts input images with NO overlapping appearance. - The estimation takes about 1 minute. ZeroGPU may be halted during processing due to quota restrictions. ''' _FOOTER_ = ''' [Project Page](https://xt4d.github.io/id-pose-web/) | ⭐ [Github](https://github.com/xt4d/id-pose) ⭐ [![GitHub Stars](https://img.shields.io/github/stars/xt4d/id-pose?style=social)](https://github.com/xt4d/id-pose) --- ''' _CITE_ = r""" ```bibtex @article{cheng2023id, title={ID-Pose: Sparse-view Camera Pose Estimation by Inverting Diffusion Models}, author={Cheng, Weihao and Cao, Yan-Pei and Shan, Ying}, journal={arXiv preprint arXiv:2306.17140}, year={2023} } ``` """ def run_demo(): demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models') with demo: est_result = gr.JSON(visible=False) gr.Markdown(_HEADER_) with gr.Row(variant='panel'): with gr.Column(scale=1): with gr.Row(): input_image1 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 1') input_image2 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 2') with gr.Row(): processed_image1 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 1', interactive=False) processed_image2 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 2', interactive=False) with gr.Row(): preprocess_chk = gr.Checkbox(True, label='Remove background and recenter object') with gr.Accordion('Advanced options', open=False): probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size') adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size') adj_iters = gr.Slider(1, 20, value=10, step=1, label='Adjust Iterations') seed_value = gr.Number(value=0, label="Seed Value", precision=0) with gr.Row(): run_btn = gr.Button('Estimate', variant='primary', interactive=True) with gr.Row(): refine_iters = gr.Slider(0, 1000, value=0, step=50, label='Refinement Iterations') with gr.Row(): refine_btn = gr.Button('Refine', variant='primary', interactive=False) with gr.Row(): gr.Markdown(_FOOTER_) with gr.Row(): gr.Markdown(_CITE_) with gr.Column(scale=1.4): with gr.Row(): vis_output = gr.Plot(label='Camera Pose Results: anchor (red) and target (blue)') with gr.Row(): with gr.Column(min_width=200): gr.Examples( examples = [ ['data/gradio_demo/duck_0.png', 'data/gradio_demo/duck_1.png'], ['data/gradio_demo/chair_0.png', 'data/gradio_demo/chair_1.png'], ['data/gradio_demo/foosball_0.png', 'data/gradio_demo/foosball_1.png'], ['data/gradio_demo/bunny_0.png', 'data/gradio_demo/bunny_1.png'], ['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'], ], inputs=[input_image1, input_image2], fn=run_example, outputs=[est_result, processed_image1, processed_image2], label='Examples (Captured)', cache_examples='lazy', examples_per_page=5 ) with gr.Column(min_width=200): gr.Examples( examples = [ ['data/gradio_demo/arc_0.png', 'data/gradio_demo/arc_1.png'], ['data/gradio_demo/husky_0.png', 'data/gradio_demo/husky_1.png'], ['data/gradio_demo/cybertruck_0.png', 'data/gradio_demo/cybertruck_1.png'], ['data/gradio_demo/plane_0.png', 'data/gradio_demo/plane_1.png'], ['data/gradio_demo/christ_0.png', 'data/gradio_demo/christ_1.png'], ], inputs=[input_image1, input_image2], fn=run_example, outputs=[est_result, processed_image1, processed_image2], label='Examples (Internet)', cache_examples='lazy', examples_per_page=5 ) with gr.Column(min_width=200): gr.Examples( examples = [ ['data/gradio_demo/status_0.png', 'data/gradio_demo/status_1.png'], ['data/gradio_demo/cat_0.png', 'data/gradio_demo/cat_1.png'], ['data/gradio_demo/ferrari_0.png', 'data/gradio_demo/ferrari_1.png'], ['data/gradio_demo/elon_0.png', 'data/gradio_demo/elon_1.png'], ['data/gradio_demo/ride_horse_0.png', 'data/gradio_demo/ride_horse_1.png'], ], inputs=[input_image1, input_image2], fn=run_example, outputs=[est_result, processed_image1, processed_image2], label='Examples (Generated)', cache_examples='lazy', examples_per_page=5 ) run_btn.click( fn=run_preprocess, inputs=[input_image1, input_image2, preprocess_chk, seed_value], outputs=[processed_image1, processed_image2], ).success( fn=run_or_visualize, inputs=[processed_image1, processed_image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result], outputs=[est_result, vis_output, refine_btn] ) refine_btn.click( fn=run_pose_refinement, inputs=[processed_image1, processed_image2, est_result, refine_iters, seed_value], outputs=[est_result, vis_output] ) input_image1.clear( fn=lambda: None, outputs=[est_result] ) input_image2.clear( fn=lambda: None, outputs=[est_result] ) demo.launch() if __name__ == '__main__': run_demo()