zero123 / app.py
rliu's picture
Update app.py
8fa8203
raw
history blame
5.95 kB
import numpy as np
import gradio as gr
import os
from PIL import Image
from functools import partial
def retrieve_input_image_wild(dataset, inputs):
img_id = inputs
img_path = os.path.join('online_demo', dataset, 'step-100_scale-6.0')
try:
image = Image.open(os.path.join(img_path, '%s.jpg' % img_id))
except:
image = Image.open(os.path.join(img_path, '%s.png' % img_id))
image.thumbnail([256, 256], Image.Resampling.LANCZOS)
return image
def retrieve_input_image(dataset, inputs):
img_id = inputs
img_path = os.path.join('online_demo', dataset, 'step-100_scale-6.0', img_id, 'input.png')
image = Image.open(img_path)
return image
def retrieve_novel_view(dataset, img_id, polar, azimuth, zoom, seed):
polar = polar // 30 + 1
azimuth = azimuth // 30
zoom = int(zoom * 2 + 1)
img_path = os.path.join('online_demo', dataset, 'step-100_scale-6.0', img_id,\
'polar-%d_azimuth-%d_distance-%d_seed-%d.png' % (polar, azimuth, zoom, seed))
image = Image.open(img_path)
return image
with gr.Blocks() as demo:
# gr.Markdown("Stable Diffusion Novel View Synthesis (Precomputed Results)")
with gr.Tab("In-the-wild Images"):
with gr.Row():
with gr.Column(scale=1):
default_input_image = Image.open( os.path.join('online_demo', 'nerf_wild', 'step-100_scale-6.0', 'car1.png'))
default_input_image.thumbnail([256, 256], Image.Resampling.LANCZOS)
input_image = gr.Image(default_input_image, shape=[256, 256])
options = sorted(next(os.walk('online_demo/nerf_wild/step-100_scale-6.0'))[1])
img_id = gr.Dropdown(options, value='car1', label='options')
text_button = gr.Button("Load Input Image")
retrieve_input_image_dataset = partial(retrieve_input_image_wild, 'nerf_wild')
text_button.click(retrieve_input_image_dataset, inputs=img_id, outputs=input_image)
with gr.Column(scale=1):
novel_view = gr.Image(shape=[256, 256])
inputs = [img_id,
gr.Slider(-30, 30, value=0, step=30, label='Polar angle (vertical rotation in degrees)'),
gr.Slider(0, 330, value=0, step=30, label='Azimuth angle (horizontal rotation in degrees)'),
gr.Slider(-0.5, 0.5, value=0, step=0.5, label='Zoom'),
gr.Slider(0, 3, value=1, step=1, label='Random seed')]
submit_button = gr.Button("Generate Novel View")
retrieve_novel_view_dataset = partial(retrieve_novel_view, 'nerf_wild')
submit_button.click(retrieve_novel_view_dataset, inputs=inputs, outputs=novel_view)
with gr.Tab("Google Scanned Objects"):
with gr.Row():
with gr.Column(scale=1):
default_input_image = Image.open( os.path.join('online_demo', 'GSO', 'step-100_scale-6.0', 'SAMBA_HEMP', 'input.png'))
default_input_image.thumbnail([256, 256], Image.Resampling.LANCZOS)
input_image = gr.Image(default_input_image, shape=[256, 256])
options = sorted(os.listdir('online_demo/GSO/step-100_scale-6.0'))
img_id = gr.Dropdown(options, value='SAMBA_HEMP', label='options')
text_button = gr.Button("Load Input Image")
retrieve_input_image_dataset = partial(retrieve_input_image, 'GSO')
text_button.click(retrieve_input_image_dataset, inputs=img_id, outputs=input_image)
with gr.Column(scale=1):
novel_view = gr.Image(shape=[256, 256])
inputs = [img_id,
gr.Slider(-30, 30, value=0, step=30, label='Polar angle (vertical rotation in degrees)'),
gr.Slider(0, 330, value=0, step=30, label='Azimuth angle (horizontal rotation in degrees)'),
gr.Slider(-0.5, 0.5, value=0, step=0.5, label='Zoom'),
gr.Slider(0, 3, value=1, step=1, label='Random seed')]
submit_button = gr.Button("Generate Novel View")
retrieve_novel_view_dataset = partial(retrieve_novel_view, 'GSO')
submit_button.click(retrieve_novel_view_dataset, inputs=inputs, outputs=novel_view)
with gr.Tab("RTMV"):
with gr.Row():
with gr.Column(scale=1):
default_input_image = Image.open( os.path.join('online_demo', 'RTMV', 'step-100_scale-6.0', '00000', 'input.png'))
default_input_image.thumbnail([256, 256], Image.Resampling.LANCZOS)
input_image = gr.Image(default_input_image, shape=[256, 256])
options = sorted(os.listdir('online_demo/RTMV/step-100_scale-6.0'))
img_id = gr.Dropdown(options, value='00000', label='options')
text_button = gr.Button("Load Input Image")
retrieve_input_image_dataset = partial(retrieve_input_image, 'RTMV')
text_button.click(retrieve_input_image_dataset, inputs=img_id, outputs=input_image)
with gr.Column(scale=1):
novel_view = gr.Image(shape=[256, 256])
inputs = [img_id,
gr.Slider(-30, 30, value=0, step=30, label='Polar angle (vertical rotation in degrees)'),
gr.Slider(0, 330, value=0, step=30, label='Azimuth angle (horizontal rotation in degrees)'),
gr.Slider(-0.5, 0.5, value=0, step=0.5, label='Zoom'),
gr.Slider(0, 3, value=1, step=1, label='Random seed')]
submit_button = gr.Button("Generate Novel View")
retrieve_novel_view_dataset = partial(retrieve_novel_view, 'RTMV')
submit_button.click(retrieve_novel_view_dataset, inputs=inputs, outputs=novel_view)
if __name__ == "__main__":
demo.launch()