Arnaudding001 commited on
Commit
511a2cd
1 Parent(s): 674f9be

Create raft_alt_cuda_corr_correlation_kernel.cu

Browse files
raft_alt_cuda_corr_correlation_kernel.cu ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <cuda.h>
3
+ #include <cuda_runtime.h>
4
+ #include <vector>
5
+
6
+
7
+ #define BLOCK_H 4
8
+ #define BLOCK_W 8
9
+ #define BLOCK_HW BLOCK_H * BLOCK_W
10
+ #define CHANNEL_STRIDE 32
11
+
12
+
13
+ __forceinline__ __device__
14
+ bool within_bounds(int h, int w, int H, int W) {
15
+ return h >= 0 && h < H && w >= 0 && w < W;
16
+ }
17
+
18
+ template <typename scalar_t>
19
+ __global__ void corr_forward_kernel(
20
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
21
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
22
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
23
+ torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
24
+ int r)
25
+ {
26
+ const int b = blockIdx.x;
27
+ const int h0 = blockIdx.y * blockDim.x;
28
+ const int w0 = blockIdx.z * blockDim.y;
29
+ const int tid = threadIdx.x * blockDim.y + threadIdx.y;
30
+
31
+ const int H1 = fmap1.size(1);
32
+ const int W1 = fmap1.size(2);
33
+ const int H2 = fmap2.size(1);
34
+ const int W2 = fmap2.size(2);
35
+ const int N = coords.size(1);
36
+ const int C = fmap1.size(3);
37
+
38
+ __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
39
+ __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
40
+ __shared__ scalar_t x2s[BLOCK_HW];
41
+ __shared__ scalar_t y2s[BLOCK_HW];
42
+
43
+ for (int c=0; c<C; c+=CHANNEL_STRIDE) {
44
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
45
+ int k1 = k + tid / CHANNEL_STRIDE;
46
+ int h1 = h0 + k1 / BLOCK_W;
47
+ int w1 = w0 + k1 % BLOCK_W;
48
+ int c1 = tid % CHANNEL_STRIDE;
49
+
50
+ auto fptr = fmap1[b][h1][w1];
51
+ if (within_bounds(h1, w1, H1, W1))
52
+ f1[c1][k1] = fptr[c+c1];
53
+ else
54
+ f1[c1][k1] = 0.0;
55
+ }
56
+
57
+ __syncthreads();
58
+
59
+ for (int n=0; n<N; n++) {
60
+ int h1 = h0 + threadIdx.x;
61
+ int w1 = w0 + threadIdx.y;
62
+ if (within_bounds(h1, w1, H1, W1)) {
63
+ x2s[tid] = coords[b][n][h1][w1][0];
64
+ y2s[tid] = coords[b][n][h1][w1][1];
65
+ }
66
+
67
+ scalar_t dx = x2s[tid] - floor(x2s[tid]);
68
+ scalar_t dy = y2s[tid] - floor(y2s[tid]);
69
+
70
+ int rd = 2*r + 1;
71
+ for (int iy=0; iy<rd+1; iy++) {
72
+ for (int ix=0; ix<rd+1; ix++) {
73
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
74
+ int k1 = k + tid / CHANNEL_STRIDE;
75
+ int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
76
+ int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
77
+ int c2 = tid % CHANNEL_STRIDE;
78
+
79
+ auto fptr = fmap2[b][h2][w2];
80
+ if (within_bounds(h2, w2, H2, W2))
81
+ f2[c2][k1] = fptr[c+c2];
82
+ else
83
+ f2[c2][k1] = 0.0;
84
+ }
85
+
86
+ __syncthreads();
87
+
88
+ scalar_t s = 0.0;
89
+ for (int k=0; k<CHANNEL_STRIDE; k++)
90
+ s += f1[k][tid] * f2[k][tid];
91
+
92
+ int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
93
+ int ix_ne = H1*W1*((iy-1) + rd*ix);
94
+ int ix_sw = H1*W1*(iy + rd*(ix-1));
95
+ int ix_se = H1*W1*(iy + rd*ix);
96
+
97
+ scalar_t nw = s * (dy) * (dx);
98
+ scalar_t ne = s * (dy) * (1-dx);
99
+ scalar_t sw = s * (1-dy) * (dx);
100
+ scalar_t se = s * (1-dy) * (1-dx);
101
+
102
+ scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
103
+
104
+ if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
105
+ *(corr_ptr + ix_nw) += nw;
106
+
107
+ if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
108
+ *(corr_ptr + ix_ne) += ne;
109
+
110
+ if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
111
+ *(corr_ptr + ix_sw) += sw;
112
+
113
+ if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
114
+ *(corr_ptr + ix_se) += se;
115
+ }
116
+ }
117
+ }
118
+ }
119
+ }
120
+
121
+
122
+ template <typename scalar_t>
123
+ __global__ void corr_backward_kernel(
124
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
125
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
126
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
127
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
128
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
129
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
130
+ torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
131
+ int r)
132
+ {
133
+
134
+ const int b = blockIdx.x;
135
+ const int h0 = blockIdx.y * blockDim.x;
136
+ const int w0 = blockIdx.z * blockDim.y;
137
+ const int tid = threadIdx.x * blockDim.y + threadIdx.y;
138
+
139
+ const int H1 = fmap1.size(1);
140
+ const int W1 = fmap1.size(2);
141
+ const int H2 = fmap2.size(1);
142
+ const int W2 = fmap2.size(2);
143
+ const int N = coords.size(1);
144
+ const int C = fmap1.size(3);
145
+
146
+ __shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
147
+ __shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
148
+
149
+ __shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
150
+ __shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
151
+
152
+ __shared__ scalar_t x2s[BLOCK_HW];
153
+ __shared__ scalar_t y2s[BLOCK_HW];
154
+
155
+ for (int c=0; c<C; c+=CHANNEL_STRIDE) {
156
+
157
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
158
+ int k1 = k + tid / CHANNEL_STRIDE;
159
+ int h1 = h0 + k1 / BLOCK_W;
160
+ int w1 = w0 + k1 % BLOCK_W;
161
+ int c1 = tid % CHANNEL_STRIDE;
162
+
163
+ auto fptr = fmap1[b][h1][w1];
164
+ if (within_bounds(h1, w1, H1, W1))
165
+ f1[c1][k1] = fptr[c+c1];
166
+ else
167
+ f1[c1][k1] = 0.0;
168
+
169
+ f1_grad[c1][k1] = 0.0;
170
+ }
171
+
172
+ __syncthreads();
173
+
174
+ int h1 = h0 + threadIdx.x;
175
+ int w1 = w0 + threadIdx.y;
176
+
177
+ for (int n=0; n<N; n++) {
178
+ x2s[tid] = coords[b][n][h1][w1][0];
179
+ y2s[tid] = coords[b][n][h1][w1][1];
180
+
181
+ scalar_t dx = x2s[tid] - floor(x2s[tid]);
182
+ scalar_t dy = y2s[tid] - floor(y2s[tid]);
183
+
184
+ int rd = 2*r + 1;
185
+ for (int iy=0; iy<rd+1; iy++) {
186
+ for (int ix=0; ix<rd+1; ix++) {
187
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
188
+ int k1 = k + tid / CHANNEL_STRIDE;
189
+ int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
190
+ int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
191
+ int c2 = tid % CHANNEL_STRIDE;
192
+
193
+ auto fptr = fmap2[b][h2][w2];
194
+ if (within_bounds(h2, w2, H2, W2))
195
+ f2[c2][k1] = fptr[c+c2];
196
+ else
197
+ f2[c2][k1] = 0.0;
198
+
199
+ f2_grad[c2][k1] = 0.0;
200
+ }
201
+
202
+ __syncthreads();
203
+
204
+ const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
205
+ scalar_t g = 0.0;
206
+
207
+ int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
208
+ int ix_ne = H1*W1*((iy-1) + rd*ix);
209
+ int ix_sw = H1*W1*(iy + rd*(ix-1));
210
+ int ix_se = H1*W1*(iy + rd*ix);
211
+
212
+ if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
213
+ g += *(grad_ptr + ix_nw) * dy * dx;
214
+
215
+ if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
216
+ g += *(grad_ptr + ix_ne) * dy * (1-dx);
217
+
218
+ if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
219
+ g += *(grad_ptr + ix_sw) * (1-dy) * dx;
220
+
221
+ if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
222
+ g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
223
+
224
+ for (int k=0; k<CHANNEL_STRIDE; k++) {
225
+ f1_grad[k][tid] += g * f2[k][tid];
226
+ f2_grad[k][tid] += g * f1[k][tid];
227
+ }
228
+
229
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
230
+ int k1 = k + tid / CHANNEL_STRIDE;
231
+ int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
232
+ int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
233
+ int c2 = tid % CHANNEL_STRIDE;
234
+
235
+ scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
236
+ if (within_bounds(h2, w2, H2, W2))
237
+ atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
238
+ }
239
+ }
240
+ }
241
+ }
242
+ __syncthreads();
243
+
244
+
245
+ for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
246
+ int k1 = k + tid / CHANNEL_STRIDE;
247
+ int h1 = h0 + k1 / BLOCK_W;
248
+ int w1 = w0 + k1 % BLOCK_W;
249
+ int c1 = tid % CHANNEL_STRIDE;
250
+
251
+ scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
252
+ if (within_bounds(h1, w1, H1, W1))
253
+ fptr[c+c1] += f1_grad[c1][k1];
254
+ }
255
+ }
256
+ }
257
+
258
+
259
+
260
+ std::vector<torch::Tensor> corr_cuda_forward(
261
+ torch::Tensor fmap1,
262
+ torch::Tensor fmap2,
263
+ torch::Tensor coords,
264
+ int radius)
265
+ {
266
+ const auto B = coords.size(0);
267
+ const auto N = coords.size(1);
268
+ const auto H = coords.size(2);
269
+ const auto W = coords.size(3);
270
+
271
+ const auto rd = 2 * radius + 1;
272
+ auto opts = fmap1.options();
273
+ auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
274
+
275
+ const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
276
+ const dim3 threads(BLOCK_H, BLOCK_W);
277
+
278
+ corr_forward_kernel<float><<<blocks, threads>>>(
279
+ fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
280
+ fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
281
+ coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
282
+ corr.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
283
+ radius);
284
+
285
+ return {corr};
286
+ }
287
+
288
+ std::vector<torch::Tensor> corr_cuda_backward(
289
+ torch::Tensor fmap1,
290
+ torch::Tensor fmap2,
291
+ torch::Tensor coords,
292
+ torch::Tensor corr_grad,
293
+ int radius)
294
+ {
295
+ const auto B = coords.size(0);
296
+ const auto N = coords.size(1);
297
+
298
+ const auto H1 = fmap1.size(1);
299
+ const auto W1 = fmap1.size(2);
300
+ const auto H2 = fmap2.size(1);
301
+ const auto W2 = fmap2.size(2);
302
+ const auto C = fmap1.size(3);
303
+
304
+ auto opts = fmap1.options();
305
+ auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
306
+ auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
307
+ auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
308
+
309
+ const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
310
+ const dim3 threads(BLOCK_H, BLOCK_W);
311
+
312
+
313
+ corr_backward_kernel<float><<<blocks, threads>>>(
314
+ fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
315
+ fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
316
+ coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
317
+ corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
318
+ fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
319
+ fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
320
+ coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
321
+ radius);
322
+
323
+ return {fmap1_grad, fmap2_grad, coords_grad};
324
+ }