Arnaudding001 commited on
Commit
674f9be
1 Parent(s): 34f7f3a

Create raft_alt_cuda_corr_correlation.cpp

Browse files
Files changed (1) hide show
  1. raft_alt_cuda_corr_correlation.cpp +54 -0
raft_alt_cuda_corr_correlation.cpp ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+
4
+ // CUDA forward declarations
5
+ std::vector<torch::Tensor> corr_cuda_forward(
6
+ torch::Tensor fmap1,
7
+ torch::Tensor fmap2,
8
+ torch::Tensor coords,
9
+ int radius);
10
+
11
+ std::vector<torch::Tensor> corr_cuda_backward(
12
+ torch::Tensor fmap1,
13
+ torch::Tensor fmap2,
14
+ torch::Tensor coords,
15
+ torch::Tensor corr_grad,
16
+ int radius);
17
+
18
+ // C++ interface
19
+ #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
20
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
21
+ #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
22
+
23
+ std::vector<torch::Tensor> corr_forward(
24
+ torch::Tensor fmap1,
25
+ torch::Tensor fmap2,
26
+ torch::Tensor coords,
27
+ int radius) {
28
+ CHECK_INPUT(fmap1);
29
+ CHECK_INPUT(fmap2);
30
+ CHECK_INPUT(coords);
31
+
32
+ return corr_cuda_forward(fmap1, fmap2, coords, radius);
33
+ }
34
+
35
+
36
+ std::vector<torch::Tensor> corr_backward(
37
+ torch::Tensor fmap1,
38
+ torch::Tensor fmap2,
39
+ torch::Tensor coords,
40
+ torch::Tensor corr_grad,
41
+ int radius) {
42
+ CHECK_INPUT(fmap1);
43
+ CHECK_INPUT(fmap2);
44
+ CHECK_INPUT(coords);
45
+ CHECK_INPUT(corr_grad);
46
+
47
+ return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius);
48
+ }
49
+
50
+
51
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
52
+ m.def("forward", &corr_forward, "CORR forward");
53
+ m.def("backward", &corr_backward, "CORR backward");
54
+ }