Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text | |
import spaces | |
DEVICE = "cuda:0" | |
#if torch.cuda.is_available(): | |
#DEVICE = "cuda:0" | |
def pred_3d_upsample( | |
pred, # n_subsampled_pts, feat_dim | |
part_text_embeds, # n_parts, feat_dim | |
temperature, | |
xyz_sub, | |
xyz_full, # n_pts, 3 | |
N_CHUNKS=1 | |
): | |
xyz_full = xyz_full.squeeze() | |
logits = pred @ part_text_embeds.T # n_pts, n_mask | |
logits_prepend0 = torch.cat([torch.zeros(logits.shape[0],1).to(DEVICE), logits],axis=1) | |
pred_softmax = torch.nn.Softmax(dim=1)(logits_prepend0 * temperature) | |
chunk_len = xyz_full.shape[0]//N_CHUNKS+1 | |
closest_idx_list = [] | |
for i in range(N_CHUNKS): | |
cur_chunk = xyz_full[chunk_len*i:chunk_len*(i+1)] | |
dist_all = (xyz_sub.unsqueeze(0) - cur_chunk.to(DEVICE).unsqueeze(1))**2 # 300k,5k,3 | |
cur_dist = (dist_all.sum(dim=-1))**0.5 # 300k,5k | |
min_idxs = torch.min(cur_dist, 1)[1] | |
del cur_dist | |
closest_idx_list.append(min_idxs) | |
all_nn_idxs = torch.cat(closest_idx_list,axis=0) | |
# just inversely weight all points | |
all_probs = pred_softmax[all_nn_idxs] | |
all_logits = logits[all_nn_idxs] | |
pred_full = all_probs.argmax(dim=1).cpu()# here, 0 is unlabeled, 1,...n_part correspond to actual part assignment | |
return all_logits, all_probs, pred_full | |
def get_segmentation_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1 | |
temperature = np.exp(model.ln_logit_scale.item()) | |
with torch.no_grad(): | |
for key in data.keys(): | |
if isinstance(data[key], torch.Tensor) and "full" not in key: | |
data[key] = data[key].to(DEVICE) | |
net_out = model(x=data) | |
text_embeds = data['label_embeds'] | |
xyz_sub = data["coord"] | |
xyz_full = data["xyz_full"] | |
_, _, pred_full = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim | |
text_embeds, # n_parts, feat_dim | |
temperature, | |
xyz_sub, | |
xyz_full, # n_pts, 3 | |
N_CHUNKS=N_CHUNKS) | |
seg_rgb = get_seg_color(pred_full.cpu()) | |
return seg_rgb | |
def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1 | |
temperature = np.exp(model.ln_logit_scale.item()) | |
with torch.no_grad(): | |
for key in data.keys(): | |
if isinstance(data[key], torch.Tensor) and "full" not in key: | |
data[key] = data[key].to(DEVICE) | |
net_out = model(x=data) | |
text_embeds = data['label_embeds'] | |
xyz_sub = data["coord"] | |
xyz_full = data["xyz_full"] | |
all_logits, _, _ = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim | |
text_embeds, # n_parts, feat_dim | |
temperature, | |
xyz_sub, | |
xyz_full, # n_pts, 3 | |
N_CHUNKS=N_CHUNKS) | |
scores = all_logits.squeeze().cpu() | |
heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze() | |
return heatmap_rgb | |
def segment_obj(xyz, rgb, normal, queries): | |
model = load_model() | |
data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) | |
data_dict["label_embeds"] = encode_text(queries) | |
seg_rgb = get_segmentation_rgb(model, data_dict) | |
return seg_rgb | |
def get_heatmap(xyz, rgb, normal, query): | |
model = load_model() | |
data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) | |
data_dict["label_embeds"] = encode_text([query]) | |
heatmap_rgb = get_heatmap_rgb(model, data_dict) | |
return heatmap_rgb |