|
import gradio as gr |
|
import re |
|
from utils import read_pcd, render_point_cloud, render_pcd_file, set_seed |
|
from inference.utils import get_legend |
|
from inference.inference import segment_obj, get_heatmap |
|
from huggingface_hub import login |
|
import os |
|
|
|
|
|
os.chdir("Pointcept/libs/pointops") |
|
os.system("python setup.py install") |
|
os.chdir("../../../") |
|
|
|
login(token=os.getenv('hfkey')) |
|
|
|
parts_dict = { |
|
"fireplug": "bonnet of a fireplug,side cap of a fireplug,barrel of a fireplug,base of a fireplug", |
|
"mickey": "ear,head,arms,hands,body,legs", |
|
"motorvehicle": "wheel of a motor vehicle,seat of a motor vehicle,handle of a motor vehicle", |
|
"teddy": "head,body,arms,legs", |
|
"lamppost": "lighting of a lamppost,pole of a lamppost", |
|
"shirt": "sleeve of a shirt,collar of a shirt,body of a shirt", |
|
"capybara": "hat worn by a capybara,head,body,feet", |
|
"corgi": "head,leg,body,ear", |
|
"pushcar": "wheel,body,handle", |
|
"plant": "pot,plant", |
|
"chair": "back of chair,leg,seat" |
|
} |
|
|
|
source_dict = { |
|
"fireplug":"objaverse", |
|
"mickey":"objaverse", |
|
"motorvehicle":"objaverse", |
|
"teddy":"objaverse", |
|
"lamppost":"objaverse", |
|
"shirt":"objaverse", |
|
"capybara": "wild", |
|
"corgi": "wild", |
|
"pushcar": "wild", |
|
"plant": "wild", |
|
"chair": "wild" |
|
} |
|
|
|
def predict(pcd_path, inference_mode, part_queries): |
|
set_seed() |
|
xyz, rgb, normal = read_pcd(pcd_path) |
|
if inference_mode == "Segmentation": |
|
parts = [part.strip(" ") for part in re.split(r'[,;.|]', part_queries)] |
|
seg_rgb = segment_obj(xyz, rgb, normal, parts).cpu().numpy() |
|
legend = get_legend(parts) |
|
return render_point_cloud(xyz, seg_rgb, legend=legend) |
|
elif inference_mode == "Localization": |
|
heatmap_rgb = get_heatmap(xyz, rgb, normal, part_queries).cpu().numpy() |
|
return render_point_cloud(xyz, heatmap_rgb) |
|
else: |
|
return None |
|
|
|
def on_select(evt: gr.SelectData): |
|
obj_name = evt.value['image']['orig_name'][:-4] |
|
src = source_dict[obj_name] |
|
return [f"examples/{src}/{obj_name}.pcd", parts_dict[obj_name]] |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Default(text_size="lg", radius_size="none")) as demo: |
|
gr.HTML( |
|
'''<h1 text-align="center">Find Any Part in 3D</h1> |
|
<p style='font-size: 16px;'>This is a demo for Find3D: Find Any Part in 3D! Two modes are supported: segmentation and localization. |
|
For segmentation mode, please provide multiple part queries in the "queries" text box, in the format of comma-separated string, such as "part1,part2,part3". |
|
After hitting "Run", the model will segment the object into the provided parts. |
|
For localization mode, please only provide one query string in the "queries" text box. After hitting "Run", the model will generate a heatmap for the provided query text. |
|
Please click on the buttons below "Objaverse" and "In the Wild" for some examples. You can also upload your own .pcd files.</p> |
|
<p style='font-size: 16px;'>Hint: we provide some part names for the examples below. |
|
When working with your own point cloud, feel free to rephrase the query (e.g. "part" vs "part of a object") to achieve better performance!</p> |
|
''' |
|
) |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=4): |
|
file_upload = gr.File( |
|
label="Upload Point Cloud File", |
|
type="filepath", |
|
file_types=[".pcd"], |
|
value="examples/objaverse/lamppost.pcd" |
|
) |
|
inference_mode = gr.Radio( |
|
choices=["Segmentation", "Localization"], |
|
label="Inference Mode", |
|
value="Segmentation", |
|
) |
|
part_queries = gr.Textbox( |
|
label="Part Queries", |
|
value="lighting of a lamppost,pole of a lamppost", |
|
) |
|
run_button = gr.Button( |
|
value="Run", |
|
variant="primary", |
|
) |
|
|
|
with gr.Column(scale=4): |
|
input_image = gr.Image(label="Input Image", visible=False, type='pil', image_mode='RGBA', height=290) |
|
input_point_cloud = gr.Plot(label="Input Point Cloud") |
|
|
|
with gr.Column(scale=4): |
|
output_point_cloud = gr.Plot(label="Output Result") |
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=6): |
|
title = gr.HTML('''<h1 text-align="center">Objaverse</h1> |
|
<p style='font-size: 16px;'>Online 3D assets from Objaverse!</p> |
|
''') |
|
gallery_objaverse = gr.Gallery([("examples/objaverse/lamppost.jpg", "lamppost"), |
|
("examples/objaverse/fireplug.jpg", "fireplug"), |
|
("examples/objaverse/mickey.jpg", "Mickey"), |
|
("examples/objaverse/motorvehicle.jpg", "motor vehicle"), |
|
("examples/objaverse/teddy.jpg", "teddy bear"), |
|
("examples/objaverse/shirt.jpg", "shirt")], |
|
columns=3, |
|
allow_preview=False) |
|
gallery_objaverse.select(fn=on_select, |
|
inputs=None, |
|
outputs=[file_upload, part_queries]) |
|
with gr.Column(scale=6): |
|
title = gr.HTML("""<h1 text-align="center">In the Wild</h1> |
|
<p style='font-size: 16px;'>Challenging in-the-wild reconstructions from iPhone photos & AI-generated images!</p> |
|
""") |
|
gallery_wild = gr.Gallery([("examples/wild/capybara.png", "DALLE-capybara"), |
|
("examples/wild/corgi.jpg", "DALLE-corgi"), |
|
("examples/wild/plant.jpg", "iPhone-plant"), |
|
("examples/wild/pushcar.jpg", "iPhone-pushcar"), |
|
("examples/wild/chair.jpg", "iPhone-chair")], |
|
columns=3, |
|
allow_preview=False) |
|
gallery_wild.select(fn=on_select, |
|
inputs=None, |
|
outputs=[file_upload, part_queries]) |
|
|
|
file_upload.change( |
|
fn=render_pcd_file, |
|
inputs=[file_upload], |
|
outputs=[input_point_cloud], |
|
) |
|
run_button.click( |
|
fn=predict, |
|
inputs=[file_upload, inference_mode, part_queries], |
|
outputs=[output_point_cloud], |
|
) |
|
demo.load( |
|
fn=render_pcd_file, |
|
inputs=[file_upload], |
|
outputs=[input_point_cloud]) |
|
|
|
demo.launch() |
|
|