File size: 6,391 Bytes
2cd560a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
import torch
import torch.nn as nn
from .models.data_processor import DataProcessor
from .models.mean_vfe import MeanVFE
from .models.spconv_backbone_voxelnext import VoxelResBackBone8xVoxelNeXt
from .models.voxelnext_head import VoxelNeXtHead

from .utils.image_projection import _proj_voxel_image
from segment_anything import SamPredictor, sam_model_registry

class VoxelNeXt(nn.Module):
    def __init__(self, model_cfg):
        super().__init__()

        point_cloud_range = np.array(model_cfg.POINT_CLOUD_RANGE, dtype=np.float32)

        self.data_processor = DataProcessor(
            model_cfg.DATA_PROCESSOR, point_cloud_range=point_cloud_range,
            training=False, num_point_features=len(model_cfg.USED_FEATURE_LIST)
        )

        input_channels = model_cfg.get('INPUT_CHANNELS', 5)
        grid_size = np.array(model_cfg.get('GRID_SIZE', [1440, 1440, 40]))

        class_names = model_cfg.get('CLASS_NAMES')
        kernel_size_head = model_cfg.get('KERNEL_SIZE_HEAD', 1)
        self.point_cloud_range = torch.Tensor(model_cfg.get('POINT_CLOUD_RANGE', [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0]))
        self.voxel_size = torch.Tensor(model_cfg.get('VOXEL_SIZE', [0.075, 0.075, 0.2]))
        CLASS_NAMES_EACH_HEAD = model_cfg.get('CLASS_NAMES_EACH_HEAD')
        SEPARATE_HEAD_CFG = model_cfg.get('SEPARATE_HEAD_CFG')
        POST_PROCESSING = model_cfg.get('POST_PROCESSING')
        self.voxelization = MeanVFE()
        self.backbone_3d = VoxelResBackBone8xVoxelNeXt(input_channels, grid_size)
        self.dense_head = VoxelNeXtHead(class_names, self.point_cloud_range, self.voxel_size, kernel_size_head,
                 CLASS_NAMES_EACH_HEAD, SEPARATE_HEAD_CFG, POST_PROCESSING)


class Model(nn.Module):
    def __init__(self, model_cfg, device="cuda"):
        super().__init__()

        sam_type = model_cfg.get('SAM_TYPE', "vit_b")
        sam_checkpoint = model_cfg.get('SAM_CHECKPOINT', "/data/sam_vit_b_01ec64.pth")

        sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint).to(device=device)
        self.sam_predictor = SamPredictor(sam)

        voxelnext_checkpoint = model_cfg.get('VOXELNEXT_CHECKPOINT', "/data/voxelnext_nuscenes_kernel1.pth")
        model_dict = torch.load(voxelnext_checkpoint)
        self.voxelnext = VoxelNeXt(model_cfg).to(device=device)
        self.voxelnext.load_state_dict(model_dict)
        self.point_features = {}
        self.device = device

    def image_embedding(self, image):
        self.sam_predictor.set_image(image)

    def point_embedding(self, data_dict, image_id):
        data_dict = self.voxelnext.data_processor.forward(
            data_dict=data_dict
        )
        data_dict['voxels'] = torch.Tensor(data_dict['voxels']).to(self.device)
        data_dict['voxel_num_points'] = torch.Tensor(data_dict['voxel_num_points']).to(self.device)
        data_dict['voxel_coords'] = torch.Tensor(data_dict['voxel_coords']).to(self.device)

        data_dict = self.voxelnext.voxelization(data_dict)
        n_voxels = data_dict['voxel_coords'].shape[0]
        device = data_dict['voxel_coords'].device
        dtype = data_dict['voxel_coords'].dtype
        data_dict['voxel_coords'] = torch.cat([torch.zeros((n_voxels, 1), device=device, dtype=dtype), data_dict['voxel_coords']], dim=1)
        data_dict['batch_size'] = 1

        if not image_id in self.point_features:
            data_dict = self.voxelnext.backbone_3d(data_dict)
            self.point_features[image_id] = data_dict
        else:
            data_dict = self.point_features[image_id]
        pred_dicts = self.voxelnext.dense_head(data_dict)

        voxel_coords = data_dict['out_voxels'][pred_dicts[0]['voxel_ids'].squeeze(-1)] * self.voxelnext.dense_head.feature_map_stride

        return pred_dicts, voxel_coords

    def generate_3D_box(self, lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=0.1):
        device = voxel_coords.device
        points_image, depth = _proj_voxel_image(voxel_coords, lidar2img_rt, self.voxelnext.voxel_size.to(device), self.voxelnext.point_cloud_range.to(device))
        points = points_image.permute(1, 0).int().cpu().numpy()
        selected_voxels = torch.zeros_like(depth).squeeze(0)

        for i in range(points.shape[0]):
            point = points[i]
            if point[0] < 0 or point[1] < 0 or point[0] >= mask.shape[1] or point[1] >= mask.shape[0]:
                continue
            if mask[point[1], point[0]]:
                selected_voxels[i] = 1

        mask_extra = (pred_dicts[0]['pred_scores'] > quality_score)
        if mask_extra.sum() == 0:
            print("no high quality 3D box related.")
            return None

        selected_voxels *= mask_extra
        if selected_voxels.sum() > 0:
            selected_box_id = pred_dicts[0]['pred_scores'][selected_voxels.bool()].argmax()
            selected_box = pred_dicts[0]['pred_boxes'][selected_voxels.bool()][selected_box_id]
        else:
            grid_x, grid_y = torch.meshgrid(torch.arange(mask.shape[0]), torch.arange(mask.shape[1]))
            mask_x, mask_y = grid_x[mask], grid_y[mask]
            mask_center = torch.Tensor([mask_y.float().mean(), mask_x.float().mean()]).to(
                pred_dicts[0]['pred_boxes'].device).unsqueeze(1)

            dist = ((points_image - mask_center) ** 2).sum(0)
            selected_id = dist[mask_extra].argmin()
            selected_box = pred_dicts[0]['pred_boxes'][mask_extra][selected_id]
        return selected_box

    def forward(self, image, point_dict, prompt_point, lidar2img_rt, image_id, quality_score=0.1):
        self.image_embedding(image)
        pred_dicts, voxel_coords = self.point_embedding(point_dict, image_id)

        masks, scores, _ = self.sam_predictor.predict(point_coords=prompt_point, point_labels=np.array([1]))
        mask = masks[0]

        box3d = self.generate_3D_box(lidar2img_rt, mask, voxel_coords, pred_dicts, quality_score=quality_score)
        return mask, box3d


if __name__ == '__main__':
    cfg_dataset = 'nuscenes_dataset.yaml'
    cfg_model = 'config.yaml'

    dataset_cfg = cfg_from_yaml_file(cfg_dataset, cfg)
    model_cfg = cfg_from_yaml_file(cfg_model, cfg)

    nuscenes_dataset = NuScenesDataset(dataset_cfg)
    model = Model(model_cfg)

    index = 0
    data_dict = nuscenes_dataset._get_points(index)
    model.point_embedding(data_dict)