File size: 832 Bytes
4893ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch.autograd import Function

from pointops._C import farthest_point_sampling_cuda


class FarthestPointSampling(Function):
    @staticmethod
    def forward(ctx, xyz, offset, new_offset):
        """
        input: coords: (n, 3), offset: (b), new_offset: (b)
        output: idx: (m)
        """
        assert xyz.is_contiguous()
        n, b, n_max = xyz.shape[0], offset.shape[0], offset[0]
        for i in range(1, b):
            n_max = max(offset[i] - offset[i - 1], n_max)
        idx = torch.cuda.IntTensor(new_offset[b - 1].item()).zero_()
        tmp = torch.cuda.FloatTensor(n).fill_(1e10)
        farthest_point_sampling_cuda(
            b, n_max, xyz, offset.int(), new_offset.int(), tmp, idx
        )
        del tmp
        return idx


farthest_point_sampling = FarthestPointSampling.apply