|
import torch |
|
import torch.nn.functional as F |
|
from model.model import PointSemSeg, Find3D |
|
import numpy as np |
|
import random |
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
DEVICE = "cuda:0" |
|
|
|
|
|
|
|
def get_seg_color(labels): |
|
part_num = labels.max() |
|
cmap_matrix = torch.tensor([[1,1,1], [1,0,0], [0,1,0], [0,0,1], [1,1,0], [1,0,1], |
|
[0,1,1], [0.5,0.5,0.5], [0.5,0.5,0], [0.5,0,0.5],[0,0.5,0.5], |
|
[0.1,0.2,0.3],[0.2,0.5,0.3], [0.6,0.3,0.2], [0.5,0.3,0.5], |
|
[0.6,0.7,0.2],[0.5,0.8,0.3]])[:part_num+1,:] |
|
onehot = F.one_hot(labels.long(), num_classes=part_num+1) * 1.0 |
|
pts_rgb = torch.matmul(onehot, cmap_matrix) |
|
return pts_rgb |
|
|
|
def get_legend(parts): |
|
colors = ["white", "red", "green", "blue", "yellow", "magenta", "cyan","grey", "olive", |
|
"purple", "teal", "navy", "darkgreen", "brown", "pinkpurple", "yellowgreen", "limegreen"] |
|
legends = [] |
|
i = 1 |
|
for part in parts: |
|
cur_color = colors[i] |
|
legends.append(f"{cur_color}:{part}") |
|
i += 1 |
|
legend = " ".join(legends) |
|
return legend |
|
|
|
|
|
def load_model(): |
|
model = Find3D.from_pretrained("ziqima/find3d-checkpt0", dim_output=768) |
|
|
|
model.eval() |
|
model = model.to(DEVICE) |
|
return model |
|
|
|
def set_seed(seed): |
|
torch.manual_seed(seed) |
|
if DEVICE != "cpu": |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
|
|
def fnv_hash_vec(arr): |
|
""" |
|
FNV64-1A |
|
""" |
|
assert arr.ndim == 2 |
|
|
|
arr = arr.copy() |
|
arr = arr.astype(np.uint64, copy=False) |
|
hashed_arr = np.uint64(14695981039346656037) * np.ones( |
|
arr.shape[0], dtype=np.uint64 |
|
) |
|
for j in range(arr.shape[1]): |
|
hashed_arr *= np.uint64(1099511628211) |
|
hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j]) |
|
return hashed_arr |
|
|
|
|
|
def grid_sample_numpy(xyz, rgb, normal, grid_size): |
|
xyz = xyz.cpu().numpy() |
|
rgb = rgb.cpu().numpy() |
|
normal = normal.cpu().numpy() |
|
|
|
scaled_coord = xyz / np.array(grid_size) |
|
grid_coord = np.floor(scaled_coord).astype(int) |
|
min_coord = grid_coord.min(0) |
|
grid_coord -= min_coord |
|
scaled_coord -= min_coord |
|
min_coord = min_coord * np.array(grid_size) |
|
key = fnv_hash_vec(grid_coord) |
|
idx_sort = np.argsort(key) |
|
key_sort = key[idx_sort] |
|
_, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True) |
|
idx_select = ( |
|
np.cumsum(np.insert(count, 0, 0)[0:-1]) |
|
+ np.random.randint(0, count.max(), count.size) % count |
|
) |
|
idx_unique = idx_sort[idx_select] |
|
|
|
grid_coord = grid_coord[idx_unique] |
|
|
|
xyz = torch.tensor(xyz[idx_unique]).to(DEVICE) |
|
rgb = torch.tensor(rgb[idx_unique]).to(DEVICE) |
|
normal = torch.tensor(normal[idx_unique]).to(DEVICE) |
|
grid_coord = torch.tensor(grid_coord).to(DEVICE) |
|
|
|
return xyz, rgb, normal, grid_coord |
|
|
|
|
|
def encode_text(texts): |
|
siglip = AutoModel.from_pretrained("google/siglip-base-patch16-224") |
|
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") |
|
inputs = tokenizer(texts, padding="max_length", return_tensors="pt") |
|
for key in inputs: |
|
inputs[key] = inputs[key].to(DEVICE) |
|
with torch.no_grad(): |
|
text_feat = siglip.to(DEVICE).get_text_features(**inputs) |
|
text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-12) |
|
return text_feat |
|
|
|
|
|
def preprocess_pcd(xyz, rgb, normal): |
|
assert rgb.max() <=1 |
|
|
|
|
|
center = xyz.mean(0) |
|
scale = max((xyz - center).abs().max(0)[0]) |
|
xyz -= center |
|
xyz *= (0.75 / float(scale)) |
|
|
|
|
|
xyz = torch.cat([-xyz[:,0].reshape(-1,1), xyz[:,2].reshape(-1,1), xyz[:,1].reshape(-1,1)], dim=1) |
|
|
|
|
|
xyz_min = xyz.min(dim=0)[0] |
|
xyz_max = xyz.max(dim=0)[0] |
|
xyz_max[2] = 0 |
|
shift = (xyz_min+xyz_max)/2 |
|
xyz -= shift |
|
|
|
|
|
if xyz.shape[0] != 5000: |
|
random_indices = torch.randint(0, xyz.shape[0], (5000,)) |
|
pts_xyz_subsampled = xyz[random_indices] |
|
pts_rgb_subsampled = rgb[random_indices] |
|
normal_subsampled = normal[random_indices] |
|
else: |
|
pts_xyz_subsampled = xyz |
|
pts_rgb_subsampled = rgb |
|
normal_subsampled = normal |
|
|
|
|
|
pts_xyz_gridsampled, pts_rgb_gridsampled, normal_gridsampled, grid_coord = grid_sample_numpy(pts_xyz_subsampled, pts_rgb_subsampled, normal_subsampled, 0.02) |
|
|
|
|
|
xyz_min = pts_xyz_gridsampled.min(dim=0)[0] |
|
xyz_min[2] = 0 |
|
xyz_max = pts_xyz_gridsampled.max(dim=0)[0] |
|
xyz_max[2] = 0 |
|
shift = (xyz_min+xyz_max)/2 |
|
pts_xyz_gridsampled -= shift |
|
xyz -= shift |
|
|
|
|
|
pts_rgb_gridsampled = pts_rgb_gridsampled / 0.5 - 1 |
|
|
|
|
|
feat = torch.cat([pts_rgb_gridsampled, normal_gridsampled], dim=1) |
|
|
|
data_dict = {} |
|
data_dict["coord"] = pts_xyz_gridsampled |
|
data_dict["feat"] = feat |
|
data_dict["grid_coord"] = grid_coord |
|
data_dict["xyz_full"] = xyz |
|
data_dict["offset"] = torch.tensor([pts_xyz_gridsampled.shape[0]]) |
|
return data_dict |