Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch.autograd import Function | |
from pointops._C import interpolation_forward_cuda, interpolation_backward_cuda | |
from .query import knn_query | |
def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3): | |
""" | |
input: coords: (m, 3), new_xyz: (n, 3), color: (m, c), offset: (b), new_offset: (b) | |
output: (n, c) | |
""" | |
assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous() | |
idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, 3), (n, 3) | |
dist_recip = 1.0 / (dist + 1e-8) # (n, 3) | |
norm = torch.sum(dist_recip, dim=1, keepdim=True) | |
weight = dist_recip / norm # (n, 3) | |
new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_() | |
for i in range(k): | |
new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1) | |
return new_feat | |
class Interpolation(Function): | |
def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3): | |
""" | |
input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) | |
output: (n, c) | |
""" | |
assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous() | |
idx, dist = knn_query(k, xyz, offset, new_xyz, new_offset) # (n, k), (n, k) | |
dist_recip = 1.0 / (dist + 1e-8) # (n, k) | |
norm = torch.sum(dist_recip, dim=1, keepdim=True) | |
weight = dist_recip / norm # (n, k) | |
n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0] | |
output = torch.cuda.FloatTensor(n, c).zero_() | |
interpolation_forward_cuda(n, c, k, input, idx, weight, output) | |
ctx.m, ctx.k = m, k | |
ctx.save_for_backward(idx, weight) | |
return output | |
def backward(ctx, grad_output): | |
""" | |
input: coords: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b) | |
output: (n, c) | |
""" | |
m, k = ctx.m, ctx.k | |
idx, weight = ctx.saved_tensors | |
n, c = grad_output.shape | |
grad_input = torch.cuda.FloatTensor(m, c).zero_() | |
interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input) | |
return None, None, grad_input, None, None, None | |
interpolation2 = Interpolation.apply | |