import torch import numpy as np import matplotlib.pyplot as plt from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text import spaces DEVICE = "cuda:0" #if torch.cuda.is_available(): #DEVICE = "cuda:0" def pred_3d_upsample( pred, # n_subsampled_pts, feat_dim part_text_embeds, # n_parts, feat_dim temperature, xyz_sub, xyz_full, # n_pts, 3 N_CHUNKS=1 ): xyz_full = xyz_full.squeeze() logits = pred @ part_text_embeds.T # n_pts, n_mask logits_prepend0 = torch.cat([torch.zeros(logits.shape[0],1).to(DEVICE), logits],axis=1) pred_softmax = torch.nn.Softmax(dim=1)(logits_prepend0 * temperature) chunk_len = xyz_full.shape[0]//N_CHUNKS+1 closest_idx_list = [] for i in range(N_CHUNKS): cur_chunk = xyz_full[chunk_len*i:chunk_len*(i+1)] dist_all = (xyz_sub.unsqueeze(0) - cur_chunk.to(DEVICE).unsqueeze(1))**2 # 300k,5k,3 cur_dist = (dist_all.sum(dim=-1))**0.5 # 300k,5k min_idxs = torch.min(cur_dist, 1)[1] del cur_dist closest_idx_list.append(min_idxs) all_nn_idxs = torch.cat(closest_idx_list,axis=0) # just inversely weight all points all_probs = pred_softmax[all_nn_idxs] all_logits = logits[all_nn_idxs] pred_full = all_probs.argmax(dim=1).cpu()# here, 0 is unlabeled, 1,...n_part correspond to actual part assignment return all_logits, all_probs, pred_full def get_segmentation_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1 temperature = np.exp(model.ln_logit_scale.item()) with torch.no_grad(): for key in data.keys(): if isinstance(data[key], torch.Tensor) and "full" not in key: data[key] = data[key].to(DEVICE) net_out = model(x=data) text_embeds = data['label_embeds'] xyz_sub = data["coord"] xyz_full = data["xyz_full"] _, _, pred_full = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim text_embeds, # n_parts, feat_dim temperature, xyz_sub, xyz_full, # n_pts, 3 N_CHUNKS=N_CHUNKS) seg_rgb = get_seg_color(pred_full.cpu()) return seg_rgb def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1 temperature = np.exp(model.ln_logit_scale.item()) with torch.no_grad(): for key in data.keys(): if isinstance(data[key], torch.Tensor) and "full" not in key: data[key] = data[key].to(DEVICE) net_out = model(x=data) text_embeds = data['label_embeds'] xyz_sub = data["coord"] xyz_full = data["xyz_full"] all_logits, _, _ = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim text_embeds, # n_parts, feat_dim temperature, xyz_sub, xyz_full, # n_pts, 3 N_CHUNKS=N_CHUNKS) scores = all_logits.squeeze().cpu() heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze() return heatmap_rgb @spaces.GPU def segment_obj(xyz, rgb, normal, queries): model = load_model() data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) data_dict["label_embeds"] = encode_text(queries) seg_rgb = get_segmentation_rgb(model, data_dict) return seg_rgb @spaces.GPU def get_heatmap(xyz, rgb, normal, query): model = load_model() data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) data_dict["label_embeds"] = encode_text([query]) heatmap_rgb = get_heatmap_rgb(model, data_dict) return heatmap_rgb