import torch from torch.autograd import Function from pointops._C import ( attention_relation_step_forward_cuda, attention_relation_step_backward_cuda, attention_fusion_step_forward_cuda, attention_fusion_step_backward_cuda, ) class AttentionRelationStep(Function): @staticmethod def forward(ctx, query, key, weight, index_target, index_refer): """ input - query: (n, g, c), key: (n, g, c), weight: (c) 1_c for scatter attention, index_target: (m), index_refer: (m) output - relation: (M, g) """ assert ( query.is_contiguous() and key.is_contiguous() and index_target.is_contiguous() and index_refer.is_contiguous() and weight.is_contiguous() ) assert index_target.shape[0] == index_refer.shape[0] _, g, c = query.shape m = index_target.shape[0] output = torch.cuda.FloatTensor(m, g).zero_() attention_relation_step_forward_cuda( m, g, c, query, key, weight, index_target.int(), index_refer.int(), output ) ctx.save_for_backward(query, key, weight, index_target, index_refer) return output @staticmethod def backward(ctx, grad_output): query, key, weight, index_target, index_refer = ctx.saved_tensors n, g, c = query.shape m = index_target.shape[0] grad_query = torch.cuda.FloatTensor(n, g, c).zero_() grad_key = torch.cuda.FloatTensor(n, g, c).zero_() grad_weight = torch.cuda.FloatTensor(c).zero_() attention_relation_step_backward_cuda( m, g, c, query, grad_query, key, grad_key, weight, grad_weight, index_target.int(), index_refer.int(), grad_output, ) return grad_query, grad_key, None, None, None class AttentionFusionStep(Function): @staticmethod def forward(ctx, weight, value, index_target, index_refer): """ input - weight: (m, g), value: (n, g, c) index_target: (m), index_value: (m) output - output: (n, g, c) """ assert ( weight.is_contiguous() and value.is_contiguous() and index_target.is_contiguous() and index_refer.is_contiguous() and weight.is_contiguous() ) assert index_target.shape[0] == index_refer.shape[0] n, g, c = value.shape m = index_refer.shape[0] output = torch.cuda.FloatTensor(n, g, c).zero_() attention_fusion_step_forward_cuda( m, g, c, weight, value, index_target.int(), index_refer.int(), output ) ctx.save_for_backward(weight, value, index_target, index_refer) return output @staticmethod def backward(ctx, grad_output): """ input: grad_output: (n, g, c) output: grad_weight: (m, g), grad_value: (n, g, c), none, none """ weight, value, index_target, index_refer = ctx.saved_tensors n, g, c = value.shape m = index_target.shape[0] grad_weight = torch.cuda.FloatTensor(m, g).zero_() grad_value = torch.cuda.FloatTensor(n, g, c).zero_() attention_fusion_step_backward_cuda( m, g, c, weight, grad_weight, value, grad_value, index_target.int(), index_refer.int(), grad_output, ) return grad_weight, grad_value, None, None attention_relation_step = AttentionRelationStep.apply attention_fusion_step = AttentionFusionStep.apply