from pickle import FALSE import gradio as gr import numpy as np import plotly.graph_objects as go from sam2point import dataset import sam2point.configs as configs from demo_utils import run_demo, create_box # Sample data for dropdowns samples = { "3D Indoor Scene - S3DIS": ["Conference Room", "Restroom", "Lobby", "Office1", "Office2"], # "3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5"], "3D Indoor Scene - ScanNet": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"], "3D Outdoor Driving Scene - KITTI": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6"], "3D Outdoor Street Scene - Semantic3D": ["Scene1", "Scene2", "Scene3", "Scene4", "Scene5", "Scene6", "Scene7"], "3D Object - Objaverse": ["Plant", "Lego", "Lock", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse"], # "3D Object - Objaverse": ["Plant", "Eleplant", "Knife Rest", "Skateboard", "Popcorn Machine", "Stove", "Bus Shelter", "Thor Hammer", "Horse", "Dinner Booth"], } PATH = { "S3DIS": ['Area_1_conferenceRoom_1.txt', 'Area_2_WC_1.txt', 'Area_4_lobby_2.txt', 'Area_5_office_3.txt', 'Area_6_office_9.txt'], # "ScanNet": ['scene0001_01.pth', 'scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth'], "ScanNet": ['scene0005_01.pth', 'scene0010_01.pth', 'scene0016_02.pth', 'scene0019_01.pth', 'scene0000_00.pth', 'scene0002_00.pth'], "Objaverse": ["plant.npy", "human.npy", "lock.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy"], # "Objaverse": ["plant.npy", "elephant.npy", "knife_rest.npy", "skateboard.npy", "popcorn_machine.npy", "stove.npy", "bus_shelter.npy", "thor_hammer.npy", "horse.npy", "dinner_booth.npy"], "KITTI": ["scene1.npy", "scene2.npy", "scene3.npy", "scene4.npy", "scene5.npy", "scene6.npy"], "Semantic3D": ["scene1.npy", "scene2.npy", "patch19.npy", "patch0.npy", "patch1.npy", "patch50.npy", "patch62.npy"] } prompt_types = ["Point", "Box", "Mask"] # def select(name, sample_idx): # DATASET = name.split('-')[1].replace(" ", "") # gr.Info(f"Visualizing {DATASET} Example {str(sample_idx)}...") # Function to load and display 3D scene or object def load_3d_scene(name, sample_idx=-1, type_=None, prompt=None, final=False, new_color=None): DATASET = name.split('-')[1].replace(" ", "") path = 'data/' + DATASET + '/' + PATH[DATASET][sample_idx] asp, SIZE = 1., 1 # load data print(path) if DATASET == 'S3DIS': point, color = dataset.load_S3DIS_sample(path, sample=True) alpha = 1 elif DATASET == 'ScanNet': point, color = dataset.load_ScanNet_sample(path) alpha = 1 elif DATASET == 'Objaverse': point, color = dataset.load_Objaverse_sample(path) alpha = 1 SIZE = 2 elif DATASET == 'KITTI': point, color = dataset.load_KITTI_sample(path) asp = 0.3 alpha = 0.7 elif DATASET == 'Semantic3D': point, color = dataset.load_Semantic3D_sample(path, sample_idx, sample=True) alpha = 0.2 print("Loading Dataset:", DATASET, ", Point Cloud Size:", point.shape) ##### Initial Showing ##### if not type_: if point.shape[0] > 100000: indices = np.random.choice(point.shape[0], 100000, replace=False) point = point[indices] color = color[indices] # #NOTE KITTI # mask1 = point[:, 1] <= 0.8 # mask4 = point[:, 1] >= 0.6 # mask2 = point[:, 0] >= 0.3 # mask3 = point[:, 0] <= 0.7 # mask = mask1 & mask2 & mask3 & mask4 # point = point[mask] # color = color[mask] # alpha = 1 # ###### fig = go.Figure( data=[ go.Scatter3d( x=point[:,0], y=point[:,1], z=point[:,2], mode='markers', marker=dict(size=SIZE, color=color, opacity=alpha), name="" ) ], layout=dict( scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), aspectratio=dict(x=1, y=1, z=asp), camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) ) ) ) return fig ##### Final if final: color = new_color green = np.array([[0.1, 0.1, 0.1]]) add_green = go.Scatter3d( x=green[:,0], y=green[:,1], z=green[:,2], mode='markers', marker=dict(size=0.0001, color='green', opacity=1), name="Segmentation Results" ) if type_ == "box": if point.shape[0] > 100000: indices = np.random.choice(point.shape[0], 100000, replace=False) point = point[indices] color = color[indices] # mask = point[:, 1] < 0.8 # point = point[mask] # color = color[mask] # alpha = 1 scatter = go.Scatter3d( x=point[:,0], y=point[:,1], z=point[:,2], mode='markers', marker=dict(size=SIZE, color=color, opacity=alpha), name="3D Object/Scene" ) if final: scatter = [scatter, add_green] + create_box(prompt) else: scatter = [scatter] + create_box(prompt) elif type_ == "point": prompt = np.array([prompt]) new = go.Scatter3d( x=prompt[:,0], y=prompt[:,1], z=prompt[:,2], mode='markers', # marker=dict(size=5, color='red', opacity=1), # marker=dict(size=5, color='rgb(255, 140, 0)', opacity=1), marker=dict(size=5, color='rgb(139, 0, 0)', opacity=1), name="Point Prompt" ) # print(point.shape, color.shape, new_color.shape) if point.shape[0] > 100000: indices = np.random.choice(point.shape[0], 100000, replace=False) point = point[indices] color = color[indices] # #NOTE KITTI # mask1 = point[:, 1] <= 0.8 # mask = point[:, 1] >= 0.35 #2 # < 0.63 #3 # mask2 = point[:, 0] >= 0.3 # mask3 = point[:, 0] <= 0.7 # mask = mask1 & mask2 & mask3 & mask4 # #NOTE S3DIS # if DATASET == 'S3DIS': # mask = point[:, 0] > 0.04 # point = point[mask] # color = color[mask] # alpha = 1 # ###### scatter = go.Scatter3d( x=point[:,0], y=point[:,1], z=point[:,2], mode='markers', marker=dict(size=SIZE, color=color, opacity=alpha), name="3D Object/Scene" ) if final: scatter = [scatter, new, add_green] else: scatter = [scatter, new] elif type_ == 'mask' and not final: color = np.clip(prompt * 255, 0, 255).astype(np.uint8) if point.shape[0] > 100000: indices = np.random.choice(point.shape[0], 100000, replace=False) point = point[indices] color = color[indices] scatter = go.Scatter3d( x=point[:,0], y=point[:,1], z=point[:,2], mode='markers', marker=dict(size=SIZE, color=color, opacity=alpha), name="3D Object/Scene" ) red = np.array([[0.1, 0.1, 0.1]]) add_red = go.Scatter3d( x=red[:,0], y=red[:,1], z=red[:,2], mode='markers', marker=dict(size=0.0001, color='red', opacity=1), name="Mask Prompt" ) scatter = [scatter, add_red] elif type_ == 'mask' and final: if point.shape[0] > 100000: indices = np.random.choice(point.shape[0], 100000, replace=False) point = point[indices] color = color[indices] # # cut # mask = point[:, 0] > 0.1 # point = point[mask] # color = color[mask] # alpha = 1 # ###### scatter = go.Scatter3d( x=point[:,0], y=point[:,1], z=point[:,2], mode='markers', marker=dict(size=SIZE, color=color, opacity=alpha), name="3D Object/Scene" ) scatter = [scatter, add_green] print(point.shape, color.shape) else: print("Wrong Prompt Type") exit(1) fig = go.Figure( data=scatter, layout=dict( scene=dict( xaxis=dict(visible=False), yaxis=dict(visible=False), zaxis=dict(visible=False), aspectratio=dict(x=1, y=1, z=asp), camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)) ) ) ) return fig # Function to display prompt in 3D def show_prompt_in_3d(name, sample_idx, prompt_type, prompt_idx): DATASET = name.split('-')[1].replace(" ", "") TYPE = prompt_type.lower() theta = 0. if DATASET in "S3DIS ScanNet" else 0.5 mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest' prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, 0.02, theta, mode, ret_prompt=True) fig = load_3d_scene(name, sample_idx, TYPE, prompt) return fig # Function to start segmentation def start_segmentation(name=None, sample_idx=None, prompt_type=None, prompt_idx=None, vx=0.02): if name == None or sample_idx == None or prompt_type == None or prompt_idx == None: return gr.Plot(), gr.Textbox(label="Response", value="Please ensure all options are selected.", visible=True) DATASET = name.split('-')[1].replace(" ", "") TYPE = prompt_type.lower() theta = 0. if DATASET in "S3DIS ScanNet" else 0.5 mode = "bilinear" if DATASET in "S3DIS ScanNet" else 'nearest' new_color, prompt = run_demo(DATASET, TYPE, sample_idx, prompt_idx, vx, theta, mode, ret_prompt=False) fig = load_3d_scene(name, sample_idx, TYPE, prompt, final=True, new_color=new_color) return fig, gr.Textbox(label="Response", value="Segmentation completed successfully!", visible=True) def update1(datasets): if 'Objaverse' in datasets: return gr.Radio(label="Select 3D Object", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) return gr.Radio(label="Select 3D Scene", choices=samples[datasets]), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) def update2(name, sample_idx, prompt_type): if name == None or sample_idx == None or prompt_type == None: return gr.Radio(label="Select Prompt Example", choices=[]), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) DATASET = name.split('-')[1].replace(" ", "") TYPE = prompt_type.lower() + '_prompts' # if DATASET in "ScanNet" and prompt_type == 'Mask': TYPE = 'point_prompts' if DATASET == 'S3DIS': info = configs.S3DIS_samples[sample_idx][TYPE] elif DATASET == 'ScanNet': info = configs.ScanNet_samples[sample_idx][TYPE] elif DATASET == 'Objaverse': info = configs.Objaverse_samples[sample_idx][TYPE] elif DATASET == 'KITTI': info = configs.KITTI_samples[sample_idx][TYPE] elif DATASET == 'Semantic3D': info = configs.Semantic3D_samples[sample_idx][TYPE] cur = ['Example ' + str(i) for i in range(1, len(info) + 1)] return gr.Radio(label="Select Prompt Example", choices=cur), gr.Textbox(label="Response", value="", visible=True) #, gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) def update3(name, sample_idx, prompt_type, prompt_idx): if name == None or sample_idx == None or prompt_type == None: return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) DATASET = name.split('-')[1].replace(" ", "") TYPE = configs.VOXEL[prompt_type.lower()] if DATASET in "S3DIS ScanNet": vx_ = 0.02 elif DATASET == 'Objaverse': vx_ = configs.Objaverse_samples[sample_idx][TYPE][prompt_idx] elif DATASET == 'KITTI': vx_ = configs.KITTI_samples[sample_idx][TYPE][prompt_idx] elif DATASET == 'Semantic3D': vx_ = configs.Semantic3D_samples[sample_idx][TYPE][prompt_idx] return gr.Textbox(label="Response", value="", visible=True), gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=vx_) def main(): title = """

