File size: 4,219 Bytes
a1b524b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py
index 973a84f..6854b97 100644
--- a/models/stylegan2/op/fused_act.py
+++ b/models/stylegan2/op/fused_act.py
@@ -2,17 +2,18 @@ import os
import torch
from torch import nn
+from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
-module_path = os.path.dirname(__file__)
-fused = load(
- 'fused',
- sources=[
- os.path.join(module_path, 'fused_bias_act.cpp'),
- os.path.join(module_path, 'fused_bias_act_kernel.cu'),
- ],
-)
+#module_path = os.path.dirname(__file__)
+#fused = load(
+# 'fused',
+# sources=[
+# os.path.join(module_path, 'fused_bias_act.cpp'),
+# os.path.join(module_path, 'fused_bias_act_kernel.cu'),
+# ],
+#)
class FusedLeakyReLUFunctionBackward(Function):
@@ -82,4 +83,18 @@ class FusedLeakyReLU(nn.Module):
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
- return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
+ if input.device.type == "cpu":
+ if bias is not None:
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
+ return (
+ F.leaky_relu(
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
+ )
+ * scale
+ )
+
+ else:
+ return F.leaky_relu(input, negative_slope=0.2) * scale
+
+ else:
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py
index 7bc5a1e..5465d1a 100644
--- a/models/stylegan2/op/upfirdn2d.py
+++ b/models/stylegan2/op/upfirdn2d.py
@@ -1,17 +1,18 @@
import os
import torch
+from torch.nn import functional as F
from torch.autograd import Function
from torch.utils.cpp_extension import load
-module_path = os.path.dirname(__file__)
-upfirdn2d_op = load(
- 'upfirdn2d',
- sources=[
- os.path.join(module_path, 'upfirdn2d.cpp'),
- os.path.join(module_path, 'upfirdn2d_kernel.cu'),
- ],
-)
+#module_path = os.path.dirname(__file__)
+#upfirdn2d_op = load(
+# 'upfirdn2d',
+# sources=[
+# os.path.join(module_path, 'upfirdn2d.cpp'),
+# os.path.join(module_path, 'upfirdn2d_kernel.cu'),
+# ],
+#)
class UpFirDn2dBackward(Function):
@@ -97,8 +98,8 @@ class UpFirDn2d(Function):
ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
- out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
- out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
ctx.out_size = (out_h, out_w)
ctx.up = (up_x, up_y)
@@ -140,9 +141,13 @@ class UpFirDn2d(Function):
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
- out = UpFirDn2d.apply(
- input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
- )
+ if input.device.type == "cpu":
+ out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
+
+ else:
+ out = UpFirDn2d.apply(
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
+ )
return out
@@ -150,6 +155,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
def upfirdn2d_native(
input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
):
+ _, channel, in_h, in_w = input.shape
+ input = input.reshape(-1, in_h, in_w, 1)
+
_, in_h, in_w, minor = input.shape
kernel_h, kernel_w = kernel.shape
@@ -180,5 +188,9 @@ def upfirdn2d_native(
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
)
out = out.permute(0, 2, 3, 1)
+ out = out[:, ::down_y, ::down_x, :]
+
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
- return out[:, ::down_y, ::down_x, :]
+ return out.view(-1, channel, out_h, out_w)
|