Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import os | |
import numpy as np | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from functools import partial | |
import gradio as gr | |
from huggingface_hub import hf_hub_download | |
import torch | |
from torchvision import transforms | |
import rembg | |
import cv2 | |
from pytorch_lightning import seed_everything | |
from src.visualizer import CameraVisualizer | |
from src.pose_estimation import load_model_from_config, estimate_poses, estimate_elevs | |
from src.pose_funcs import find_optimal_poses | |
from src.utils import spherical_to_cartesian, elu_to_c2w | |
if torch.cuda.is_available(): | |
_device_ = 'cuda:0' | |
else: | |
_device_ = 'cpu' | |
_config_path_ = 'src/configs/sd-objaverse-finetune-c_concat-256.yaml' | |
_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/zero123-xl.ckpt', repo_type='model') | |
_matcher_ckpt_path_ = hf_hub_download(repo_id='tokenid/ID-Pose', filename='ckpts/indoor_ds_new.ckpt', repo_type='model') | |
_config_ = OmegaConf.load(_config_path_) | |
_model_ = load_model_from_config(_config_, _ckpt_path_, device='cpu') | |
_model_ = _model_.to(_device_) | |
_model_.eval() | |
def rgba_to_rgb(img): | |
assert img.mode == 'RGBA' | |
img = np.asarray(img, dtype=np.float32) | |
img[:, :, :3] = img[:, :, :3] * (img[..., 3:]/255.) + (255-img[..., 3:]) | |
img = img.clip(0, 255).astype(np.uint8) | |
return Image.fromarray(img[:, :, :3]) | |
def remove_background(image, rembg_session = None, force = False, **rembg_kwargs): | |
do_remove = True | |
if image.mode == "RGBA" and image.getextrema()[3][0] < 255: | |
do_remove = False | |
do_remove = do_remove or force | |
if do_remove: | |
image = rembg.remove(image, session=rembg_session, **rembg_kwargs) | |
return image | |
def group_recenter(images, ratio=1.5, mask_thres=127, bkg_color=[255, 255, 255, 255]): | |
ws = [] | |
hs = [] | |
images = [ np.asarray(img) for img in images ] | |
for img in images: | |
alpha = img[:, :, 3] | |
yy, xx = np.where(alpha > mask_thres) | |
y0, y1 = yy.min(), yy.max() | |
x0, x1 = xx.min(), xx.max() | |
ws.append(float(x1 - x0) / img.shape[0]) | |
hs.append(float(y1 - y0) / img.shape[1]) | |
sz_w = np.max(ws) | |
sz_h = np.max(hs) | |
sz = max(ratio*sz_w, ratio*sz_h) | |
out_rgbs = [] | |
for rgba in images: | |
rgb = rgba[:, :, :3] | |
alpha = rgba[:, :, 3] | |
yy, xx = np.where(alpha > mask_thres) | |
y0, y1 = yy.min(), yy.max() | |
x0, x1 = xx.min(), xx.max() | |
height, width, chn = rgb.shape | |
cy = (y0 + y1) // 2 | |
cx = (x0 + x1) // 2 | |
y0 = cy - int(np.floor(sz * rgba.shape[0] / 2)) | |
y1 = cy + int(np.ceil(sz * rgba.shape[0] / 2)) | |
x0 = cx - int(np.floor(sz * rgba.shape[1] / 2)) | |
x1 = cx + int(np.ceil(sz * rgba.shape[1] / 2)) | |
out = rgba[ max(y0, 0) : min(y1, height) , max(x0, 0) : min(x1, width), : ].copy() | |
pads = [(max(0-y0, 0), max(y1-height, 0)), (max(0-x0, 0), max(x1-width, 0)), (0, 0)] | |
out = np.pad(out, pads, mode='constant', constant_values=0) | |
out[:, :, :3] = out[:, :, :3] * (out[..., 3:]/255.) + np.array(bkg_color)[None, None, :3] * (1-out[..., 3:]/255.) | |
out[:, :, -1] = bkg_color[-1] | |
out = cv2.resize(out.astype(np.uint8), (256, 256)) | |
out = out[:, :, :3] | |
out_rgbs.append(out) | |
return out_rgbs | |
def run_preprocess(image1, image2, preprocess_chk, seed_value): | |
seed_everything(seed_value) | |
if preprocess_chk: | |
rembg_session = rembg.new_session() | |
image1 = remove_background(image1, force=True, rembg_session = rembg_session) | |
image2 = remove_background(image2, force=True, rembg_session = rembg_session) | |
rgbs = group_recenter([image1, image2]) | |
image1 = Image.fromarray(rgbs[0]) | |
image2 = Image.fromarray(rgbs[1]) | |
return image1, image2 | |
def image_to_tensor(img, width=256, height=256): | |
img = transforms.ToTensor()(img).unsqueeze(0) | |
img = img * 2 - 1 | |
img = transforms.functional.resize(img, [height, width]) | |
return img | |
def run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value): | |
seed_everything(seed_value) | |
image1 = image_to_tensor(image1).to(_device_) | |
image2 = image_to_tensor(image2).to(_device_) | |
images = [image1, image2] | |
elevs, elev_ranges = estimate_elevs( | |
_model_, images, | |
est_type='all', | |
matcher_ckpt_path=_matcher_ckpt_path_ | |
) | |
anchor_polar = elevs[0] | |
if torch.mean(torch.abs(image1 - image2)) < 0.005: | |
theta = azimuth = radius = 0 | |
print('Identical images found!') | |
else: | |
noise = np.random.randn(probe_bsz, 4, 32, 32) | |
result_poses, aux_data = estimate_poses( | |
_model_, images, | |
seed_cand_num=8, | |
explore_type='triangular', | |
refine_type='triangular', | |
probe_ts_range=[0.2, 0.21], | |
ts_range=[0.2, 0.21], | |
probe_bsz=probe_bsz, | |
adjust_factor=10.0, | |
adjust_iters=adj_iters, | |
adjust_bsz=adj_bsz, | |
refine_factor=1.0, | |
refine_iters=0, | |
refine_bsz=4, | |
noise=noise, | |
elevs=elevs, | |
elev_ranges=elev_ranges | |
) | |
theta, azimuth, radius = result_poses[0] | |
if anchor_polar is None: | |
anchor_polar = np.pi/2 | |
explored_sph = (theta, azimuth, radius) | |
return anchor_polar, explored_sph | |
def run_pose_refinement(image1, image2, est_result, refine_iters, seed_value): | |
seed_everything(seed_value) | |
anchor_polar = est_result[0] | |
explored_sph = est_result[1] | |
images = [image_to_tensor(image1).to(_device_), image_to_tensor(image2).to(_device_)] | |
images = [ img.permute(0, 2, 3, 1) for img in images ] | |
out_poses, _, loss = find_optimal_poses( | |
_model_, images, | |
1.0, | |
bsz=1, | |
n_iter=refine_iters, | |
init_poses={1: explored_sph}, | |
ts_range=[0.2, 0.21], | |
combinations=[(0, 1), (1, 0)], | |
avg_last_n=20, | |
print_n=100 | |
) | |
final_sph = out_poses[0] | |
theta, azimuth, radius = final_sph | |
xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) | |
c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) | |
xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius)) | |
c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) | |
cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) | |
fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) | |
return (anchor_polar, final_sph), fig | |
def run_example(image1, image2): | |
image1, image2 = run_preprocess(image1, image2, True, 0) | |
anchor_polar, explored_sph = run_pose_exploration(image1, image2, 16, 4, 10, 0) | |
return (anchor_polar, explored_sph), image1, image2 | |
def run_or_visualize(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result): | |
if est_result is None: | |
anchor_polar, explored_sph = run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value) | |
else: | |
anchor_polar = est_result[0] | |
explored_sph = est_result[1] | |
print('Using cache result.') | |
xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.)) | |
c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.])) | |
xyz1 = spherical_to_cartesian((explored_sph[0] + anchor_polar, 0. + explored_sph[1], 4. + explored_sph[2])) | |
c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.])) | |
cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)]) | |
fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True) | |
return (anchor_polar, explored_sph), fig, gr.update(interactive=True) | |
_HEADER_ = ''' | |
# Official 🤗 Gradio Demo for [ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models](https://github.com/xt4d/id-pose) | |
- ID-Pose accepts input images with NO overlapping appearance. | |
- The estimation takes about 1 minute. ZeroGPU may be halted during processing due to quota restrictions. | |
''' | |
_FOOTER_ = ''' | |
[Project Page](https://xt4d.github.io/id-pose-web/) | ⭐ [Github](https://github.com/xt4d/id-pose) ⭐ [![GitHub Stars](https://img.shields.io/github/stars/xt4d/id-pose?style=social)](https://github.com/xt4d/id-pose) | |
--- | |
''' | |
_CITE_ = r""" | |
```bibtex | |
@article{cheng2023id, | |
title={ID-Pose: Sparse-view Camera Pose Estimation by Inverting Diffusion Models}, | |
author={Cheng, Weihao and Cao, Yan-Pei and Shan, Ying}, | |
journal={arXiv preprint arXiv:2306.17140}, | |
year={2023} | |
} | |
``` | |
""" | |
def run_demo(): | |
demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models') | |
with demo: | |
est_result = gr.JSON(visible=False) | |
gr.Markdown(_HEADER_) | |
with gr.Row(variant='panel'): | |
with gr.Column(scale=1): | |
with gr.Row(): | |
with gr.Column(min_width=280): | |
input_image1 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 1', width=280) | |
with gr.Column(min_width=280): | |
input_image2 = gr.Image(type='pil', image_mode='RGBA', label='Input Image 2', width=280) | |
with gr.Row(): | |
with gr.Column(min_width=280): | |
processed_image1 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 1', width=280, interactive=False) | |
with gr.Column(min_width=280): | |
processed_image2 = gr.Image(type='numpy', image_mode='RGB', label='Processed Image 2', width=280, interactive=False) | |
with gr.Row(): | |
preprocess_chk = gr.Checkbox(True, label='Remove background and recenter object') | |
with gr.Accordion('Advanced options', open=False): | |
probe_bsz = gr.Slider(4, 32, value=16, step=4, label='Probe Batch Size') | |
adj_bsz = gr.Slider(1, 8, value=4, step=1, label='Adjust Batch Size') | |
adj_iters = gr.Slider(1, 20, value=10, step=1, label='Adjust Iterations') | |
seed_value = gr.Number(value=0, label="Seed Value", precision=0) | |
with gr.Row(): | |
run_btn = gr.Button('Estimate', variant='primary', interactive=True) | |
with gr.Row(): | |
refine_iters = gr.Slider(0, 1000, value=0, step=50, label='Refinement Iterations') | |
with gr.Row(): | |
refine_btn = gr.Button('Refine', variant='primary', interactive=False) | |
with gr.Row(): | |
gr.Markdown(_FOOTER_) | |
with gr.Row(): | |
gr.Markdown(_CITE_) | |
with gr.Column(scale=1.4): | |
with gr.Row(): | |
vis_output = gr.Plot(label='Camera Pose Results: anchor (red) and target (blue)') | |
with gr.Row(): | |
with gr.Column(min_width=200): | |
gr.Examples( | |
examples = [ | |
['data/gradio_demo/duck_0.png', 'data/gradio_demo/duck_1.png'], | |
['data/gradio_demo/chair_0.png', 'data/gradio_demo/chair_1.png'], | |
['data/gradio_demo/foosball_0.png', 'data/gradio_demo/foosball_1.png'], | |
['data/gradio_demo/bunny_0.png', 'data/gradio_demo/bunny_1.png'], | |
['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'], | |
], | |
inputs=[input_image1, input_image2], | |
fn=run_example, | |
outputs=[est_result, processed_image1, processed_image2], | |
label='Examples (Captured)', | |
cache_examples='lazy', | |
examples_per_page=5 | |
) | |
with gr.Column(min_width=200): | |
gr.Examples( | |
examples = [ | |
['data/gradio_demo/arc_0.png', 'data/gradio_demo/arc_1.png'], | |
['data/gradio_demo/husky_0.png', 'data/gradio_demo/husky_1.png'], | |
['data/gradio_demo/cybertruck_0.png', 'data/gradio_demo/cybertruck_1.png'], | |
['data/gradio_demo/plane_0.png', 'data/gradio_demo/plane_1.png'], | |
['data/gradio_demo/christ_0.png', 'data/gradio_demo/christ_1.png'], | |
], | |
inputs=[input_image1, input_image2], | |
fn=run_example, | |
outputs=[est_result, processed_image1, processed_image2], | |
label='Examples (Internet)', | |
cache_examples='lazy', | |
examples_per_page=5 | |
) | |
with gr.Column(min_width=200): | |
gr.Examples( | |
examples = [ | |
['data/gradio_demo/status_0.png', 'data/gradio_demo/status_1.png'], | |
['data/gradio_demo/cat_0.png', 'data/gradio_demo/cat_1.png'], | |
['data/gradio_demo/ferrari_0.png', 'data/gradio_demo/ferrari_1.png'], | |
['data/gradio_demo/elon_0.png', 'data/gradio_demo/elon_1.png'], | |
['data/gradio_demo/ride_horse_0.png', 'data/gradio_demo/ride_horse_1.png'], | |
], | |
inputs=[input_image1, input_image2], | |
fn=run_example, | |
outputs=[est_result, processed_image1, processed_image2], | |
label='Examples (Generated)', | |
cache_examples='lazy', | |
examples_per_page=5 | |
) | |
run_btn.click( | |
fn=run_preprocess, | |
inputs=[input_image1, input_image2, preprocess_chk, seed_value], | |
outputs=[processed_image1, processed_image2], | |
).success( | |
fn=run_or_visualize, | |
inputs=[processed_image1, processed_image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result], | |
outputs=[est_result, vis_output, refine_btn] | |
) | |
refine_btn.click( | |
fn=run_pose_refinement, | |
inputs=[processed_image1, processed_image2, est_result, refine_iters, seed_value], | |
outputs=[est_result, vis_output] | |
) | |
input_image1.clear( | |
fn=lambda: None, | |
outputs=[est_result] | |
) | |
input_image2.clear( | |
fn=lambda: None, | |
outputs=[est_result] | |
) | |
demo.launch() | |
if __name__ == '__main__': | |
run_demo() | |