SAM2Point

Segment Any 3D as Videos in Zero-shot and Promptable Manners


""" title = """

Sam2Point

Segment Any 3D as Videos in Zero-shot and Promptable Manners

""" with gr.Blocks( css=""" .contain { display: flex; flex-direction: column; } .gradio-container { height: 100vh !important; } #col_container { height: 100%; } pre { white-space: pre-wrap; /* Since CSS 2.1 */ white-space: -moz-pre-wrap; /* Mozilla, since 1999 */ white-space: -pre-wrap; /* Opera 4-6 */ white-space: -o-pre-wrap; /* Opera 7 */ word-wrap: break-word; /* Internet Explorer 5.5+ */ }""", js=""" function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'light') { url.searchParams.set('__theme', 'light'); window.location.href = url.href; } }""", title="SAM2Point: Segment Any 3D as Videos in Zero-shot and Promptable Manners", theme=gr.themes.Soft() ) as app: gr.HTML(title) with gr.Row(): with gr.Column(elem_id="col_container"): sample_dropdown = gr.Dropdown(label="Select 3D Data Type", choices=samples, type="value") scene_dropdown = gr.Radio(label="Select 3D Object/Scene", choices=[], type="index") show_button = gr.Button("Show 3D Scene/Object") prompt_type_dropdown = gr.Radio(label="Select Prompt Type", choices=prompt_types) prompt_sample_dropdown = gr.Radio(label="Select Prompt Example", choices=[], type="index") show_prompt_button = gr.Button("Show Prompt in 3D Scene/Object") # show_button.input(select, [sample_dropdown, scene_dropdown], []) with gr.Column(): # vx = gr.Slider(minimum=0.01, maximum=0.15, step=0.001, label="Voxel Size", value=0.02) start_segment_button = gr.Button("Start Segmentation") plot1 = gr.Plot() response = gr.Textbox(label="Response") sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response]) sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response]) # sample_dropdown.change(update1, sample_dropdown, [scene_dropdown, response, vx]) # sample_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response, vx]) # scene_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response, vx]) # prompt_type_dropdown.change(update2, [sample_dropdown, scene_dropdown, prompt_type_dropdown], [prompt_sample_dropdown, response, vx]) # prompt_sample_dropdown.change(update3, [sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], [response, vx]) # Logic to handle interactions show_button.click(load_3d_scene, inputs=[sample_dropdown, scene_dropdown], outputs=plot1) show_prompt_button.click(show_prompt_in_3d, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=plot1) # start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown, vx], outputs=[plot1, response]) start_segment_button.click(start_segmentation, inputs=[sample_dropdown, scene_dropdown, prompt_type_dropdown, prompt_sample_dropdown], outputs=[plot1, response]) app.queue(status_update_rate="auto") app.launch(share=True, favicon_path="./logo.png") if __name__ == "__main__": main()