genwarp / app.py
kaz-sony's picture
add demo
832c977
raw
history blame
No virus
12.2 kB
import sys
import os
from subprocess import check_call
import tempfile
from os.path import basename, splitext, join
from io import BytesIO
import numpy as np
from scipy.spatial import KDTree
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor, to_pil_image
from einops import rearrange
import gradio as gr
from huggingface_hub import hf_hub_download
from extern.ZoeDepth.zoedepth.utils.misc import colorize
from gradio_model3dgscamera import Model3DGSCamera
IMAGE_SIZE = 512
NEAR, FAR = 0.01, 100
FOVY = np.deg2rad(55)
def download_models():
models = [
{
'repo': 'stabilityai/sd-vae-ft-mse',
'sub': None,
'dst': 'checkpoints/sd-vae-ft-mse',
'files': ['config.json', 'diffusion_pytorch_model.safetensors'],
'token': None
},
{
'repo': 'lambdalabs/sd-image-variations-diffusers',
'sub': 'image_encoder',
'dst': 'checkpoints',
'files': ['config.json', 'pytorch_model.bin'],
'token': None
},
{
'repo': 'Sony/genwarp',
'sub': 'multi1',
'dst': 'checkpoints',
'files': ['config.json', 'denoising_unet.pth', 'pose_guider.pth', 'reference_unet.pth'],
'token': None
}
]
for model in models:
for file in model['files']:
hf_hub_download(
repo_id=model['repo'],
subfolder=model['sub'],
filename=file,
local_dir=model['dst'],
token=model['token']
)
# Crop the image to the shorter side.
def crop(img: Image) -> Image:
W, H = img.size
if W < H:
left, right = 0, W
top, bottom = np.ceil((H - W) / 2.), np.floor((H - W) / 2.) + W
else:
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
top, bottom = 0, H
return img.crop((left, top, right, bottom))
def unproject(depth):
fovy_deg = 55
H, W = depth.shape[2:4]
mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
viewport_mtx = get_viewport_matrix(
IMAGE_SIZE, IMAGE_SIZE,
batch_size=1
).to(depth)
# Projection matrix.
fovy = torch.ones(1) * FOVY
proj_mtx = get_projection_matrix(
fovy=fovy,
aspect_wh=1.,
near=NEAR,
far=FAR
).to(depth)
view_mtx = camera_lookat(
torch.tensor([[0., 0., 0.]]),
torch.tensor([[0., 0., 1.]]),
torch.tensor([[0., -1., 0.]])
).to(depth)
scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
grid = torch.stack(torch.meshgrid(
torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
).to(depth)[None] # BHW2
screen = F.pad(grid, (0, 1), 'constant', 0)
screen = F.pad(screen, (0, 1), 'constant', 1)
screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
eye = screen_flat @ torch.linalg.inv_ex(
scr_mtx.float()
)[0].mT.to(depth)
eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
eye[..., 3] = 1
points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
points = points[0, :, :3]
# Translate to the origin.
points[..., 2] -= mean_depth
camera_pos = (0, 0, -mean_depth)
view_mtx = camera_lookat(
torch.tensor([[0., 0., -mean_depth]]),
torch.tensor([[0., 0., 0.]]),
torch.tensor([[0., -1., 0.]])
).to(depth)
return points, camera_pos, view_mtx, proj_mtx
def calc_dist2(points: np.ndarray):
dists, _ = KDTree(points).query(points, k=4)
mean_dists = (dists[:, 1:] ** 2).mean(1)
return mean_dists
def save_as_splat(
filepath: str,
xyz: np.ndarray,
rgb: np.ndarray
):
# To gaussian splat
inv_sigmoid = lambda x: np.log(x / (1 - x))
dist2 = np.clip(calc_dist2(xyz), a_min=0.0000001, a_max=None)
scales = np.repeat(np.log(np.sqrt(dist2))[..., np.newaxis], 3, axis=1)
rots = np.zeros((xyz.shape[0], 4))
rots[:, 0] = 1
opacities = inv_sigmoid(0.1 * np.ones((xyz.shape[0], 1)))
sorted_indices = np.argsort((
-np.exp(np.sum(scales, axis=-1, keepdims=True))
/ (1 + np.exp(-opacities))
).squeeze())
buffer = BytesIO()
for idx in sorted_indices:
position = xyz[idx]
scale = np.exp(scales[idx]).astype(np.float32)
rot = rots[idx].astype(np.float32)
color = np.concatenate(
(rgb[idx], 1 / (1 + np.exp(-opacities[idx]))),
axis=-1
)
buffer.write(position.tobytes())
buffer.write(scale.tobytes())
buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes())
buffer.write(
((rot / np.linalg.norm(rot)) * 128 + 128)
.clip(0, 255)
.astype(np.uint8)
.tobytes()
)
with open(filepath, "wb") as f:
f.write(buffer.getvalue())
def view_from_rt(position, rotation):
t = np.array(position)
euler = np.array(rotation)
cx = np.cos(euler[0])
sx = np.sin(euler[0])
cy = np.cos(euler[1])
sy = np.sin(euler[1])
cz = np.cos(euler[2])
sz = np.sin(euler[2])
R = np.array([
cy * cz + sy * sx * sz,
-cy * sz + sy * sx * cz,
sy * cx,
cx * sz,
cx * cz,
-sx,
-sy * cz + cy * sx * sz,
sy * sz + cy * sx * cz,
cy * cx
])
view_mtx = np.array([
[R[0], R[1], R[2], 0],
[R[3], R[4], R[5], 0],
[R[6], R[7], R[8], 0],
[
-t[0] * R[0] - t[1] * R[3] - t[2] * R[6],
-t[0] * R[1] - t[1] * R[4] - t[2] * R[7],
-t[0] * R[2] - t[1] * R[5] - t[2] * R[8],
1
]
]).T
B = np.array([
[1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, -1, 0],
[0, 0, 0, 1]
])
return B @ view_mtx
# Setup.
download_models()
mde = torch.hub.load(
'./extern/ZoeDepth',
'ZoeD_N',
source='local',
pretrained=True,
trust_repo=True
)
import spaces
check_call([
sys.executable, '-m', 'pip', 'install',
'extern/splatting-0.0.1-py3-none-any.whl'
])
from genwarp import GenWarp
from genwarp.ops import (
camera_lookat, get_projection_matrix, get_viewport_matrix
)
# GenWarp
genwarp_cfg = dict(
pretrained_model_path='checkpoints',
checkpoint_name='multi1',
half_precision_weights=True
)
genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
with tempfile.TemporaryDirectory() as tmpdir:
with gr.Blocks(
title='GenWarp Demo',
css='img {display: inline;}'
) as demo:
# Internal states.
src_image = gr.State()
src_depth = gr.State()
proj_mtx = gr.State()
src_view_mtx = gr.State()
# Blocks.
gr.Markdown(
"""
# GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping
[![Project Site](https://img.shields.io/badge/Project-Web-green)](https://genwarp-nvs.github.io/) &nbsp;
[![Spaces](https://img.shields.io/badge/Spaces-Demo-yellow?logo=huggingface)](https://huggingface.co/spaces/Sony/GenWarp) &nbsp;
[![Github](https://img.shields.io/badge/Github-Repo-orange?logo=github)](https://github.com/sony/genwarp/) &nbsp;
[![Models](https://img.shields.io/badge/Models-checkpoints-blue?logo=huggingface)](https://huggingface.co/Sony/genwarp) &nbsp;
[![arXiv](https://img.shields.io/badge/arXiv-2405.17251-red?logo=arxiv)](https://arxiv.org/abs/2405.17251)
## Introduction
This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer the [paper](https://arxiv.org/abs/2405.17251).
## How to Use
1. Upload a reference image to "Reference Input"
- You can also select a image from "Examples"
2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
3. Hit "Generate a novel view" button and check the result
"""
)
file = gr.File(label='Reference Input', file_types=['image'])
examples = gr.Examples(
examples=['./assets/pexels-heyho-5998120_19mm.jpg',
'./assets/pexels-itsterrymag-12639296_24mm.jpg'],
inputs=file
)
with gr.Row():
image_widget = gr.Image(
label='Reference View', type='filepath',
interactive=False
)
depth_widget = gr.Image(label='Estimated Depth', type='pil')
viewer = Model3DGSCamera(
label = 'Unprojected 3DGS',
width=IMAGE_SIZE,
height=IMAGE_SIZE,
camera_width=IMAGE_SIZE,
camera_height=IMAGE_SIZE,
camera_fx=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
camera_fy=IMAGE_SIZE / (np.tan(FOVY / 2.)) / 2.,
camera_near=NEAR,
camera_far=FAR
)
button = gr.Button('Generate a novel view', size='lg', variant='primary')
with gr.Row():
warped_widget = gr.Image(
label='Warped Image', type='pil', interactive=False
)
gen_widget = gr.Image(
label='Generated View', type='pil', interactive=False
)
# Callbacks
@spaces.GPU
def cb_mde(image_file: str):
image = to_tensor(crop(Image.open(
image_file
).convert('RGB')).resize((IMAGE_SIZE, IMAGE_SIZE)))[None].cuda()
depth = mde.cuda().infer(image)
depth_image = to_pil_image(colorize(depth[0]))
return to_pil_image(image[0]), depth_image, image.cpu().detach(), depth.cpu().detach()
@spaces.GPU
def cb_3d(image, depth, image_file):
xyz, camera_pos, view_mtx, proj_mtx = unproject(depth.cuda())
rgb = rearrange(image, 'b c h w -> b (h w) c')[0]
splat_file = join(tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
save_as_splat(splat_file, xyz.cpu().detach().numpy(), rgb.cpu().detach().numpy())
return (splat_file, camera_pos, None), view_mtx.cpu().detach(), proj_mtx.cpu().detach()
@spaces.GPU
def cb_generate(viewer, image, depth, src_view_mtx, proj_mtx):
image = image.cuda()
depth = depth.cuda()
src_view_mtx = src_view_mtx.cuda()
proj_mtx = proj_mtx.cuda()
src_camera_pos = viewer[1]
src_camera_rot = viewer[2]
tar_view_mtx = view_from_rt(src_camera_pos, src_camera_rot)
tar_view_mtx = torch.from_numpy(tar_view_mtx).to(image)
rel_view_mtx = (
tar_view_mtx @ torch.linalg.inv(src_view_mtx.to(image))
).to(image)
# GenWarp.
renders = genwarp_nvs.to('cuda')(
src_image=image.half(),
src_depth=depth.half(),
rel_view_mtx=rel_view_mtx.half(),
src_proj_mtx=proj_mtx.half(),
tar_proj_mtx=proj_mtx.half()
)
warped = renders['warped']
synthesized = renders['synthesized']
warped_pil = to_pil_image(warped[0])
synthesized_pil = to_pil_image(synthesized[0])
return warped_pil, synthesized_pil
# Events
file.change(
fn=cb_mde,
inputs=file,
outputs=[image_widget, depth_widget, src_image, src_depth]
).then(
fn=cb_3d,
inputs=[src_image, src_depth, image_widget],
outputs=[viewer, src_view_mtx, proj_mtx])
button.click(
fn=cb_generate,
inputs=[viewer, src_image, src_depth, src_view_mtx, proj_mtx],
outputs=[warped_widget, gen_widget])
if __name__ == '__main__':
demo.launch()