File size: 3,973 Bytes
4893ce0
 
 
 
9a7ea82
4893ce0
7591db0
 
 
4893ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a7ea82
f0dc24f
4893ce0
 
 
 
 
 
9a7ea82
f0dc24f
4893ce0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import numpy as np
import matplotlib.pyplot as plt
from inference.utils import get_seg_color, load_model, preprocess_pcd, encode_text
import spaces

DEVICE = "cuda:0"
#if torch.cuda.is_available():
    #DEVICE = "cuda:0"

def pred_3d_upsample(
        pred, # n_subsampled_pts, feat_dim
        part_text_embeds, # n_parts, feat_dim
        temperature,
        xyz_sub,
        xyz_full, # n_pts, 3
        N_CHUNKS=1
        ):
    xyz_full = xyz_full.squeeze()
    logits = pred @ part_text_embeds.T # n_pts, n_mask

    logits_prepend0 = torch.cat([torch.zeros(logits.shape[0],1).to(DEVICE), logits],axis=1)
    pred_softmax = torch.nn.Softmax(dim=1)(logits_prepend0 * temperature)
    
    chunk_len = xyz_full.shape[0]//N_CHUNKS+1
    closest_idx_list = []
    for i in range(N_CHUNKS):
        cur_chunk = xyz_full[chunk_len*i:chunk_len*(i+1)]
        dist_all = (xyz_sub.unsqueeze(0) - cur_chunk.to(DEVICE).unsqueeze(1))**2 # 300k,5k,3
        cur_dist = (dist_all.sum(dim=-1))**0.5 # 300k,5k
        min_idxs = torch.min(cur_dist, 1)[1]
        del cur_dist
        closest_idx_list.append(min_idxs)
    all_nn_idxs = torch.cat(closest_idx_list,axis=0)
    # just inversely weight all points
    all_probs = pred_softmax[all_nn_idxs]
    all_logits = logits[all_nn_idxs]
    pred_full = all_probs.argmax(dim=1).cpu()# here, 0 is unlabeled, 1,...n_part correspond to actual part assignment
    return all_logits, all_probs, pred_full

def get_segmentation_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1
    temperature = np.exp(model.ln_logit_scale.item())
    with torch.no_grad():
        for key in data.keys():
            if isinstance(data[key], torch.Tensor) and "full" not in key:
                data[key] = data[key].to(DEVICE)
        net_out = model(x=data)
        text_embeds = data['label_embeds']
        xyz_sub = data["coord"]
        xyz_full = data["xyz_full"]
        _, _, pred_full = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim
                            text_embeds, # n_parts, feat_dim
                            temperature,
                            xyz_sub,
                            xyz_full, # n_pts, 3
                            N_CHUNKS=N_CHUNKS)
        seg_rgb = get_seg_color(pred_full.cpu())
        return seg_rgb

def get_heatmap_rgb(model, data, N_CHUNKS=5): # evaluate loader can only have batch size=1
    temperature = np.exp(model.ln_logit_scale.item())
    with torch.no_grad():
        for key in data.keys():
            if isinstance(data[key], torch.Tensor) and "full" not in key:
                data[key] = data[key].to(DEVICE)
        net_out = model(x=data)
        text_embeds = data['label_embeds']
        xyz_sub = data["coord"]
        xyz_full = data["xyz_full"]
        all_logits, _, _ = pred_3d_upsample(net_out, # n_subsampled_pts, feat_dim
                            text_embeds, # n_parts, feat_dim
                            temperature,
                            xyz_sub,
                            xyz_full, # n_pts, 3
                            N_CHUNKS=N_CHUNKS)
        scores = all_logits.squeeze().cpu()
        heatmap_rgb = torch.tensor(plt.cm.jet(scores.numpy())[:,:3]).squeeze()
        return heatmap_rgb
    
@spaces.GPU
def segment_obj(xyz, rgb, normal, queries):
    model = load_model()
    data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
    data_dict["label_embeds"] = encode_text(queries)
    seg_rgb = get_segmentation_rgb(model, data_dict)
    return seg_rgb
    
@spaces.GPU
def get_heatmap(xyz, rgb, normal, query):
    model = load_model()
    data_dict = preprocess_pcd(torch.tensor(xyz).float().to(DEVICE), torch.tensor(rgb).float().to(DEVICE), torch.tensor(normal).float().to(DEVICE))
    data_dict["label_embeds"] = encode_text([query])
    heatmap_rgb = get_heatmap_rgb(model, data_dict)
    return heatmap_rgb