Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.autograd import Function | |
from pointops._C import grouping_forward_cuda, grouping_backward_cuda | |
class Grouping(Function): | |
def forward(ctx, input, idx): | |
""" | |
input: input: (n, c), idx : (m, nsample) | |
output: (m, nsample, c) | |
""" | |
assert input.is_contiguous() and idx.is_contiguous() | |
m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1] | |
output = torch.cuda.FloatTensor(m, nsample, c) | |
grouping_forward_cuda(m, nsample, c, input, idx, output) | |
ctx.n = n | |
ctx.save_for_backward(idx) | |
return output | |
def backward(ctx, grad_output): | |
""" | |
input: grad_out: (m, c, nsample) | |
output: (n, c), None | |
""" | |
n = ctx.n | |
(idx,) = ctx.saved_tensors | |
m, nsample, c = grad_output.shape | |
grad_input = torch.cuda.FloatTensor(n, c).zero_() | |
grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input) | |
return grad_input, None | |
def grouping(idx, feat, xyz, new_xyz=None, with_xyz=False): | |
if new_xyz is None: | |
new_xyz = xyz | |
assert xyz.is_contiguous() and feat.is_contiguous() | |
m, nsample, c = idx.shape[0], idx.shape[1], feat.shape[1] | |
xyz = torch.cat([xyz, torch.zeros([1, 3]).to(xyz.device)], dim=0) | |
feat = torch.cat([feat, torch.zeros([1, c]).to(feat.device)], dim=0) | |
grouped_feat = feat[idx.view(-1).long(), :].view( | |
m, nsample, c | |
) # (m, num_sample, c) | |
if with_xyz: | |
assert new_xyz.is_contiguous() | |
mask = torch.sign(idx + 1) | |
grouped_xyz = xyz[idx.view(-1).long(), :].view( | |
m, nsample, 3 | |
) - new_xyz.unsqueeze( | |
1 | |
) # (m, num_sample, 3) | |
grouped_xyz = torch.einsum( | |
"n s c, n s -> n s c", grouped_xyz, mask | |
) # (m, num_sample, 3) | |
return torch.cat((grouped_xyz, grouped_feat), -1) | |
else: | |
return grouped_feat | |
grouping2 = Grouping.apply | |