|
import spaces |
|
from pickle import FALSE |
|
import gradio as gr |
|
import numpy as np |
|
import plotly.graph_objects as go |
|
from sam2point import dataset |
|
import sam2point.configs as configs |
|
from demo_utils import run_demo, create_box |
|
|
|
samples = { |
|
"3D Indoor Scene - S3DIS": ["Conference Room", "Restroom", "Lobby", "Office1", "Office2"], |
|
"3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"], |
|
"3D Raw LiDAR - KITTI": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"], |
|
"3D Outdoor Scene - Semantic3D": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6", "Scene7"], |
|
"3D Object - Objaverse": ["Plant", "Lego", "Lock", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse"], |
|
} |
|
|
|
PATH = { |
|
"S3DIS": ['Area_1_conferenceRoom_1.txt', 'Area_2_WC_1.txt', 'Area_4_lobby_2.txt', 'Area_5_office_3.txt', 'Area_6_office_9.txt'], |
|
"ScanNet": ['scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth', 'scene0000_00.pth', 'scene0002_00.pth'], |
|
"Objaverse": ["plant.npy", "human.npy", "lock.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy"], |
|
"KITTI": ["scene1.npy", "scene2.npy", "scene3.npy", "scene4.npy", "scene5.npy", "scene6.npy"], |
|
"Semantic3D": ["scene1.npy", "scene2.npy", "patch19.npy", "patch0.npy", "patch1.npy", "patch50.npy", "patch62.npy"] |
|
} |
|
|
|
prompt_types = ["Point", "Box", "Mask"] |
|
|
|
def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new_color=None): |
|
DATASET = name.split('-')[1].replace(" ", "") |
|
path = 'data/' + DATASET + '/' + PATH[DATASET][sample_idx] |
|
asp, SIZE = 1., 1 |
|
|
|
print(path) |
|
if DATASET == 'S3DIS': |
|
point, color = dataset.load_S3DIS_sample(path, sample=True) |
|
alpha = 1 |
|
elif DATASET == 'ScanNet': |
|
point, color = dataset.load_ScanNet_sample(path) |
|
alpha = 1 |
|
elif DATASET == 'Objaverse': |
|
point, color = dataset.load_Objaverse_sample(path) |
|
alpha = 1 |
|
SIZE = 2 |
|
elif DATASET == 'KITTI': |
|
point, color = dataset.load_KITTI_sample(path) |
|
asp = 0.3 |
|
alpha = 0.7 |
|
elif DATASET == 'Semantic3D': |
|
point, color = dataset.load_Semantic3D_sample(path, sample_idx, sample=True) |
|
alpha = 0.2 |
|
print("Loading Dataset:", DATASET, "Point Cloud Size:", point.shape, "Path:", path) |
|
|
|
|
|
if not type_: |
|
if point.shape[0] > 100000: |
|
indices = np.random.choice(point.shape[0], 100000, replace=False) |
|
point = point[indices] |
|
color = color[indices] |
|
fig = go.Figure( |
|
data=[ |
|
go.Scatter3d( |
|
x=point[:,0], y=point[:,1], z=point[:,2], |
|
mode='markers', |
|
marker=dict(size=SIZE, color=color, opacity=alpha), |
|
name="" |
|
) |
|
], |
|
layout=dict( |
|
scene=dict( |
|
xaxis=dict(visible=False), |
|
yaxis=dict(visible=False), |
|
zaxis=dict(visible=False), |
|
aspectratio=dict(x=1, y=1, z=asp), |
|
camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) |
|
) |
|
) |
|
) |
|
return fig |
|
|
|
if final: |
|
color = new_color |
|
green = np.array([[0.1, 0.1, 0.1]]) |
|
add_green = go.Scatter3d( |
|
x=green[:,0], y=green[:,1], z=green[:,2], |
|
mode='markers', |
|
marker=dict(size=0.0001, color='green', opacity=1), |
|
name="Segmentation Results" |
|
) |
|
if type_ == "box": |
|
if point.shape[0] > 100000: |
|
indices = np.random.choice(point.shape[0], 100000, replace=False) |
|
point = point[indices] |
|
color = color[indices] |
|
scatter = go.Scatter3d( |
|
x=point[:,0], y=point[:,1], z=point[:,2], |
|
mode='markers', |
|
marker=dict(size=SIZE, color=color, opacity=alpha), |
|
name="3D Object/Scene" |
|
) |
|
if final: scatter = [scatter, add_green] + create_box(prompt) |
|
else: scatter = [scatter] + create_box(prompt) |
|
elif type_ == "point": |
|
prompt = np.array([prompt]) |
|
new = go.Scatter3d( |
|
x=prompt[:,0], y=prompt[:,1], z=prompt[:,2], |
|
mode='markers', |
|
marker=dict(size=5, color='red', opacity=1), |
|
name="Point Prompt" |
|
) |
|
if point.shape[0] > 100000: |
|
indices = np.random.choice(point.shape[0], 100000, replace=False) |
|
point = point[indices] |
|
color = color[indices] |
|
scatter = go.Scatter3d( |
|
x=point[:,0], y=point[:,1], z=point[:,2], |
|
mode='markers', |
|
marker=dict(size=SIZE, color=color, opacity=alpha), |
|
name="3D Object/Scene" |
|
) |
|
if final: scatter = [scatter, new, add_green] |
|
else: scatter = [scatter, new] |
|
elif type_ == 'mask' and not final: |
|
color = np.clip(prompt * 255, 0, 255).astype(np.uint8) |
|
if point.shape[0] > 100000: |
|
indices = np.random.choice(point.shape[0], 100000, replace=False) |
|
point = point[indices] |
|
color = color[indices] |
|
scatter = go.Scatter3d( |
|
x=point[:,0], y=point[:,1], z=point[:,2], |
|
mode='markers', |
|
marker=dict(size=SIZE, color=color, opacity=alpha), |
|
name="3D Object/Scene" |
|
) |
|
red = np.array([[0.1, 0.1, 0.1]]) |
|
add_red = go.Scatter3d( |
|
x=red[:,0], y=red[:,1], z=red[:,2], |
|
mode='markers', |
|
marker=dict(size=0.0001, color='red', opacity=1), |
|
name="Mask Prompt" |
|
) |
|
scatter = [scatter, add_red] |
|
elif type_ == 'mask' and final: |
|
if point.shape[0] > 100000: |
|
indices = np.random.choice(point.shape[0], 100000, replace=False) |
|
point = point[indices] |
|
color = color[indices] |
|
scatter = go.Scatter3d( |
|
x=point[:,0], y=point[:,1], z=point[:,2], |
|
mode='markers', |
|
marker=dict(size=SIZE, color=color, opacity=alpha), |
|
name="3D Object/Scene" |
|
) |
|
scatter = [scatter, add_green] |
|
else: |
|
print("Wrong Prompt Type") |
|
exit(1) |
|
|
|
fig = go.Figure( |
|
data=scatter, |
|
layout=dict( |
|
scene=dict( |
|
xaxis=dict(visible=False), |
|
yaxis=dict(visible=False), |
|
zaxis=dict(visible=False), |
|
aspectratio=dict(x=1, y=1, z=asp), |
|
camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) |
|
) |
|
) |
|
) |
|
return fig |
|
|
|
@spaces.GPU() |
|
def show_prompt_in_3d(name, sample_idx, prompt_type, prompt_idx): |
|
if name == None or sample_idx == None or prompt_type == None or prompt_idx == None: |
|
return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True) |
|
|
|
DATASET = name.split('-')[1].replace(" ", "") |
|
TYPE = prompt_type.lower() |
|
theta = 0. if DATASET in "S3DIS ScanNet" else 0.5 |
|
mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest' |
|
|
|
prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, 0.02, theta, mode, ret_prompt=True) |
|
fig = load_3d_scene(name, sample_idx, TYPE, prompt) |
|
return fig, gr.Textbox(label="Response", value="Prompt has been shown in 3D Object/Scene!", visible=True) |
|
|
|
@spaces.GPU() |
|
def start_segmentation(name=None, sample_idx=None, prompt_type=None, prompt_idx=None, vx=0.02): |
|
if name == None or sample_idx == None or prompt_type == None or prompt_idx == None: |
|
return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True) |
|
|
|
DATASET = name.split('-')[1].replace(" ", "") |
|
TYPE = prompt_type.lower() |
|
theta = 0. if DATASET in "S3DIS ScanNet" else 0.5 |
|
mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest' |
|
|
|
new_color, prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, vx, theta, mode, ret_prompt=False) |
|
fig = load_3d_scene(name, sample_idx, TYPE, prompt, final=True, new_color=new_color) |
|
return fig, gr.Textbox(label="Response", value="Segmentation completed successfully!", visible=True) |
|
|
|
def update1(datasets): |
|
if 'Objaverse' in datasets: |
|
return gr.Radio(label="Select 3D Object", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) |
|
return gr.Radio(label="Select 3D Scene", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) |
|
|
|
def update2(name, sample_idx, prompt_type): |
|
if name == None or sample_idx == None or prompt_type == None: |
|
return gr.Radio(label="Select Prompt Example", choices=[]), gr.Textbox(label="Response", value="", visible=True) |
|
DATASET = name.split('-')[1].replace(" ", "") |
|
TYPE = prompt_type.lower() + '_prompts' |
|
|
|
if DATASET == 'S3DIS': |
|
info = configs.S3DIS_samples[sample_idx][TYPE] |
|
elif DATASET == 'ScanNet': |
|
info = configs.ScanNet_samples[sample_idx][TYPE] |
|
elif DATASET == 'Objaverse': |
|
info = configs.Objaverse_samples[sample_idx][TYPE] |
|
elif DATASET == 'KITTI': |
|
info = configs.KITTI_samples[sample_idx][TYPE] |
|
elif DATASET == 'Semantic3D': |
|
info = configs.Semantic3D_samples[sample_idx][TYPE] |
|
|
|
cur = ['Example ' + str(i) for i in range(1, len(info) + 1)] |
|
return gr.Radio(label="Select Prompt Example", choices=cur), gr.Textbox(label="Response", value="", visible=True) |
|
|
|
def update3(name, sample_idx, prompt_type, prompt_idx): |
|
if name == None or sample_idx == None or prompt_type == None: |
|
return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) |
|
DATASET = name.split('-')[1].replace(" ", "") |
|
TYPE = configs.VOXEL[prompt_type.lower()] |
|
|
|
if DATASET in "S3DIS ScanNet": |
|
vx_ = 0.02 |
|
elif DATASET == 'Objaverse': |
|
vx_ = configs.Objaverse_samples[sample_idx][TYPE][prompt_idx] |
|
elif DATASET == 'KITTI': |
|
vx_ = configs.KITTI_samples[sample_idx][TYPE][prompt_idx] |
|
elif DATASET == 'Semantic3D': |
|
vx_ = configs.Semantic3D_samples[sample_idx][TYPE][prompt_idx] |
|
|
|
return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=vx_) |
|
|
|
def main(): |
|
title = """<h1 style="text-align: center;"> |
|
<div style="width: 1.2em; height: 1.2em; display: inline-block;"><img src="https://github.com/ZiyuGuo99/ZiyuGuo99.github.io/blob/main/assets/img/logo.png?raw=true" style='width: 100%; height: 100%; object-fit: contain;' /></div> |
|
<span style="font-variant: small-caps; font-weight: bold;">Sam2Point</span> |
|
</h1> |
|
<h3 align="center"><span style="font-variant: small-caps; ">Segment Any 3D as Videos in Zero-shot and Promptable Manners |
|
</span></h3> |
|
|
|
<div style="text-align: center;"> |
|
<div style="display: flex; align-items: center; justify-content: center; gap: 0.5rem; margin-bottom: 0.5rem; font-size: 1rem; flex-wrap: wrap;"> |
|
<a href="https://sam2point.github.io/" target="_blank">[Webpage]</a> |
|
<a href="https://arxiv.org/pdf/2408.16768" target="_blank">[Paper]</a> |
|
<a href="https://github.com/ZiyuGuo99/SAM2Point" target="_blank">[Code]</a> |
|
</div> |
|
</div> |
|
<p style="text-align: center;"> |
|
Select an example and a 3D prompt to start segmentation using <span style="font-variant: small-caps;">Sam2Point</span>. |
|
</p> |
|
<p style="text-align: center;"> |
|
Custom 3D input and prompts will be supported soon. |
|
</p> |
|
""" |
|
|
|
with gr.Blocks( |
|
css=""" |
|
.contain { display: flex; flex-direction: column; } |
|
.gradio-container { height: 100vh !important; } |
|
#col_container { height: 100%; } |
|
pre { |
|
white-space: pre-wrap; /* Since CSS 2.1 */ |
|
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ |
|
white-space: -pre-wrap; /* Opera 4-6 */ |
|
white-space: -o-pre-wrap; /* Opera 7 */ |
|
word-wrap: break-word; /* Internet Explorer 5.5+ */ |
|
}""", |
|
js=""" |
|
function refresh() { |
|
const url = new URL(window.location); |
|
if (url.searchParams.get('__theme') !== 'light') { |
|
url.searchParams.set('__theme', 'light'); |
|
window.location.href = url.href; |
|
} |
|
}""", |
|
title="SAM2Point: Segment Any 3D as Videos in Zero-shot and Promptable Manners", |
|
theme=gr.themes.Soft() |
|
) as app: |
|
gr.HTML(title) |
|
with gr.Row(): |
|
with gr.Column(elem_id="col_container"): |
|
sample_dropdown = gr.Dropdown(label="Select 3D Data Type", choices=samples, type="value") |
|
scene_dropdown = gr.Radio(label="Select 3D Object/Scene", choices=[], type="index") |
|
show_button = gr.Button("Show 3D Scene/Object") |
|
prompt_type_dropdown = gr.Radio(label="Select Prompt Type", choices=prompt_types) |
|
prompt_sample_dropdown = gr.Radio(label="Select Prompt Example", choices=[], type="index") |
|
show_prompt_button = gr.Button("Show Prompt in 3D Scene/Object") |
|
with gr.Column(): |
|
start_segment_button = gr.Button("Start Segmentation") |
|
plot1 = gr.Plot() |
|
|
|
response = gr.Textbox(label="Response") |
|
|
|
sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response]) |
|
sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) |
|
scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) |
|
prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) |
|
|
|
show_button.click(load_3d_scene, inputs=[sample_dropdown, scene_dropdown], outputs=plot1) |
|
show_prompt_button.click(show_prompt_in_3d, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response]) |
|
start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response]) |
|
|
|
app.queue(max_size=20, api_open=False) |
|
app.launch(max_threads=400) |
|
|
|
if __name__ == "__main__": |
|
main() |