SAM2Point / app.py
ZiyuG's picture
Update app.py
201bc5c verified
import spaces
from pickle import FALSE
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from sam2point import dataset
import sam2point.configs as configs
from demo_utils import run_demo, create_box
samples = {
"3D Indoor Scene - S3DIS": ["Conference Room", "Restroom", "Lobby", "Office1", "Office2"],
"3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"],
"3D Raw LiDAR - KITTI": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"],
"3D Outdoor Scene - Semantic3D": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6", "Scene7"],
"3D Object - Objaverse": ["Plant", "Lego", "Lock", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse"],
}
PATH = {
"S3DIS": ['Area_1_conferenceRoom_1.txt', 'Area_2_WC_1.txt', 'Area_4_lobby_2.txt', 'Area_5_office_3.txt', 'Area_6_office_9.txt'],
"ScanNet": ['scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth', 'scene0000_00.pth', 'scene0002_00.pth'],
"Objaverse": ["plant.npy", "human.npy", "lock.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy"],
"KITTI": ["scene1.npy", "scene2.npy", "scene3.npy", "scene4.npy", "scene5.npy", "scene6.npy"],
"Semantic3D": ["scene1.npy", "scene2.npy", "patch19.npy", "patch0.npy", "patch1.npy", "patch50.npy", "patch62.npy"]
}
prompt_types = ["Point", "Box", "Mask"]
def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new_color=None):
DATASET = name.split('-')[1].replace(" ", "")
path = 'data/' + DATASET + '/' + PATH[DATASET][sample_idx]
asp, SIZE = 1., 1
print(path)
if DATASET == 'S3DIS':
point, color = dataset.load_S3DIS_sample(path, sample=True)
alpha = 1
elif DATASET == 'ScanNet':
point, color = dataset.load_ScanNet_sample(path)
alpha = 1
elif DATASET == 'Objaverse':
point, color = dataset.load_Objaverse_sample(path)
alpha = 1
SIZE = 2
elif DATASET == 'KITTI':
point, color = dataset.load_KITTI_sample(path)
asp = 0.3
alpha = 0.7
elif DATASET == 'Semantic3D':
point, color = dataset.load_Semantic3D_sample(path, sample_idx, sample=True)
alpha = 0.2
print("Loading Dataset:", DATASET, "Point Cloud Size:", point.shape, "Path:", path)
##### Initial Show #####
if not type_:
if point.shape[0] > 100000: # sample points for speeding up
indices = np.random.choice(point.shape[0], 100000, replace=False)
point = point[indices]
color = color[indices]
fig = go.Figure(
data=[
go.Scatter3d(
x=point[:,0], y=point[:,1], z=point[:,2],
mode='markers',
marker=dict(size=SIZE, color=color, opacity=alpha),
name=""
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
aspectratio=dict(x=1, y=1, z=asp),
camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
)
)
)
return fig
##### Final Results #####
if final:
color = new_color
green = np.array([[0.1, 0.1, 0.1]])
add_green = go.Scatter3d(
x=green[:,0], y=green[:,1], z=green[:,2],
mode='markers',
marker=dict(size=0.0001, color='green', opacity=1),
name="Segmentation Results"
)
if type_ == "box":
if point.shape[0] > 100000:
indices = np.random.choice(point.shape[0], 100000, replace=False)
point = point[indices]
color = color[indices]
scatter = go.Scatter3d(
x=point[:,0], y=point[:,1], z=point[:,2],
mode='markers',
marker=dict(size=SIZE, color=color, opacity=alpha),
name="3D Object/Scene"
)
if final: scatter = [scatter, add_green] + create_box(prompt)
else: scatter = [scatter] + create_box(prompt)
elif type_ == "point":
prompt = np.array([prompt])
new = go.Scatter3d(
x=prompt[:,0], y=prompt[:,1], z=prompt[:,2],
mode='markers',
marker=dict(size=5, color='red', opacity=1),
name="Point Prompt"
)
if point.shape[0] > 100000:
indices = np.random.choice(point.shape[0], 100000, replace=False)
point = point[indices]
color = color[indices]
scatter = go.Scatter3d(
x=point[:,0], y=point[:,1], z=point[:,2],
mode='markers',
marker=dict(size=SIZE, color=color, opacity=alpha),
name="3D Object/Scene"
)
if final: scatter = [scatter, new, add_green]
else: scatter = [scatter, new]
elif type_ == 'mask' and not final:
color = np.clip(prompt * 255, 0, 255).astype(np.uint8)
if point.shape[0] > 100000:
indices = np.random.choice(point.shape[0], 100000, replace=False)
point = point[indices]
color = color[indices]
scatter = go.Scatter3d(
x=point[:,0], y=point[:,1], z=point[:,2],
mode='markers',
marker=dict(size=SIZE, color=color, opacity=alpha),
name="3D Object/Scene"
)
red = np.array([[0.1, 0.1, 0.1]])
add_red = go.Scatter3d(
x=red[:,0], y=red[:,1], z=red[:,2],
mode='markers',
marker=dict(size=0.0001, color='red', opacity=1),
name="Mask Prompt"
)
scatter = [scatter, add_red]
elif type_ == 'mask' and final:
if point.shape[0] > 100000:
indices = np.random.choice(point.shape[0], 100000, replace=False)
point = point[indices]
color = color[indices]
scatter = go.Scatter3d(
x=point[:,0], y=point[:,1], z=point[:,2],
mode='markers',
marker=dict(size=SIZE, color=color, opacity=alpha),
name="3D Object/Scene"
)
scatter = [scatter, add_green]
else:
print("Wrong Prompt Type")
exit(1)
fig = go.Figure(
data=scatter,
layout=dict(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False),
aspectratio=dict(x=1, y=1, z=asp),
camera=dict(eye=dict(x=1.5, y=1.5, z=1.5))
)
)
)
return fig
@spaces.GPU()
def show_prompt_in_3d(name, sample_idx, prompt_type, prompt_idx):
if name == None or sample_idx == None or prompt_type == None or prompt_idx == None:
return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True)
DATASET = name.split('-')[1].replace(" ", "")
TYPE = prompt_type.lower()
theta = 0. if DATASET in "S3DIS ScanNet" else 0.5
mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest'
prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, 0.02, theta, mode, ret_prompt=True)
fig = load_3d_scene(name, sample_idx, TYPE, prompt)
return fig, gr.Textbox(label="Response", value="Prompt has been shown in 3D Object/Scene!", visible=True)
@spaces.GPU()
def start_segmentation(name=None, sample_idx=None, prompt_type=None, prompt_idx=None, vx=0.02):
if name == None or sample_idx == None or prompt_type == None or prompt_idx == None:
return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True)
DATASET = name.split('-')[1].replace(" ", "")
TYPE = prompt_type.lower()
theta = 0. if DATASET in "S3DIS ScanNet" else 0.5
mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest'
new_color, prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, vx, theta, mode, ret_prompt=False)
fig = load_3d_scene(name, sample_idx, TYPE, prompt, final=True, new_color=new_color)
return fig, gr.Textbox(label="Response", value="Segmentation completed successfully!", visible=True)
def update1(datasets):
if 'Objaverse' in datasets:
return gr.Radio(label="Select 3D Object", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True)
return gr.Radio(label="Select 3D Scene", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True)
def update2(name, sample_idx, prompt_type):
if name == None or sample_idx == None or prompt_type == None:
return gr.Radio(label="Select Prompt Example", choices=[]), gr.Textbox(label="Response", value="", visible=True)
DATASET = name.split('-')[1].replace(" ", "")
TYPE = prompt_type.lower() + '_prompts'
if DATASET == 'S3DIS':
info = configs.S3DIS_samples[sample_idx][TYPE]
elif DATASET == 'ScanNet':
info = configs.ScanNet_samples[sample_idx][TYPE]
elif DATASET == 'Objaverse':
info = configs.Objaverse_samples[sample_idx][TYPE]
elif DATASET == 'KITTI':
info = configs.KITTI_samples[sample_idx][TYPE]
elif DATASET == 'Semantic3D':
info = configs.Semantic3D_samples[sample_idx][TYPE]
cur = ['Example ' + str(i) for i in range(1, len(info) + 1)]
return gr.Radio(label="Select Prompt Example", choices=cur), gr.Textbox(label="Response", value="", visible=True)
def update3(name, sample_idx, prompt_type, prompt_idx):
if name == None or sample_idx == None or prompt_type == None:
return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02)
DATASET = name.split('-')[1].replace(" ", "")
TYPE = configs.VOXEL[prompt_type.lower()]
if DATASET in "S3DIS ScanNet":
vx_ = 0.02
elif DATASET == 'Objaverse':
vx_ = configs.Objaverse_samples[sample_idx][TYPE][prompt_idx]
elif DATASET == 'KITTI':
vx_ = configs.KITTI_samples[sample_idx][TYPE][prompt_idx]
elif DATASET == 'Semantic3D':
vx_ = configs.Semantic3D_samples[sample_idx][TYPE][prompt_idx]
return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=vx_)
def main():
title = """<h1 style="text-align: center;">
<div style="width: 1.2em; height: 1.2em; display: inline-block;"><img src="https://github.com/ZiyuGuo99/ZiyuGuo99.github.io/blob/main/assets/img/logo.png?raw=true" style='width: 100%; height: 100%; object-fit: contain;' /></div>
<span style="font-variant: small-caps; font-weight: bold;">Sam2Point</span>
</h1>
<h3 align="center"><span style="font-variant: small-caps; ">Segment Any 3D as Videos in Zero-shot and Promptable Manners
</span></h3>
<div style="text-align: center;">
<div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem; margin-bottom: 0.5rem; font-size: 1rem; flex-wrap: wrap;">
<a href="https://sam2point.github.io/" target="_blank">[Webpage]</a>
<a href="https://arxiv.org/pdf/2408.16768" target="_blank">[Paper]</a>
<a href="https://github.com/ZiyuGuo99/SAM2Point" target="_blank">[Code]</a>
</div>
</div>
<p style="text-align: center;">
Select an example and a 3D prompt to start segmentation using <span style="font-variant: small-caps;">Sam2Point</span>.
</p>
<p style="text-align: center;">
Custom 3D input and prompts will be supported soon.
</p>
"""
with gr.Blocks(
css="""
.contain { display: flex; flex-direction: column; }
.gradio-container { height: 100vh !important; }
#col_container { height: 100%; }
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
white-space: -pre-wrap; /* Opera 4-6 */
white-space: -o-pre-wrap; /* Opera 7 */
word-wrap: break-word; /* Internet Explorer 5.5+ */
}""",
js="""
function refresh() {
const url = new URL(window.location);
if (url.searchParams.get('__theme') !== 'light') {
url.searchParams.set('__theme', 'light');
window.location.href = url.href;
}
}""",
title="SAM2Point: Segment Any 3D as Videos in Zero-shot and Promptable Manners",
theme=gr.themes.Soft()
) as app:
gr.HTML(title)
with gr.Row():
with gr.Column(elem_id="col_container"):
sample_dropdown = gr.Dropdown(label="Select 3D Data Type", choices=samples, type="value")
scene_dropdown = gr.Radio(label="Select 3D Object/Scene", choices=[], type="index")
show_button = gr.Button("Show 3D Scene/Object")
prompt_type_dropdown = gr.Radio(label="Select Prompt Type", choices=prompt_types)
prompt_sample_dropdown = gr.Radio(label="Select Prompt Example", choices=[], type="index")
show_prompt_button = gr.Button("Show Prompt in 3D Scene/Object")
with gr.Column():
start_segment_button = gr.Button("Start Segmentation")
plot1 = gr.Plot()
response = gr.Textbox(label="Response")
sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response])
sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response])
show_button.click(load_3d_scene, inputs=[sample_dropdown, scene_dropdown], outputs=plot1)
show_prompt_button.click(show_prompt_in_3d, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response])
start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response])
app.queue(max_size=20, api_open=False)
app.launch(max_threads=400)
if __name__ == "__main__":
main()