File size: 5,662 Bytes
4893ce0 7591db0 4893ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
import torch.nn.functional as F
from model.model import PointSemSeg, Find3D
import numpy as np
import random
from transformers import AutoTokenizer, AutoModel
DEVICE = "cuda:0"
#if torch.cuda.is_available():
#DEVICE = "cuda:0"
def get_seg_color(labels):
part_num = labels.max()
cmap_matrix = torch.tensor([[1,1,1], [1,0,0], [0,1,0], [0,0,1], [1,1,0], [1,0,1],
[0,1,1], [0.5,0.5,0.5], [0.5,0.5,0], [0.5,0,0.5],[0,0.5,0.5],
[0.1,0.2,0.3],[0.2,0.5,0.3], [0.6,0.3,0.2], [0.5,0.3,0.5],
[0.6,0.7,0.2],[0.5,0.8,0.3]])[:part_num+1,:]
onehot = F.one_hot(labels.long(), num_classes=part_num+1) * 1.0 # n_pts, part_num+1, each row 00.010.0, first place is unlabeled (0 originally)
pts_rgb = torch.matmul(onehot, cmap_matrix)
return pts_rgb
def get_legend(parts):
colors = ["white", "red", "green", "blue", "yellow", "magenta", "cyan","grey", "olive",
"purple", "teal", "navy", "darkgreen", "brown", "pinkpurple", "yellowgreen", "limegreen"]
legends = []
i = 1
for part in parts:
cur_color = colors[i]
legends.append(f"{cur_color}:{part}")
i += 1
legend = " ".join(legends)
return legend
def load_model():
model = Find3D.from_pretrained("ziqima/find3d-checkpt0", dim_output=768)
#model.load_state_dict(torch.load("find3d_checkpoint.pth")["model_state_dict"])
model.eval()
model = model.to(DEVICE)
return model
def set_seed(seed):
torch.manual_seed(seed)
if DEVICE != "cpu":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
def fnv_hash_vec(arr):
"""
FNV64-1A
"""
assert arr.ndim == 2
# Floor first for negative coordinates
arr = arr.copy()
arr = arr.astype(np.uint64, copy=False)
hashed_arr = np.uint64(14695981039346656037) * np.ones(
arr.shape[0], dtype=np.uint64
)
for j in range(arr.shape[1]):
hashed_arr *= np.uint64(1099511628211)
hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
return hashed_arr
def grid_sample_numpy(xyz, rgb, normal, grid_size): # this should hopefully be 5000 or close
xyz = xyz.cpu().numpy()
rgb = rgb.cpu().numpy()
normal = normal.cpu().numpy()
scaled_coord = xyz / np.array(grid_size)
grid_coord = np.floor(scaled_coord).astype(int)
min_coord = grid_coord.min(0)
grid_coord -= min_coord
scaled_coord -= min_coord
min_coord = min_coord * np.array(grid_size)
key = fnv_hash_vec(grid_coord)
idx_sort = np.argsort(key)
key_sort = key[idx_sort]
_, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
idx_select = (
np.cumsum(np.insert(count, 0, 0)[0:-1])
+ np.random.randint(0, count.max(), count.size) % count
)
idx_unique = idx_sort[idx_select]
grid_coord = grid_coord[idx_unique]
xyz = torch.tensor(xyz[idx_unique]).to(DEVICE)
rgb = torch.tensor(rgb[idx_unique]).to(DEVICE)
normal = torch.tensor(normal[idx_unique]).to(DEVICE)
grid_coord = torch.tensor(grid_coord).to(DEVICE)
return xyz, rgb, normal, grid_coord
def encode_text(texts):
siglip = AutoModel.from_pretrained("google/siglip-base-patch16-224") # dim 768 #"google/siglip-so400m-patch14-384")
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")#"google/siglip-so400m-patch14-384")
inputs = tokenizer(texts, padding="max_length", return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(DEVICE)
with torch.no_grad():
text_feat = siglip.to(DEVICE).get_text_features(**inputs)
text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-12)
return text_feat
def preprocess_pcd(xyz, rgb, normal): # rgb should be 0-1
assert rgb.max() <=1
# normalize
# this is the same preprocessing I do before training
center = xyz.mean(0)
scale = max((xyz - center).abs().max(0)[0])
xyz -= center
xyz *= (0.75 / float(scale)) # put in 0.75-size box
# axis swap
xyz = torch.cat([-xyz[:,0].reshape(-1,1), xyz[:,2].reshape(-1,1), xyz[:,1].reshape(-1,1)], dim=1)
# center shift
xyz_min = xyz.min(dim=0)[0]
xyz_max = xyz.max(dim=0)[0]
xyz_max[2] = 0
shift = (xyz_min+xyz_max)/2
xyz -= shift
# subsample/upsample to 5000 pts for grid sampling
if xyz.shape[0] != 5000:
random_indices = torch.randint(0, xyz.shape[0], (5000,))
pts_xyz_subsampled = xyz[random_indices]
pts_rgb_subsampled = rgb[random_indices]
normal_subsampled = normal[random_indices]
else:
pts_xyz_subsampled = xyz
pts_rgb_subsampled = rgb
normal_subsampled = normal
# grid sampling
pts_xyz_gridsampled, pts_rgb_gridsampled, normal_gridsampled, grid_coord = grid_sample_numpy(pts_xyz_subsampled, pts_rgb_subsampled, normal_subsampled, 0.02)
# another center shift, z=false
xyz_min = pts_xyz_gridsampled.min(dim=0)[0]
xyz_min[2] = 0
xyz_max = pts_xyz_gridsampled.max(dim=0)[0]
xyz_max[2] = 0
shift = (xyz_min+xyz_max)/2
pts_xyz_gridsampled -= shift
xyz -= shift
# normalize color
pts_rgb_gridsampled = pts_rgb_gridsampled / 0.5 - 1
# combine color and normal as feat
feat = torch.cat([pts_rgb_gridsampled, normal_gridsampled], dim=1)
data_dict = {}
data_dict["coord"] = pts_xyz_gridsampled
data_dict["feat"] = feat
data_dict["grid_coord"] = grid_coord
data_dict["xyz_full"] = xyz
data_dict["offset"] = torch.tensor([pts_xyz_gridsampled.shape[0]])
return data_dict |