Find3D / inference /utils.py
ziqima's picture
set seed
164964b
raw
history blame
5.47 kB
import torch
import torch.nn.functional as F
from model.model import PointSemSeg, Find3D
import numpy as np
import random
from transformers import AutoTokenizer, AutoModel
DEVICE = "cuda:0"
#if torch.cuda.is_available():
#DEVICE = "cuda:0"
def get_seg_color(labels):
part_num = labels.max()
cmap_matrix = torch.tensor([[1,1,1], [1,0,0], [0,1,0], [0,0,1], [1,1,0], [1,0,1],
[0,1,1], [0.5,0.5,0.5], [0.5,0.5,0], [0.5,0,0.5],[0,0.5,0.5],
[0.1,0.2,0.3],[0.2,0.5,0.3], [0.6,0.3,0.2], [0.5,0.3,0.5],
[0.6,0.7,0.2],[0.5,0.8,0.3]])[:part_num+1,:]
onehot = F.one_hot(labels.long(), num_classes=part_num+1) * 1.0 # n_pts, part_num+1, each row 00.010.0, first place is unlabeled (0 originally)
pts_rgb = torch.matmul(onehot, cmap_matrix)
return pts_rgb
def get_legend(parts):
colors = ["white", "red", "green", "blue", "yellow", "magenta", "cyan","grey", "olive",
"purple", "teal", "navy", "darkgreen", "brown", "pinkpurple", "yellowgreen", "limegreen"]
legends = []
i = 1
for part in parts:
cur_color = colors[i]
legends.append(f"{cur_color}:{part}")
i += 1
legend = " ".join(legends)
return legend
def load_model():
model = Find3D.from_pretrained("ziqima/find3d-checkpt0", dim_output=768)
#model.load_state_dict(torch.load("find3d_checkpoint.pth")["model_state_dict"])
model.eval()
model = model.to(DEVICE)
return model
def fnv_hash_vec(arr):
"""
FNV64-1A
"""
assert arr.ndim == 2
# Floor first for negative coordinates
arr = arr.copy()
arr = arr.astype(np.uint64, copy=False)
hashed_arr = np.uint64(14695981039346656037) * np.ones(
arr.shape[0], dtype=np.uint64
)
for j in range(arr.shape[1]):
hashed_arr *= np.uint64(1099511628211)
hashed_arr = np.bitwise_xor(hashed_arr, arr[:, j])
return hashed_arr
def grid_sample_numpy(xyz, rgb, normal, grid_size): # this should hopefully be 5000 or close
xyz = xyz.cpu().numpy()
rgb = rgb.cpu().numpy()
normal = normal.cpu().numpy()
scaled_coord = xyz / np.array(grid_size)
grid_coord = np.floor(scaled_coord).astype(int)
min_coord = grid_coord.min(0)
grid_coord -= min_coord
scaled_coord -= min_coord
min_coord = min_coord * np.array(grid_size)
key = fnv_hash_vec(grid_coord)
idx_sort = np.argsort(key)
key_sort = key[idx_sort]
_, inverse, count = np.unique(key_sort, return_inverse=True, return_counts=True)
idx_select = (
np.cumsum(np.insert(count, 0, 0)[0:-1])
+ np.random.randint(0, count.max(), count.size) % count
)
idx_unique = idx_sort[idx_select]
grid_coord = grid_coord[idx_unique]
xyz = torch.tensor(xyz[idx_unique]).to(DEVICE)
rgb = torch.tensor(rgb[idx_unique]).to(DEVICE)
normal = torch.tensor(normal[idx_unique]).to(DEVICE)
grid_coord = torch.tensor(grid_coord).to(DEVICE)
return xyz, rgb, normal, grid_coord
def encode_text(texts):
siglip = AutoModel.from_pretrained("google/siglip-base-patch16-224") # dim 768 #"google/siglip-so400m-patch14-384")
tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224")#"google/siglip-so400m-patch14-384")
inputs = tokenizer(texts, padding="max_length", return_tensors="pt")
for key in inputs:
inputs[key] = inputs[key].to(DEVICE)
with torch.no_grad():
text_feat = siglip.to(DEVICE).get_text_features(**inputs)
text_feat = text_feat / (text_feat.norm(dim=-1, keepdim=True) + 1e-12)
return text_feat
def preprocess_pcd(xyz, rgb, normal): # rgb should be 0-1
assert rgb.max() <=1
# normalize
# this is the same preprocessing I do before training
center = xyz.mean(0)
scale = max((xyz - center).abs().max(0)[0])
xyz -= center
xyz *= (0.75 / float(scale)) # put in 0.75-size box
# axis swap
xyz = torch.cat([-xyz[:,0].reshape(-1,1), xyz[:,2].reshape(-1,1), xyz[:,1].reshape(-1,1)], dim=1)
# center shift
xyz_min = xyz.min(dim=0)[0]
xyz_max = xyz.max(dim=0)[0]
xyz_max[2] = 0
shift = (xyz_min+xyz_max)/2
xyz -= shift
# subsample/upsample to 5000 pts for grid sampling
if xyz.shape[0] != 5000:
random_indices = torch.randint(0, xyz.shape[0], (5000,))
pts_xyz_subsampled = xyz[random_indices]
pts_rgb_subsampled = rgb[random_indices]
normal_subsampled = normal[random_indices]
else:
pts_xyz_subsampled = xyz
pts_rgb_subsampled = rgb
normal_subsampled = normal
# grid sampling
pts_xyz_gridsampled, pts_rgb_gridsampled, normal_gridsampled, grid_coord = grid_sample_numpy(pts_xyz_subsampled, pts_rgb_subsampled, normal_subsampled, 0.02)
# another center shift, z=false
xyz_min = pts_xyz_gridsampled.min(dim=0)[0]
xyz_min[2] = 0
xyz_max = pts_xyz_gridsampled.max(dim=0)[0]
xyz_max[2] = 0
shift = (xyz_min+xyz_max)/2
pts_xyz_gridsampled -= shift
xyz -= shift
# normalize color
pts_rgb_gridsampled = pts_rgb_gridsampled / 0.5 - 1
# combine color and normal as feat
feat = torch.cat([pts_rgb_gridsampled, normal_gridsampled], dim=1)
data_dict = {}
data_dict["coord"] = pts_xyz_gridsampled
data_dict["feat"] = feat
data_dict["grid_coord"] = grid_coord
data_dict["xyz_full"] = xyz
data_dict["offset"] = torch.tensor([pts_xyz_gridsampled.shape[0]])
return data_dict