import torch from torch.autograd import Function from pointops._C import subtraction_forward_cuda, subtraction_backward_cuda class Subtraction(Function): @staticmethod def forward(ctx, input1, input2, idx): """ input: input1: (n, c), input2: (n, c), idx: (n, nsample) output: (n, nsample, c) """ assert input1.is_contiguous() and input2.is_contiguous() n, c = input1.shape nsample = idx.shape[-1] output = torch.cuda.FloatTensor(n, nsample, c).zero_() subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output) ctx.save_for_backward(idx) return output @staticmethod def backward(ctx, grad_output): """ input: grad_out: (n, nsample, c) output: grad_input1: (n, c), grad_input2: (n, c) """ (idx,) = ctx.saved_tensors n, nsample, c = grad_output.shape grad_input1 = torch.cuda.FloatTensor(n, c).zero_() grad_input2 = torch.cuda.FloatTensor(n, c).zero_() subtraction_backward_cuda( n, nsample, c, idx, grad_output, grad_input1, grad_input2 ) return grad_input1, grad_input2, None subtraction = Subtraction.apply