Find3D / app.py
ziqima's picture
cleanup readme and raise errors
64c5824
raw
history blame
7.05 kB
import gradio as gr
import re
from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed
from inference.utils import get_legend
from inference.inference import segment_obj, get_heatmap
from huggingface_hub import login
import os
os.chdir("Pointcept/libs/pointops")
os.system("python setup.py install")
os.chdir("../../../")
login(token=os.getenv('hfkey'))
parts_dict = {
"fireplug": "bonnet of a fireplug,side cap of a fireplug,barrel of a fireplug,base of a fireplug",
"mickey": "ear,head,arms,hands,body,legs",
"motorvehicle": "wheel of a motor vehicle,seat of a motor vehicle,handle of a motor vehicle",
"teddy": "head,body,arms,legs",
"lamppost": "lighting of a lamppost,pole of a lamppost",
"shirt": "sleeve of a shirt,collar of a shirt,body of a shirt",
"capybara": "hat worn by a capybara,head,body,feet",
"corgi": "head,leg,body,ear",
"pushcar": "wheel,body,handle",
"plant": "pot,plant",
"chair": "back of chair,leg,seat"
}
source_dict = {
"fireplug":"objaverse",
"mickey":"objaverse",
"motorvehicle":"objaverse",
"teddy":"objaverse",
"lamppost":"objaverse",
"shirt":"objaverse",
"capybara": "wild",
"corgi": "wild",
"pushcar": "wild",
"plant": "wild",
"chair": "wild"
}
def predict(pcd_path, inference_mode, part_queries):
set_seed()
xyz, rgb, normal = read_pcd(pcd_path)
if inference_mode == "Segmentation":
parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)]
if len(parts)< 2:
raise gr.Error("For segmentation mode, please provide 2 or more parts", duration=5)
seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy()
legend = get_legend(parts)
return render_point_cloud(xyz, seg_rgb, legend=legend)
elif inference_mode == "Localization":
if "," in part_queries or ";" in part_queries or "." in part_queries:
raise gr.Error("For localization mode, please provide only one part", duration=5)
heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy()
return render_point_cloud(xyz, heatmap_rgb)
else:
return None
def on_select(evt: gr.SelectData):
obj_name = evt.value['image']['orig_name'][:-4]
src = source_dict[obj_name]
return [f"examples/{src}/{obj_name}.pcd", parts_dict[obj_name]]
with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as demo:
gr.HTML(
'''<h1 text-align="center">Find Any Part in 3D</h1>
<p style='font-size: 16px;'>This is a demo for Find3D: Find Any Part in 3D! Two modes are supported: <b>segmentation</b> and <b>localization</b>.
For <b>segmentation mode</b>, please provide multiple part queries in the "queries" text box, in the format of comma-separated string, such as "part1,part2,part3".
After hitting "Run", the model will segment the object into the provided parts.
<br>
For <b>localization mode<b>, please only provide one query string in the "queries" text box. After hitting "Run", the model will generate a heatmap for the provided query text.
Please click on the buttons below "Objaverse" and "In the Wild" for some examples. You can also upload your own .pcd files.</p>
<p style='font-size: 16px;'>Hint:
When uploading your own point cloud, please first close the existing point cloud by clicking on the "x" button.
<br>
We show some sample queries for the provided examples. When working with your own point cloud, feel free to rephrase the query (e.g. "part" vs "part of a object") to achieve better performance!</p>
'''
)
with gr.Row(variant="panel"):
with gr.Column(scale=4):
file_upload = gr.File(
label="Upload Point Cloud File",
type="filepath",
file_types=[".pcd"],
value="examples/objaverse/lamppost.pcd"
)
inference_mode = gr.Radio(
choices=["Segmentation", "Localization"],
label="Inference Mode",
value="Segmentation",
)
part_queries = gr.Textbox(
label="Part Queries",
value="lighting of a lamppost,pole of a lamppost",
)
run_button = gr.Button(
value="Run",
variant="primary",
)
with gr.Column(scale=4):
input_image = gr.Image(label="Input Image", visible=False, type='pil', image_mode='RGBA', height=290)
input_point_cloud = gr.Plot(label="Input Point Cloud")
with gr.Column(scale=4):
output_point_cloud = gr.Plot(label="Output Result")
with gr.Row(variant="panel"):
with gr.Column(scale=6):
title = gr.HTML('''<h1 text-align="center">Objaverse</h1>
<p style='font-size: 16px;'>Online 3D assets from Objaverse!</p>
''')
gallery_objaverse = gr.Gallery([("examples/objaverse/lamppost.jpg", "lamppost"),
("examples/objaverse/fireplug.jpg", "fireplug"),
("examples/objaverse/mickey.jpg", "Mickey"),
("examples/objaverse/motorvehicle.jpg", "motor vehicle"),
("examples/objaverse/teddy.jpg", "teddy bear"),
("examples/objaverse/shirt.jpg", "shirt")],
columns=3,
allow_preview=False)
gallery_objaverse.select(fn=on_select,
inputs=None,
outputs=[file_upload, part_queries])
with gr.Column(scale=6):
title = gr.HTML("""<h1 text-align="center">In the Wild</h1>
<p style='font-size: 16px;'>Challenging in-the-wild reconstructions from iPhone photos & AI-generated images!</p>
""")
gallery_wild = gr.Gallery([("examples/wild/capybara.png", "DALLE-capybara"),
("examples/wild/corgi.jpg", "DALLE-corgi"),
("examples/wild/plant.jpg", "iPhone-plant"),
("examples/wild/pushcar.jpg", "iPhone-pushcar"),
("examples/wild/chair.jpg", "iPhone-chair")],
columns=3,
allow_preview=False)
gallery_wild.select(fn=on_select,
inputs=None,
outputs=[file_upload, part_queries])
file_upload.change(
fn=render_pcd_file,
inputs=[file_upload],
outputs=[input_point_cloud],
)
run_button.click(
fn=predict,
inputs=[file_upload, inference_mode, part_queries],
outputs=[output_point_cloud],
)
demo.load(
fn=render_pcd_file,
inputs=[file_upload],
outputs=[input_point_cloud]) # initialize
demo.launch()