# -*- coding: utf-8 -*- import os import torch import argparse import numpy as np import open3d as o3d from huggingface_hub import hf_hub_download, HfFolder from segment import seg_point, seg_box, seg_mask import sam2point.dataset as dataset import sam2point.configs as configs from sam2point.voxelizer import Voxelizer from sam2point.utils import cal import matplotlib.pyplot as plt import plotly.graph_objects as go print("Torch CUDA:", torch.cuda.is_available()) # use bfloat16 for the entire notebook torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__() # if torch.cuda.get_device_properties(0).major >= 8: # # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) # torch.backends.cuda.matmul.allow_tf32 = True # torch.backends.cudnn.allow_tf32 = True def run_demo(dataset_name, prompt_type, sample_idx, prompt_idx, voxel_size, theta, mode, ret_prompt): parser = argparse.ArgumentParser() parser.add_argument('--dataset', choices=['S3DIS', 'ScanNet', 'Objaverse', 'KITTI', 'Semantic3D'], default='Objaverse', help='dataset selected') parser.add_argument('--prompt_type', choices=['point', 'box', 'mask'], default='point', help='prompt type selected') parser.add_argument('--sample_idx', type=int, default=2, help='the index of the scene or object') parser.add_argument('--prompt_idx', type=int, default=0, help='the index of the prompt') parser.add_argument('--voxel_size', type=float, default=0.02, help='voxel size') parser.add_argument('--theta', type=float, default=0.5) # indoor NOTE parser.add_argument('--mode', type=str, default='bilinear') # indoor NOTE parser.add_argument("--ret_prompt", action="store_true") args = parser.parse_args() args.dataset, args.prompt_type, args.sample_idx, args.prompt_idx = dataset_name, prompt_type, sample_idx, prompt_idx args.voxel_size, args.theta, args.mode, args.ret_prompt = voxel_size, theta, mode, ret_prompt print(args) #cache name_list = [args.dataset, "sample" + str(args.sample_idx), args.prompt_type + "-prompt" + str(args.prompt_idx)] name = '_'.join(name_list) # hf repo_id = "ZiyuG/Cache" result_name = "cache_results/" + name + '.npy' prompt_name = "cache_prompt/" + name + '.npy' token = os.getenv('HF_TOKEN') try: result_file = hf_hub_download(repo_id=repo_id, filename=result_name, use_auth_token=token, repo_type='dataset') prompt_file = hf_hub_download(repo_id=repo_id, filename=prompt_name, use_auth_token=token, repo_type='dataset') new_color = np.load(result_file) PROMPT = np.load(prompt_file) if not args.ret_prompt: return new_color, PROMPT else: return PROMPT except Exception as e: if os.path.exists("./cache_results/" + name + '.npy') and os.path.exists("./cache_prompt/" + name + '.npy'): new_color = np.load("./cache_results/" + name + '.npy') PROMPT = np.load("./cache_prompt/" + name + '.npy') if not args.ret_prompt: return new_color, PROMPT else: return PROMPT ######### if args.dataset == 'S3DIS': info = configs.S3DIS_samples[args.sample_idx] # early return if args.prompt_type == 'point' and args.ret_prompt: return list(np.array(info['point_prompts'])[args.prompt_idx]) elif args.prompt_type == 'box' and args.ret_prompt: return list(np.array(info['box_prompts'])[args.prompt_idx]) point, color = dataset.load_S3DIS_sample(info['path']) elif args.dataset == 'ScanNet': info = configs.ScanNet_samples[args.sample_idx] # early return if args.prompt_type == 'point' and args.ret_prompt: return list(np.array(info['point_prompts'])[args.prompt_idx]) elif args.prompt_type == 'box' and args.ret_prompt: return list(np.array(info['box_prompts'])[args.prompt_idx]) point, color = dataset.load_ScanNet_sample(info['path']) elif args.dataset == 'Objaverse': info = configs.Objaverse_samples[args.sample_idx] # early return if args.prompt_type == 'point' and args.ret_prompt: return list(np.array(info['point_prompts'])[args.prompt_idx]) elif args.prompt_type == 'box' and args.ret_prompt: return list(np.array(info['box_prompts'])[args.prompt_idx]) point, color = dataset.load_Objaverse_sample(info['path']) args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx] elif args.dataset == 'KITTI': info = configs.KITTI_samples[args.sample_idx] # early return if args.prompt_type == 'point' and args.ret_prompt: return list(np.array(info['point_prompts'])[args.prompt_idx]) elif args.prompt_type == 'box' and args.ret_prompt: return list(np.array(info['box_prompts'])[args.prompt_idx]) point, color = dataset.load_KITTI_sample(info['path']) args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx] elif args.dataset == 'Semantic3D': info = configs.Semantic3D_samples[args.sample_idx] # early return if args.prompt_type == 'point' and args.ret_prompt: return list(np.array(info['point_prompts'])[args.prompt_idx]) elif args.prompt_type == 'box' and args.ret_prompt: return list(np.array(info['box_prompts'])[args.prompt_idx]) point, color = dataset.load_Semantic3D_sample(info['path'], args.sample_idx) args.voxel_size = info[configs.VOXEL[args.prompt_type]][args.prompt_idx] point_color = np.concatenate([point, color], axis=1) voxelizer = Voxelizer(voxel_size=args.voxel_size, clip_bound=None) labels_in = point[:, :1].astype(int) locs, feats, labels, inds_reconstruct = voxelizer.voxelize(point, color, labels_in) if args.prompt_type == 'point': if args.ret_prompt: return list(np.array(info['point_prompts'])[args.prompt_idx]) mask = seg_point(locs, feats, info['point_prompts'], args) point_prompts = np.array(info['point_prompts']) prompt_point = list(point_prompts[args.prompt_idx]) prompt_box = None PROMPT = prompt_point elif args.prompt_type == 'box': if args.ret_prompt: return list(np.array(info['box_prompts'])[args.prompt_idx]) mask = seg_box(locs, feats, info['box_prompts'], args) point_prompts = np.array(info['box_prompts']) prompt_point = None prompt_box = list(point_prompts[args.prompt_idx]) PROMPT = prompt_box elif args.prompt_type == 'mask': if 'mask_prompts' not in info: info['mask_prompts'] = info['point_prompts'] mask, prompt_mask = seg_mask(locs, feats, info['mask_prompts'], args) prompt_point, prompt_box = None, None point_locs = locs[inds_reconstruct] point_prompt_mask = prompt_mask[point_locs[:, 0], point_locs[:, 1], point_locs[:, 2]] point_prompt_mask = point_prompt_mask.unsqueeze(-1) point_prompt_mask_not = ~point_prompt_mask color_prompt_mask = color * point_prompt_mask_not.numpy() + (color * 0 + np.array([[1., 0., 0.]])) * point_prompt_mask.numpy() PROMPT = color_prompt_mask if args.ret_prompt: return color_prompt_mask point_locs = locs[inds_reconstruct] point_mask = mask[point_locs[:, 0], point_locs[:, 1], point_locs[:, 2]] point_mask = point_mask.unsqueeze(-1) point_mask_not = ~point_mask point, color = point_color[:, :3], point_color[:, 3:] new_color = color * point_mask_not.numpy() + (color * 0 + np.array([[0., 1., 0.]])) * point_mask.numpy() name_list = [args.dataset, "sample" + str(args.sample_idx), args.prompt_type + "-prompt" + str(args.prompt_idx)] name = '_'.join(name_list) + 'frames' # os.system('rm -rf ' + name) #cache name_list = [args.dataset, "sample" + str(args.sample_idx), args.prompt_type + "-prompt" + str(args.prompt_idx)] name = '_'.join(name_list) os.makedirs("cache_results", exist_ok=True) os.makedirs("cache_prompt", exist_ok=True) np.save("./cache_results/" + name + '.npy', new_color) np.save("./cache_prompt/" + name + '.npy', PROMPT) return new_color, PROMPT def create_box(prompt): x_min, y_min, z_min, x_max, y_max, z_max = tuple(prompt) bbox_points = np.array([ [x_min, y_min, z_min], [x_max, y_min, z_min], [x_max, y_max, z_min], [x_min, y_max, z_min], [x_min, y_min, z_max], [x_max, y_min, z_max], [x_max, y_max, z_max], [x_min, y_max, z_max] ]) edges = [ (0, 1), (1, 2), (2, 3), (3, 0), # Bottom face (4, 5), (5, 6), (6, 7), (7, 4), # Top face (0, 4), (1, 5), (2, 6), (3, 7) # Vertical edges ] bbox_lines = [] f = 1 for start, end in edges: bbox_lines.append(go.Scatter3d( x=[bbox_points[start, 0], bbox_points[end, 0]], y=[bbox_points[start, 1], bbox_points[end, 1]], z=[bbox_points[start, 2], bbox_points[end, 2]], mode='lines', # line=dict(color='red', width=2), # Customize color and width # line=dict(color='rgb(255, 140, 0)', width=4), # Customize color and width line=dict(color='rgb(220, 20, 60)', width=6), # Customize color and width name="Box Prompt" if f == 1 else "", showlegend=True if f == 1 else False )) f = 0 return bbox_lines