CorvaeOboro commited on
Commit
f87aee2
β€’
1 Parent(s): d0a89cf

Upload upfirdn2d.py

Browse files
Files changed (1) hide show
  1. torch_utils/ops/upfirdn2d.py +384 -0
torch_utils/ops/upfirdn2d.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Custom PyTorch ops for efficient resampling of 2D images."""
10
+
11
+ import os
12
+ import warnings
13
+ import numpy as np
14
+ import torch
15
+ import traceback
16
+
17
+ from .. import custom_ops
18
+ from .. import misc
19
+ from . import conv2d_gradfix
20
+
21
+ #----------------------------------------------------------------------------
22
+
23
+ _inited = False
24
+ _plugin = None
25
+
26
+ def _init():
27
+ global _inited, _plugin
28
+ if not _inited:
29
+ sources = ['upfirdn2d.cpp', 'upfirdn2d.cu']
30
+ sources = [os.path.join(os.path.dirname(__file__), s) for s in sources]
31
+ try:
32
+ _plugin = custom_ops.get_plugin('upfirdn2d_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math'])
33
+ except:
34
+ warnings.warn('Failed to build CUDA kernels for upfirdn2d. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc())
35
+ return _plugin is not None
36
+
37
+ def _parse_scaling(scaling):
38
+ if isinstance(scaling, int):
39
+ scaling = [scaling, scaling]
40
+ assert isinstance(scaling, (list, tuple))
41
+ assert all(isinstance(x, int) for x in scaling)
42
+ sx, sy = scaling
43
+ assert sx >= 1 and sy >= 1
44
+ return sx, sy
45
+
46
+ def _parse_padding(padding):
47
+ if isinstance(padding, int):
48
+ padding = [padding, padding]
49
+ assert isinstance(padding, (list, tuple))
50
+ assert all(isinstance(x, int) for x in padding)
51
+ if len(padding) == 2:
52
+ padx, pady = padding
53
+ padding = [padx, padx, pady, pady]
54
+ padx0, padx1, pady0, pady1 = padding
55
+ return padx0, padx1, pady0, pady1
56
+
57
+ def _get_filter_size(f):
58
+ if f is None:
59
+ return 1, 1
60
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
61
+ fw = f.shape[-1]
62
+ fh = f.shape[0]
63
+ with misc.suppress_tracer_warnings():
64
+ fw = int(fw)
65
+ fh = int(fh)
66
+ misc.assert_shape(f, [fh, fw][:f.ndim])
67
+ assert fw >= 1 and fh >= 1
68
+ return fw, fh
69
+
70
+ #----------------------------------------------------------------------------
71
+
72
+ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
73
+ r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
74
+
75
+ Args:
76
+ f: Torch tensor, numpy array, or python list of the shape
77
+ `[filter_height, filter_width]` (non-separable),
78
+ `[filter_taps]` (separable),
79
+ `[]` (impulse), or
80
+ `None` (identity).
81
+ device: Result device (default: cpu).
82
+ normalize: Normalize the filter so that it retains the magnitude
83
+ for constant input signal (DC)? (default: True).
84
+ flip_filter: Flip the filter? (default: False).
85
+ gain: Overall scaling factor for signal magnitude (default: 1).
86
+ separable: Return a separable filter? (default: select automatically).
87
+
88
+ Returns:
89
+ Float32 tensor of the shape
90
+ `[filter_height, filter_width]` (non-separable) or
91
+ `[filter_taps]` (separable).
92
+ """
93
+ # Validate.
94
+ if f is None:
95
+ f = 1
96
+ f = torch.as_tensor(f, dtype=torch.float32)
97
+ assert f.ndim in [0, 1, 2]
98
+ assert f.numel() > 0
99
+ if f.ndim == 0:
100
+ f = f[np.newaxis]
101
+
102
+ # Separable?
103
+ if separable is None:
104
+ separable = (f.ndim == 1 and f.numel() >= 8)
105
+ if f.ndim == 1 and not separable:
106
+ f = f.ger(f)
107
+ assert f.ndim == (1 if separable else 2)
108
+
109
+ # Apply normalize, flip, gain, and device.
110
+ if normalize:
111
+ f /= f.sum()
112
+ if flip_filter:
113
+ f = f.flip(list(range(f.ndim)))
114
+ f = f * (gain ** (f.ndim / 2))
115
+ f = f.to(device=device)
116
+ return f
117
+
118
+ #----------------------------------------------------------------------------
119
+
120
+ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
121
+ r"""Pad, upsample, filter, and downsample a batch of 2D images.
122
+
123
+ Performs the following sequence of operations for each channel:
124
+
125
+ 1. Upsample the image by inserting N-1 zeros after each pixel (`up`).
126
+
127
+ 2. Pad the image with the specified number of zeros on each side (`padding`).
128
+ Negative padding corresponds to cropping the image.
129
+
130
+ 3. Convolve the image with the specified 2D FIR filter (`f`), shrinking it
131
+ so that the footprint of all output pixels lies within the input image.
132
+
133
+ 4. Downsample the image by keeping every Nth pixel (`down`).
134
+
135
+ This sequence of operations bears close resemblance to scipy.signal.upfirdn().
136
+ The fused op is considerably more efficient than performing the same calculation
137
+ using standard PyTorch ops. It supports gradients of arbitrary order.
138
+
139
+ Args:
140
+ x: Float32/float64/float16 input tensor of the shape
141
+ `[batch_size, num_channels, in_height, in_width]`.
142
+ f: Float32 FIR filter of the shape
143
+ `[filter_height, filter_width]` (non-separable),
144
+ `[filter_taps]` (separable), or
145
+ `None` (identity).
146
+ up: Integer upsampling factor. Can be a single int or a list/tuple
147
+ `[x, y]` (default: 1).
148
+ down: Integer downsampling factor. Can be a single int or a list/tuple
149
+ `[x, y]` (default: 1).
150
+ padding: Padding with respect to the upsampled image. Can be a single number
151
+ or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
152
+ (default: 0).
153
+ flip_filter: False = convolution, True = correlation (default: False).
154
+ gain: Overall scaling factor for signal magnitude (default: 1).
155
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
156
+
157
+ Returns:
158
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
159
+ """
160
+ assert isinstance(x, torch.Tensor)
161
+ assert impl in ['ref', 'cuda']
162
+ if impl == 'cuda' and x.device.type == 'cuda' and _init():
163
+ return _upfirdn2d_cuda(up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain).apply(x, f)
164
+ return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
165
+
166
+ #----------------------------------------------------------------------------
167
+
168
+ @misc.profiled_function
169
+ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
170
+ """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
171
+ """
172
+ # Validate arguments.
173
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
174
+ if f is None:
175
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
176
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
177
+ assert f.dtype == torch.float32 and not f.requires_grad
178
+ batch_size, num_channels, in_height, in_width = x.shape
179
+ upx, upy = _parse_scaling(up)
180
+ downx, downy = _parse_scaling(down)
181
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
182
+
183
+ # Upsample by inserting zeros.
184
+ x = x.reshape([batch_size, num_channels, in_height, 1, in_width, 1])
185
+ x = torch.nn.functional.pad(x, [0, upx - 1, 0, 0, 0, upy - 1])
186
+ x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
187
+
188
+ # Pad or crop.
189
+ x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
190
+ x = x[:, :, max(-pady0, 0) : x.shape[2] - max(-pady1, 0), max(-padx0, 0) : x.shape[3] - max(-padx1, 0)]
191
+
192
+ # Setup filter.
193
+ f = f * (gain ** (f.ndim / 2))
194
+ f = f.to(x.dtype)
195
+ if not flip_filter:
196
+ f = f.flip(list(range(f.ndim)))
197
+
198
+ # Convolve with the filter.
199
+ f = f[np.newaxis, np.newaxis].repeat([num_channels, 1] + [1] * f.ndim)
200
+ if f.ndim == 4:
201
+ x = conv2d_gradfix.conv2d(input=x, weight=f, groups=num_channels)
202
+ else:
203
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(2), groups=num_channels)
204
+ x = conv2d_gradfix.conv2d(input=x, weight=f.unsqueeze(3), groups=num_channels)
205
+
206
+ # Downsample by throwing away pixels.
207
+ x = x[:, :, ::downy, ::downx]
208
+ return x
209
+
210
+ #----------------------------------------------------------------------------
211
+
212
+ _upfirdn2d_cuda_cache = dict()
213
+
214
+ def _upfirdn2d_cuda(up=1, down=1, padding=0, flip_filter=False, gain=1):
215
+ """Fast CUDA implementation of `upfirdn2d()` using custom ops.
216
+ """
217
+ # Parse arguments.
218
+ upx, upy = _parse_scaling(up)
219
+ downx, downy = _parse_scaling(down)
220
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
221
+
222
+ # Lookup from cache.
223
+ key = (upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
224
+ if key in _upfirdn2d_cuda_cache:
225
+ return _upfirdn2d_cuda_cache[key]
226
+
227
+ # Forward op.
228
+ class Upfirdn2dCuda(torch.autograd.Function):
229
+ @staticmethod
230
+ def forward(ctx, x, f): # pylint: disable=arguments-differ
231
+ assert isinstance(x, torch.Tensor) and x.ndim == 4
232
+ if f is None:
233
+ f = torch.ones([1, 1], dtype=torch.float32, device=x.device)
234
+ assert isinstance(f, torch.Tensor) and f.ndim in [1, 2]
235
+ y = x
236
+ if f.ndim == 2:
237
+ y = _plugin.upfirdn2d(y, f, upx, upy, downx, downy, padx0, padx1, pady0, pady1, flip_filter, gain)
238
+ else:
239
+ y = _plugin.upfirdn2d(y, f.unsqueeze(0), upx, 1, downx, 1, padx0, padx1, 0, 0, flip_filter, np.sqrt(gain))
240
+ y = _plugin.upfirdn2d(y, f.unsqueeze(1), 1, upy, 1, downy, 0, 0, pady0, pady1, flip_filter, np.sqrt(gain))
241
+ ctx.save_for_backward(f)
242
+ ctx.x_shape = x.shape
243
+ return y
244
+
245
+ @staticmethod
246
+ def backward(ctx, dy): # pylint: disable=arguments-differ
247
+ f, = ctx.saved_tensors
248
+ _, _, ih, iw = ctx.x_shape
249
+ _, _, oh, ow = dy.shape
250
+ fw, fh = _get_filter_size(f)
251
+ p = [
252
+ fw - padx0 - 1,
253
+ iw * upx - ow * downx + padx0 - upx + 1,
254
+ fh - pady0 - 1,
255
+ ih * upy - oh * downy + pady0 - upy + 1,
256
+ ]
257
+ dx = None
258
+ df = None
259
+
260
+ if ctx.needs_input_grad[0]:
261
+ dx = _upfirdn2d_cuda(up=down, down=up, padding=p, flip_filter=(not flip_filter), gain=gain).apply(dy, f)
262
+
263
+ assert not ctx.needs_input_grad[1]
264
+ return dx, df
265
+
266
+ # Add to cache.
267
+ _upfirdn2d_cuda_cache[key] = Upfirdn2dCuda
268
+ return Upfirdn2dCuda
269
+
270
+ #----------------------------------------------------------------------------
271
+
272
+ def filter2d(x, f, padding=0, flip_filter=False, gain=1, impl='cuda'):
273
+ r"""Filter a batch of 2D images using the given 2D FIR filter.
274
+
275
+ By default, the result is padded so that its shape matches the input.
276
+ User-specified padding is applied on top of that, with negative values
277
+ indicating cropping. Pixels outside the image are assumed to be zero.
278
+
279
+ Args:
280
+ x: Float32/float64/float16 input tensor of the shape
281
+ `[batch_size, num_channels, in_height, in_width]`.
282
+ f: Float32 FIR filter of the shape
283
+ `[filter_height, filter_width]` (non-separable),
284
+ `[filter_taps]` (separable), or
285
+ `None` (identity).
286
+ padding: Padding with respect to the output. Can be a single number or a
287
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
288
+ (default: 0).
289
+ flip_filter: False = convolution, True = correlation (default: False).
290
+ gain: Overall scaling factor for signal magnitude (default: 1).
291
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
292
+
293
+ Returns:
294
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
295
+ """
296
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
297
+ fw, fh = _get_filter_size(f)
298
+ p = [
299
+ padx0 + fw // 2,
300
+ padx1 + (fw - 1) // 2,
301
+ pady0 + fh // 2,
302
+ pady1 + (fh - 1) // 2,
303
+ ]
304
+ return upfirdn2d(x, f, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
305
+
306
+ #----------------------------------------------------------------------------
307
+
308
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
309
+ r"""Upsample a batch of 2D images using the given 2D FIR filter.
310
+
311
+ By default, the result is padded so that its shape is a multiple of the input.
312
+ User-specified padding is applied on top of that, with negative values
313
+ indicating cropping. Pixels outside the image are assumed to be zero.
314
+
315
+ Args:
316
+ x: Float32/float64/float16 input tensor of the shape
317
+ `[batch_size, num_channels, in_height, in_width]`.
318
+ f: Float32 FIR filter of the shape
319
+ `[filter_height, filter_width]` (non-separable),
320
+ `[filter_taps]` (separable), or
321
+ `None` (identity).
322
+ up: Integer upsampling factor. Can be a single int or a list/tuple
323
+ `[x, y]` (default: 1).
324
+ padding: Padding with respect to the output. Can be a single number or a
325
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
326
+ (default: 0).
327
+ flip_filter: False = convolution, True = correlation (default: False).
328
+ gain: Overall scaling factor for signal magnitude (default: 1).
329
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
330
+
331
+ Returns:
332
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
333
+ """
334
+ upx, upy = _parse_scaling(up)
335
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
336
+ fw, fh = _get_filter_size(f)
337
+ p = [
338
+ padx0 + (fw + upx - 1) // 2,
339
+ padx1 + (fw - upx) // 2,
340
+ pady0 + (fh + upy - 1) // 2,
341
+ pady1 + (fh - upy) // 2,
342
+ ]
343
+ return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain*upx*upy, impl=impl)
344
+
345
+ #----------------------------------------------------------------------------
346
+
347
+ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
348
+ r"""Downsample a batch of 2D images using the given 2D FIR filter.
349
+
350
+ By default, the result is padded so that its shape is a fraction of the input.
351
+ User-specified padding is applied on top of that, with negative values
352
+ indicating cropping. Pixels outside the image are assumed to be zero.
353
+
354
+ Args:
355
+ x: Float32/float64/float16 input tensor of the shape
356
+ `[batch_size, num_channels, in_height, in_width]`.
357
+ f: Float32 FIR filter of the shape
358
+ `[filter_height, filter_width]` (non-separable),
359
+ `[filter_taps]` (separable), or
360
+ `None` (identity).
361
+ down: Integer downsampling factor. Can be a single int or a list/tuple
362
+ `[x, y]` (default: 1).
363
+ padding: Padding with respect to the input. Can be a single number or a
364
+ list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]`
365
+ (default: 0).
366
+ flip_filter: False = convolution, True = correlation (default: False).
367
+ gain: Overall scaling factor for signal magnitude (default: 1).
368
+ impl: Implementation to use. Can be `'ref'` or `'cuda'` (default: `'cuda'`).
369
+
370
+ Returns:
371
+ Tensor of the shape `[batch_size, num_channels, out_height, out_width]`.
372
+ """
373
+ downx, downy = _parse_scaling(down)
374
+ padx0, padx1, pady0, pady1 = _parse_padding(padding)
375
+ fw, fh = _get_filter_size(f)
376
+ p = [
377
+ padx0 + (fw - downx + 1) // 2,
378
+ padx1 + (fw - downx) // 2,
379
+ pady0 + (fh - downy + 1) // 2,
380
+ pady1 + (fh - downy) // 2,
381
+ ]
382
+ return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
383
+
384
+ #----------------------------------------------------------------------------