|
import torch |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text |
|
import spaces |
|
|
|
DEVICE = "cuda:0" |
|
|
|
|
|
|
|
def pred_3d_upsample( |
|
pred, |
|
part_text_embeds, |
|
temperature, |
|
xyz_sub, |
|
xyz_full, |
|
N_CHUNKS=1 |
|
): |
|
xyz_full = xyz_full.squeeze() |
|
logits = pred @ part_text_embeds.T |
|
|
|
logits_prepend0 = torch.cat([torch.zeros(logits.shape[0],1).to(DEVICE), logits],axis=1) |
|
pred_softmax = torch.nn.Softmax(dim=1)(logits_prepend0 * temperature) |
|
|
|
chunk_len = xyz_full.shape[0]//N_CHUNKS+1 |
|
closest_idx_list = [] |
|
for i in range(N_CHUNKS): |
|
cur_chunk = xyz_full[chunk_len*i:chunk_len*(i+1)] |
|
dist_all = (xyz_sub.unsqueeze(0) - cur_chunk.to(DEVICE).unsqueeze(1))**2 |
|
cur_dist = (dist_all.sum(dim=-1))**0.5 |
|
min_idxs = torch.min(cur_dist, 1)[1] |
|
del cur_dist |
|
closest_idx_list.append(min_idxs) |
|
all_nn_idxs = torch.cat(closest_idx_list,axis=0) |
|
|
|
all_probs = pred_softmax[all_nn_idxs] |
|
all_logits = logits[all_nn_idxs] |
|
pred_full = all_probs.argmax(dim=1).cpu() |
|
return all_logits, all_probs, pred_full |
|
|
|
def get_segmentation_rgb(model, data, N_CHUNKS=5): |
|
temperature = np.exp(model.ln_logit_scale.item()) |
|
with torch.no_grad(): |
|
for key in data.keys(): |
|
if isinstance(data[key], torch.Tensor) and "full" not in key: |
|
data[key] = data[key].to(DEVICE) |
|
net_out = model(x=data) |
|
text_embeds = data['label_embeds'] |
|
xyz_sub = data["coord"] |
|
xyz_full = data["xyz_full"] |
|
_, _, pred_full = pred_3d_upsample(net_out, |
|
text_embeds, |
|
temperature, |
|
xyz_sub, |
|
xyz_full, |
|
N_CHUNKS=N_CHUNKS) |
|
seg_rgb = get_seg_color(pred_full.cpu()) |
|
return seg_rgb |
|
|
|
def get_heatmap_rgb(model, data, N_CHUNKS=5): |
|
temperature = np.exp(model.ln_logit_scale.item()) |
|
with torch.no_grad(): |
|
for key in data.keys(): |
|
if isinstance(data[key], torch.Tensor) and "full" not in key: |
|
data[key] = data[key].to(DEVICE) |
|
net_out = model(x=data) |
|
text_embeds = data['label_embeds'] |
|
xyz_sub = data["coord"] |
|
xyz_full = data["xyz_full"] |
|
all_logits, _, _ = pred_3d_upsample(net_out, |
|
text_embeds, |
|
temperature, |
|
xyz_sub, |
|
xyz_full, |
|
N_CHUNKS=N_CHUNKS) |
|
scores = all_logits.squeeze().cpu() |
|
heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze() |
|
return heatmap_rgb |
|
|
|
@spaces.GPU |
|
def segment_obj(xyz, rgb, normal, queries): |
|
model = load_model() |
|
data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) |
|
data_dict["label_embeds"] = encode_text(queries) |
|
seg_rgb = get_segmentation_rgb(model, data_dict) |
|
return seg_rgb |
|
|
|
@spaces.GPU |
|
def get_heatmap(xyz, rgb, normal, query): |
|
model = load_model() |
|
data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE)) |
|
data_dict["label_embeds"] = encode_text([query]) |
|
heatmap_rgb = get_heatmap_rgb(model, data_dict) |
|
return heatmap_rgb |