ID-Pose / app.py
tokenid
add random seed
fc6f56d
raw
history blame
13.7 kB
import spaces
import os
import numpy as np
from PIL import Image
from omegaconf import OmegaConf
from functools import partial
import gradio as gr
from huggingface_hub import hf_hub_download
import torch
from torchvision import transforms
import rembg
import cv2
from pytorch_lightning import seed_everything
from src.visualizer import CameraVisualizer
from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs
from src.pose_funcs import find_optimal_poses
from src.utils import spherical_to_cartesian, elu_to_c2w
if torch.cuda.is_available():
_device_ = 'cuda:0'
else:
_device_ = 'cpu'
_config_path_ = 'src/configs/sd-objaverse-finetune-c_concat-256.yaml'
_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/zero123-xl.ckpt', repo_type='model')
_matcher_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/indoor_ds_new.ckpt', repo_type='model')
_config_ = OmegaConf.load(_config_path_)
_model_ = load_model_from_config(_config_, _ckpt_path_, device='cpu')
_model_ = _model_.to(_device_)
_model_.eval()
def rgba_to_rgb(img):
assert img.mode == 'RGBA'
img = np.asarray(img, dtype=np.float32)
img[:, :, :3] = img[:, :, :3] * (img[..., 3:]/255.) + (255-img[..., 3:])
img = img.clip(0, 255).astype(np.uint8)
return Image.fromarray(img[:, :, :3])
def remove_background(image, rembg_session = None, force = False, **rembg_kwargs):
do_remove = True
if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
do_remove = False
do_remove = do_remove or force
if do_remove:
image = rembg.remove(image, session=rembg_session, **rembg_kwargs)
return image
def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255, 255]):
ws = []
hs = []
images = [ np.asarray(img) for img in images ]
for img in images:
alpha = img[:, :, 3]
yy, xx = np.where(alpha > mask_thres)
y0, y1 = yy.min(), yy.max()
x0, x1 = xx.min(), xx.max()
ws.append(x1 - x0)
hs.append(y1 - y0)
sz_w = np.max(ws)
sz_h = np.max(hs)
sz = int( max(ratio*sz_w, ratio*sz_h) )
out_rgbs = []
for rgba in images:
rgb = rgba[:, :, :3]
alpha = rgba[:, :, 3]
yy, xx = np.where(alpha > mask_thres)
y0, y1 = yy.min(), yy.max()
x0, x1 = xx.min(), xx.max()
height, width, chn = rgb.shape
cy = (y0 + y1) // 2
cx = (x0 + x1) // 2
y0 = cy - int(np.floor(sz / 2))
y1 = cy + int(np.ceil(sz / 2))
x0 = cx - int(np.floor(sz / 2))
x1 = cx + int(np.ceil(sz / 2))
out = rgba[ max(y0, 0) : min(y1, height) , max(x0, 0) : min(x1, width), : ].copy()
pads = [(max(0-y0, 0), max(y1-height, 0)), (max(0-x0, 0), max(x1-width, 0)), (0, 0)]
out = np.pad(out, pads, mode='constant', constant_values=0)
assert(out.shape[:2] == (sz, sz))
out[:, :, :3] = out[:, :, :3] * (out[..., 3:]/255.) + np.array(bkg_color)[None, None, :3] * (1-out[..., 3:]/255.)
out[:, :, -1] = bkg_color[-1]
out = cv2.resize(out.astype(np.uint8), (256, 256))
out = out[:, :, :3]
out_rgbs.append(out)
return out_rgbs
def run_preprocess(image1, image2, preprocess_chk, seed_value):
seed_everything(seed_value)
if preprocess_chk:
rembg_session = rembg.new_session()
image1 = remove_background(image1, force=True, rembg_session = rembg_session)
image2 = remove_background(image2, force=True, rembg_session = rembg_session)
rgbs = group_recenter([image1, image2])
image1 = Image.fromarray(rgbs[0])
image2 = Image.fromarray(rgbs[1])
return image1, image2
def image_to_tensor(img, width=256, height=256):
img = transforms.ToTensor()(img).unsqueeze(0)
img = img * 2 - 1
img = transforms.functional.resize(img, [height, width])
return img
@spaces.GPU
def run_pose_exploration_a(image1, image2, seed_value):
seed_everything(seed_value)
image1 = image_to_tensor(image1).to(_device_)
image2 = image_to_tensor(image2).to(_device_)
images = [image1, image2]
elevs, elev_ranges = estimate_elevs(
_model_, images,
est_type='all',
matcher_ckpt_path=_matcher_ckpt_path_
)
fig = None
return elevs, elev_ranges, fig
@spaces.GPU
def run_pose_exploration_b(cam_vis, image1, image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters, seed_value):
seed_everything(seed_value)
noise = np.random.randn(probe_bsz, 4, 32, 32)
cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
image1 = image_to_tensor(image1).to(_device_)
image2 = image_to_tensor(image2).to(_device_)
images = [image1, image2]
result_poses, aux_data = estimate_poses(
_model_, images,
seed_cand_num=8,
explore_type='triangular',
refine_type='triangular',
probe_ts_range=[0.2, 0.21],
ts_range=[0.2, 0.21],
probe_bsz=probe_bsz,
adjust_factor=10.0,
adjust_iters=adj_iters,
adjust_bsz=adj_bsz,
refine_factor=1.0,
refine_iters=0,
refine_bsz=4,
noise=noise,
elevs=elevs,
elev_ranges=elev_ranges
)
theta, azimuth, radius = result_poses[0]
anchor_polar = aux_data['elev'][0]
if anchor_polar is None:
anchor_polar = np.pi/2
xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.))
c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.]))
xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius))
c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
cam_vis._poses = [c2w0, c2w1]
fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
explored_sph = (theta, azimuth, radius)
return anchor_polar, explored_sph, fig, gr.update(interactive=True)
@spaces.GPU
def run_pose_refinement(cam_vis, image1, image2, anchor_polar, explored_sph, refine_iters, seed_value):
seed_everything(seed_value)
cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
image1 = image_to_tensor(image1).to(_device_)
image2 = image_to_tensor(image2).to(_device_)
images = [image1, image2]
images = [ img.permute(0, 2, 3, 1) for img in images ]
out_poses, _, loss = find_optimal_poses(
_model_, images,
1.0,
bsz=1,
n_iter=refine_iters,
init_poses={1: explored_sph},
ts_range=[0.2, 0.21],
combinations=[(0, 1), (1, 0)],
avg_last_n=20,
print_n=100
)
final_sph = out_poses[0]
theta, azimuth, radius = final_sph
xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.))
c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.]))
xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius))
c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
cam_vis._poses = [c2w0, c2w1]
fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
return final_sph, fig
_HEADER_ = '''
# Official 🤗 Gradio Demo for [ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models](https://github.com/xt4d/id-pose)
- ID-Pose accepts input images with NO overlapping appearance.
- The estimation takes about 1 minute. ZeroGPU may be halted during processing due to quota restrictions.
'''
_FOOTER_ = '''
- Project Page: [https://xt4d.github.io/id-pose-web/](https://xt4d.github.io/id-pose-web/)
- Github: [https://github.com/xt4d/id-pose](https://github.com/xt4d/id-pose)
'''
_CITE_ = r"""
```bibtex
@article{cheng2023id,
title={ID-Pose: Sparse-view Camera Pose Estimation by Inverting Diffusion Models},
author={Cheng, Weihao and Cao, Yan-Pei and Shan, Ying},
journal={arXiv preprint arXiv:2306.17140},
year={2023}
}
```
"""
def run_demo():
demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models')
with demo:
gr.Markdown(_HEADER_)
with gr.Row(variant='panel'):
with gr.Column(scale=1):
with gr.Row():
with gr.Column(min_width=280):
input_image1 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 1', width=280)
with gr.Column(min_width=280):
input_image2 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 2', width=280)
with gr.Row():
with gr.Column(min_width=280):
processed_image1 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 1', width=280, interactive=False)
with gr.Column(min_width=280):
processed_image2 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 2', width=280, interactive=False)
with gr.Row():
preprocess_chk = gr.Checkbox(True, label='Remove background and recenter object')
with gr.Accordion('Advanced options', open=False):
probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size')
adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size')
adj_iters = gr.Slider(1, 20, value=8, step=1, label='Adjust Iterations')
seed_value = gr.Number(value=0, label="Seed Value", precision=0)
with gr.Row():
run_btn = gr.Button('Estimate', variant='primary', interactive=True)
with gr.Row():
refine_iters = gr.Slider(0, 1000, value=0, step=50, label='Refinement Iterations')
with gr.Row():
refine_btn = gr.Button('Refine', variant='primary', interactive=False)
with gr.Row():
gr.Markdown(_FOOTER_)
with gr.Row():
gr.Markdown(_CITE_)
with gr.Column(scale=1.4):
with gr.Row():
vis_output = gr.Plot(label='Camera Pose Results: anchor (red) and target (blue)')
with gr.Row():
with gr.Column(min_width=200):
gr.Examples(
examples = [
['data/gradio_demo/duck_0.png', 'data/gradio_demo/duck_1.png'],
['data/gradio_demo/chair_0.png', 'data/gradio_demo/chair_1.png'],
['data/gradio_demo/foosball_0.png', 'data/gradio_demo/foosball_1.png'],
],
inputs=[input_image1, input_image2],
label='Examples (Self-captured)',
cache_examples=False,
examples_per_page=3
)
with gr.Column(min_width=200):
gr.Examples(
examples = [
['data/gradio_demo/bunny_0.png', 'data/gradio_demo/bunny_1.png'],
['data/gradio_demo/bus_0.png', 'data/gradio_demo/bus_1.png'],
['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'],
],
inputs=[input_image1, input_image2],
label='Examples (Images from NAVI)',
cache_examples=False,
examples_per_page=3
)
with gr.Column(min_width=200):
gr.Examples(
examples = [
['data/gradio_demo/status_0.png', 'data/gradio_demo/status_1.png'],
['data/gradio_demo/bag_0.png', 'data/gradio_demo/bag_1.png'],
['data/gradio_demo/cat_0.png', 'data/gradio_demo/cat_1.png'],
],
inputs=[input_image1, input_image2],
label='Examples (Generated)',
cache_examples=False,
examples_per_page=3
)
cam_vis = CameraVisualizer([np.eye(4), np.eye(4)], ['Image 1', 'Image 2'], ['red', 'blue'])
explored_sph = gr.State()
anchor_polar = gr.State()
refined_sph = gr.State()
elevs = gr.State()
elev_ranges = gr.State()
run_btn.click(
fn=run_preprocess,
inputs=[input_image1, input_image2, preprocess_chk, seed_value],
outputs=[processed_image1, processed_image2],
).success(
fn=run_pose_exploration_a,
inputs=[processed_image1, processed_image2, seed_value],
outputs=[elevs, elev_ranges, vis_output]
).success(
fn=partial(run_pose_exploration_b, cam_vis),
inputs=[processed_image1, processed_image2, elevs, elev_ranges, probe_bsz, adj_bsz, adj_iters, seed_value],
outputs=[anchor_polar, explored_sph, vis_output, refine_btn]
)
refine_btn.click(
fn=partial(run_pose_refinement, cam_vis),
inputs=[processed_image1, processed_image2, anchor_polar, explored_sph, refine_iters, seed_value],
outputs=[refined_sph, vis_output]
)
demo.launch()
if __name__ == '__main__':
run_demo()