Initial commit
Browse files- layers.py +98 -0
- networks.py +226 -0
- pix2pix.py +212 -0
- pix2pix_disc_ckpt_200.pt +3 -0
- pix2pix_gen_ckpt_200.pt +3 -0
layers.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class DownsamplingBlock(nn.Module):
|
7 |
+
"""Defines the Unet downsampling block.
|
8 |
+
|
9 |
+
Consists of Convolution-BatchNorm-ReLU layer with k filters.
|
10 |
+
"""
|
11 |
+
def __init__(self, c_in, c_out, kernel_size=4, stride=2,
|
12 |
+
padding=1, negative_slope=0.2, use_norm=True):
|
13 |
+
"""
|
14 |
+
Initializes the UnetDownsamplingBlock.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
c_in (int): The number of input channels.
|
18 |
+
c_out (int): The number of output channels.
|
19 |
+
kernel_size (int, optional): The size of the convolving kernel. Default is 4.
|
20 |
+
stride (int, optional): Stride of the convolution. Default is 2.
|
21 |
+
padding (int, optional): Zero-padding added to both sides of the input. Default is 0.
|
22 |
+
negative_slope (float, optional): Negative slope for the LeakyReLU activation function. Default is 0.2.
|
23 |
+
use_norm (bool, optinal): If use norm layer. If True add a BatchNorm layer after Conv. Default is True.
|
24 |
+
"""
|
25 |
+
super(DownsamplingBlock, self).__init__()
|
26 |
+
block = []
|
27 |
+
block += [nn.Conv2d(in_channels=c_in, out_channels=c_out,
|
28 |
+
kernel_size=kernel_size, stride=stride, padding=padding,
|
29 |
+
bias=(not use_norm) # No need to use a bias if there is a batchnorm layer after conv
|
30 |
+
)]
|
31 |
+
if use_norm:
|
32 |
+
block += [nn.BatchNorm2d(num_features=c_out)]
|
33 |
+
|
34 |
+
block += [nn.LeakyReLU(negative_slope=negative_slope)]
|
35 |
+
|
36 |
+
self.conv_block = nn.Sequential(*block)
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
return self.conv_block(x)
|
40 |
+
|
41 |
+
|
42 |
+
class UpsamplingBlock(nn.Module):
|
43 |
+
"""Defines the Unet upsampling block.
|
44 |
+
"""
|
45 |
+
def __init__(self, c_in, c_out, kernel_size=4, stride=2,
|
46 |
+
padding=1, use_dropout=False, use_upsampling=False, mode='nearest'):
|
47 |
+
|
48 |
+
"""
|
49 |
+
Initializes the Unet Upsampling Block.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
c_in (int): The number of input channels.
|
53 |
+
c_out (int): The number of output channels.
|
54 |
+
kernel_size (int, optional): Size of the convolving kernel. Default is 4.
|
55 |
+
stride (int, optional): Stride of the convolution. Default is 2.
|
56 |
+
padding (int, optional): Zero-padding added to both sides of the input. Default is 0.
|
57 |
+
use_dropout (bool, optional): if use dropout layers. Default is False.
|
58 |
+
upsample (bool, optinal): if use upsampling rather than transpose convolution. Default is False.
|
59 |
+
mode (str, optional): the upsampling algorithm: one of 'nearest',
|
60 |
+
'bilinear', 'bicubic'. Default: 'nearest'
|
61 |
+
"""
|
62 |
+
super(UpsamplingBlock, self).__init__()
|
63 |
+
block = []
|
64 |
+
if use_upsampling:
|
65 |
+
# Transpose convolution causes checkerboard artifacts. Upsampling
|
66 |
+
# followed by a regular convolutions produces better results appearantly
|
67 |
+
# Please check for further reading: https://distill.pub/2016/deconv-checkerboard/
|
68 |
+
# Odena, et al., "Deconvolution and Checkerboard Artifacts", Distill, 2016. http://doi.org/10.23915/distill.00003
|
69 |
+
|
70 |
+
mode = mode if mode in ('nearest', 'bilinear', 'bicubic') else 'nearest'
|
71 |
+
|
72 |
+
block += [nn.Sequential(
|
73 |
+
nn.Upsample(scale_factor=2, mode=mode),
|
74 |
+
nn.Conv2d(in_channels=c_in, out_channels=c_out,
|
75 |
+
kernel_size=3, stride=1, padding=padding,
|
76 |
+
bias=False
|
77 |
+
)
|
78 |
+
)]
|
79 |
+
else:
|
80 |
+
block += [nn.ConvTranspose2d(in_channels=c_in,
|
81 |
+
out_channels=c_out,
|
82 |
+
kernel_size=kernel_size,
|
83 |
+
stride=stride,
|
84 |
+
padding=padding, bias=False
|
85 |
+
)
|
86 |
+
]
|
87 |
+
|
88 |
+
block += [nn.BatchNorm2d(num_features=c_out)]
|
89 |
+
|
90 |
+
if use_dropout:
|
91 |
+
block += [nn.Dropout(0.5)]
|
92 |
+
|
93 |
+
block += [nn.ReLU()]
|
94 |
+
|
95 |
+
self.conv_block = nn.Sequential(*block)
|
96 |
+
|
97 |
+
def forward(self, x):
|
98 |
+
return self.conv_block(x)
|
networks.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .layers import DownsamplingBlock, UpsamplingBlock
|
6 |
+
|
7 |
+
class UnetEncoder(nn.Module):
|
8 |
+
"""Create the Unet Encoder Network.
|
9 |
+
|
10 |
+
C64-C128-C256-C512-C512-C512-C512-C512
|
11 |
+
"""
|
12 |
+
def __init__(self, c_in=3, c_out=512):
|
13 |
+
"""
|
14 |
+
Constructs the Unet Encoder Network.
|
15 |
+
|
16 |
+
Ck denote a Convolution-BatchNorm-ReLU layer with k filters.
|
17 |
+
C64-C128-C256-C512-C512-C512-C512-C512
|
18 |
+
Args:
|
19 |
+
c_in (int, optional): Number of input channels.
|
20 |
+
c_out (int, optional): Number of output channels. Default is 512.
|
21 |
+
"""
|
22 |
+
super(UnetEncoder, self).__init__()
|
23 |
+
self.enc1 = DownsamplingBlock(c_in, 64, use_norm=False) # C64
|
24 |
+
self.enc2 = DownsamplingBlock(64, 128) # C128
|
25 |
+
self.enc3 = DownsamplingBlock(128, 256) # C256
|
26 |
+
self.enc4 = DownsamplingBlock(256, 512) # C512
|
27 |
+
self.enc5 = DownsamplingBlock(512, 512) # C512
|
28 |
+
self.enc6 = DownsamplingBlock(512, 512) # C512
|
29 |
+
self.enc7 = DownsamplingBlock(512, 512) # C512
|
30 |
+
self.enc8 = DownsamplingBlock(512, c_out) # C512
|
31 |
+
|
32 |
+
def forward(self, x):
|
33 |
+
x1 = self.enc1(x)
|
34 |
+
x2 = self.enc2(x1)
|
35 |
+
x3 = self.enc3(x2)
|
36 |
+
x4 = self.enc4(x3)
|
37 |
+
x5 = self.enc5(x4)
|
38 |
+
x6 = self.enc6(x5)
|
39 |
+
x7 = self.enc7(x6)
|
40 |
+
x8 = self.enc8(x7)
|
41 |
+
out = [x8, x7, x6, x5, x4, x3, x2, x1] # latest activation is the first element
|
42 |
+
return out
|
43 |
+
|
44 |
+
|
45 |
+
class UnetDecoder(nn.Module):
|
46 |
+
"""Creates the Unet Decoder Network.
|
47 |
+
"""
|
48 |
+
def __init__(self, c_in=512, c_out=64, use_upsampling=False, mode='nearest'):
|
49 |
+
"""
|
50 |
+
Constructs the Unet Decoder Network.
|
51 |
+
|
52 |
+
Ck denote a Convolution-BatchNorm-ReLU layer with k filters.
|
53 |
+
|
54 |
+
CDk denotes a Convolution-BatchNorm-Dropout-ReLU layer with a dropout rate of 50%.
|
55 |
+
CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
|
56 |
+
Args:
|
57 |
+
c_in (int): Number of input channels.
|
58 |
+
c_out (int, optional): Number of output channels. Default is 512.
|
59 |
+
use_upsampling (bool, optional): Upsampling method for decoder.
|
60 |
+
If True, use upsampling layer followed regular convolution layer.
|
61 |
+
If False, use transpose convolution. Default is False
|
62 |
+
mode (str, optional): the upsampling algorithm: one of 'nearest',
|
63 |
+
'bilinear', 'bicubic'. Default: 'nearest'
|
64 |
+
"""
|
65 |
+
super(UnetDecoder, self).__init__()
|
66 |
+
self.dec1 = UpsamplingBlock(c_in, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD512
|
67 |
+
self.dec2 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD1024
|
68 |
+
self.dec3 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD1024
|
69 |
+
self.dec4 = UpsamplingBlock(1024, 512, use_upsampling=use_upsampling, mode=mode) # C1024
|
70 |
+
self.dec5 = UpsamplingBlock(1024, 256, use_upsampling=use_upsampling, mode=mode) # C1024
|
71 |
+
self.dec6 = UpsamplingBlock(512, 128, use_upsampling=use_upsampling, mode=mode) # C512
|
72 |
+
self.dec7 = UpsamplingBlock(256, 64, use_upsampling=use_upsampling, mode=mode) # C256
|
73 |
+
self.dec8 = UpsamplingBlock(128, c_out, use_upsampling=use_upsampling, mode=mode) # C128
|
74 |
+
|
75 |
+
|
76 |
+
def forward(self, x):
|
77 |
+
x9 = torch.cat([x[1], self.dec1(x[0])], 1) # (N,1024,H,W)
|
78 |
+
x10 = torch.cat([x[2], self.dec2(x9)], 1) # (N,1024,H,W)
|
79 |
+
x11 = torch.cat([x[3], self.dec3(x10)], 1) # (N,1024,H,W)
|
80 |
+
x12 = torch.cat([x[4], self.dec4(x11)], 1) # (N,1024,H,W)
|
81 |
+
x13 = torch.cat([x[5], self.dec5(x12)], 1) # (N,512,H,W)
|
82 |
+
x14 = torch.cat([x[6], self.dec6(x13)], 1) # (N,256,H,W)
|
83 |
+
x15 = torch.cat([x[7], self.dec7(x14)], 1) # (N,128,H,W)
|
84 |
+
out = self.dec8(x15) # (N,64,H,W)
|
85 |
+
return out
|
86 |
+
|
87 |
+
|
88 |
+
class UnetGenerator(nn.Module):
|
89 |
+
"""Create a Unet-based generator"""
|
90 |
+
def __init__(self, c_in=3, c_out=3, use_upsampling=False, mode='nearest'):
|
91 |
+
"""
|
92 |
+
Constructs a Unet generator
|
93 |
+
Args:
|
94 |
+
c_in (int): The number of input channels.
|
95 |
+
c_out (int): The number of output channels.
|
96 |
+
use_upsampling (bool, optional): Upsampling method for decoder.
|
97 |
+
If True, use upsampling layer followed regular convolution layer.
|
98 |
+
If False, use transpose convolution. Default is False
|
99 |
+
mode (str, optional): the upsampling algorithm: one of 'nearest',
|
100 |
+
'bilinear', 'bicubic'. Default: 'nearest'
|
101 |
+
"""
|
102 |
+
super(UnetGenerator, self).__init__()
|
103 |
+
self.encoder = UnetEncoder(c_in=c_in)
|
104 |
+
self.decoder = UnetDecoder(use_upsampling=use_upsampling, mode=mode)
|
105 |
+
# In the paper, the authors state:
|
106 |
+
# """
|
107 |
+
# After the last layer in the decoder, a convolution is applied
|
108 |
+
# to map to the number of output channels (3 in general, except
|
109 |
+
# in colorization, where it is 2), followed by a Tanh function.
|
110 |
+
# """
|
111 |
+
# However, in the official Lua implementation, only a Tanh layer is applied.
|
112 |
+
# Therefore, I took the liberty of adding a convolutional layer with a
|
113 |
+
# kernel size of 3.
|
114 |
+
# For more information please check the paper and official github repo:
|
115 |
+
# https://github.com/phillipi/pix2pix
|
116 |
+
# https://arxiv.org/abs/1611.07004
|
117 |
+
self.head = nn.Sequential(
|
118 |
+
nn.Conv2d(in_channels=64, out_channels=c_out,
|
119 |
+
kernel_size=3, stride=1, padding=1,
|
120 |
+
bias=True
|
121 |
+
),
|
122 |
+
nn.Tanh()
|
123 |
+
)
|
124 |
+
|
125 |
+
def forward(self, x):
|
126 |
+
outE = self.encoder(x)
|
127 |
+
outD = self.decoder(outE)
|
128 |
+
out = self.head(outD)
|
129 |
+
return out
|
130 |
+
|
131 |
+
|
132 |
+
class PatchDiscriminator(nn.Module):
|
133 |
+
"""Create a PatchGAN discriminator"""
|
134 |
+
def __init__(self, c_in=3, c_hid=64, n_layers=3):
|
135 |
+
"""Constructs a PatchGAN discriminator
|
136 |
+
|
137 |
+
Args:
|
138 |
+
c_in (int, optional): The number of input channels. Defaults to 3.
|
139 |
+
c_hid (int, optional): The number of channels after first conv layer.
|
140 |
+
Defaults to 64.
|
141 |
+
n_layers (int, optional): the number of convolution blocks in the
|
142 |
+
discriminator. Defaults to 3.
|
143 |
+
"""
|
144 |
+
super(PatchDiscriminator, self).__init__()
|
145 |
+
model = [DownsamplingBlock(c_in, c_hid, use_norm=False)]
|
146 |
+
|
147 |
+
n_p = 1 # multiplier for previous channel
|
148 |
+
n_c = 1 # multiplier for current channel
|
149 |
+
# last block is with stride of 1, therefore iterate (n_layers-1) times
|
150 |
+
for n in range(1, n_layers):
|
151 |
+
n_p = n_c
|
152 |
+
n_c = min(2**n, 8) # The number of channels is 512 at most
|
153 |
+
|
154 |
+
model += [DownsamplingBlock(c_hid*n_p, c_hid*n_c)]
|
155 |
+
|
156 |
+
n_p = n_c
|
157 |
+
n_c = min(2**n_layers, 8)
|
158 |
+
model += [DownsamplingBlock(c_hid*n_p, c_hid*n_c, stride=1)] # last block is with stride of 1
|
159 |
+
|
160 |
+
# last layer is a convolution followed by a Sigmoid function.
|
161 |
+
model += [nn.Conv2d(in_channels=c_hid*n_c, out_channels=1,
|
162 |
+
kernel_size=4, stride=1, padding=1, bias=True
|
163 |
+
)]
|
164 |
+
# Normally, there should be a sigmoid layer at the end of discriminator.
|
165 |
+
# However, nn.BCEWithLogitsLoss combines the sigmoid layer with BCE loss,
|
166 |
+
# providing greater numerical stability. Therefore, the discriminator outputs
|
167 |
+
# logits to take advantage of this stability.
|
168 |
+
self.model = nn.Sequential(*model)
|
169 |
+
|
170 |
+
def forward(self, x):
|
171 |
+
return self.model(x)
|
172 |
+
|
173 |
+
|
174 |
+
class PixelDiscriminator(nn.Module):
|
175 |
+
"""Create a PixelGAN discriminator (1x1 PatchGAN discriminator)"""
|
176 |
+
def __init__(self, c_in=3, c_hid=64):
|
177 |
+
"""Constructs a PixelGAN discriminator, a special form of PatchGAN Discriminator.
|
178 |
+
All convolutions are 1x1 spatial filters
|
179 |
+
|
180 |
+
Args:
|
181 |
+
c_in (int, optional): The number of input channels. Defaults to 3.
|
182 |
+
c_hid (int, optional): The number of channels after first conv layer.
|
183 |
+
Defaults to 64.
|
184 |
+
"""
|
185 |
+
super(PixelDiscriminator, self).__init__()
|
186 |
+
self.model = nn.Sequential(
|
187 |
+
DownsamplingBlock(c_in, c_hid, kernel_size=1, stride=1, padding=0, use_norm=False),
|
188 |
+
DownsamplingBlock(c_hid, c_hid*2, kernel_size=1, stride=1, padding=0),
|
189 |
+
nn.Conv2d(in_channels=c_hid*2, out_channels=1, kernel_size=1)
|
190 |
+
)
|
191 |
+
# Similar to PatchDiscriminator, there should be a sigmoid layer at the end of discriminator.
|
192 |
+
# However, nn.BCEWithLogitsLoss combines the sigmoid layer with BCE loss,
|
193 |
+
# providing greater numerical stability. Therefore, the discriminator outputs
|
194 |
+
# logits to take advantage of this stability.
|
195 |
+
|
196 |
+
def forward(self, x):
|
197 |
+
return self.model(x)
|
198 |
+
|
199 |
+
|
200 |
+
class PatchGAN(nn.Module):
|
201 |
+
"""Create a PatchGAN discriminator"""
|
202 |
+
def __init__(self, c_in=3, c_hid=64, mode='patch', n_layers=3):
|
203 |
+
"""Constructs a PatchGAN discriminator.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
c_in (int, optional): The number of input channels. Defaults to 3.
|
207 |
+
c_hid (int, optional): The number of channels after first
|
208 |
+
convolutional layer. Defaults to 64.
|
209 |
+
mode (str, optional): PatchGAN type. Use 'pixel' for PixelGAN, and
|
210 |
+
'patch' for other types. Defaults to 'patch'.
|
211 |
+
n_layers (int, optional): PatchGAN number of layers. Defaults to 3.
|
212 |
+
- 16x16 PatchGAN if n=1
|
213 |
+
- 34x34 PatchGAN if n=2
|
214 |
+
- 70x70 PatchGAN if n=3
|
215 |
+
- 142x142 PatchGAN if n=4
|
216 |
+
- 286x286 PatchGAN if n=5
|
217 |
+
- 574x574 PatchGAN if n=6
|
218 |
+
"""
|
219 |
+
super(PatchGAN, self).__init__()
|
220 |
+
if mode == 'pixel':
|
221 |
+
self.model = PixelDiscriminator(c_in, c_hid)
|
222 |
+
else:
|
223 |
+
self.model = PatchDiscriminator(c_in, c_hid, n_layers)
|
224 |
+
|
225 |
+
def forward(self, x):
|
226 |
+
return self.model(x)
|
pix2pix.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .networks import UnetGenerator, PatchGAN
|
6 |
+
|
7 |
+
class Pix2Pix(nn.Module):
|
8 |
+
"""Create a Pix2Pix class. It is a model for image to image translation tasks.
|
9 |
+
By default, the model uses a Unet architecture for generator with transposed
|
10 |
+
convolution. The discriminator is 70x70 PatchGAN discriminator, by default.
|
11 |
+
"""
|
12 |
+
def __init__(self,
|
13 |
+
c_in: int = 3,
|
14 |
+
c_out: int = 3,
|
15 |
+
is_train: bool = True,
|
16 |
+
netD: str = 'patch',
|
17 |
+
lambda_L1: float = 100.0,
|
18 |
+
is_CGAN: bool = True,
|
19 |
+
use_upsampling: bool = False,
|
20 |
+
mode: str = 'nearest',
|
21 |
+
c_hid: int = 64,
|
22 |
+
n_layers: int = 3,
|
23 |
+
lr: float = 0.0002,
|
24 |
+
beta1: float = 0.5,
|
25 |
+
beta2: float = 0.999
|
26 |
+
):
|
27 |
+
"""Constructs the Pix2Pix class.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
c_in: Number of input channels
|
31 |
+
c_out: Number of output channels
|
32 |
+
is_train: Whether the model is in training mode
|
33 |
+
netD: Type of discriminator ('patch' or 'pixel')
|
34 |
+
lambda_L1: Weight for L1 loss
|
35 |
+
is_CGAN: If True, use conditional GAN architecture
|
36 |
+
use_upsampling: If True, use upsampling in generator instead of transpose conv
|
37 |
+
mode: Upsampling mode ('nearest', 'bilinear', 'bicubic')
|
38 |
+
c_hid: Number of base filters in discriminator
|
39 |
+
n_layers: Number of layers in discriminator
|
40 |
+
lr: Learning rate
|
41 |
+
beta1: Beta1 parameter for Adam optimizer
|
42 |
+
beta2: Beta2 parameter for Adam optimizer
|
43 |
+
"""
|
44 |
+
super(Pix2Pix, self).__init__()
|
45 |
+
self.is_CGAN = is_CGAN
|
46 |
+
self.lambda_L1 = lambda_L1
|
47 |
+
|
48 |
+
self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode)
|
49 |
+
self.gen = self.gen.apply(self.weights_init)
|
50 |
+
|
51 |
+
if is_train:
|
52 |
+
# Conditional GANs need both input and output together, the total input channel is c_in+c_out
|
53 |
+
disc_in = c_in + c_out if is_CGAN else c_out
|
54 |
+
self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers)
|
55 |
+
self.disc = self.disc.apply(self.weights_init)
|
56 |
+
|
57 |
+
# Initialize optimizers
|
58 |
+
self.gen_optimizer = torch.optim.Adam(
|
59 |
+
self.gen.parameters(), lr=lr, betas=(beta1, beta2))
|
60 |
+
self.disc_optimizer = torch.optim.Adam(
|
61 |
+
self.disc.parameters(), lr=lr, betas=(beta1, beta2))
|
62 |
+
|
63 |
+
# Initialize loss functions
|
64 |
+
self.criterion = nn.BCEWithLogitsLoss()
|
65 |
+
self.criterion_L1 = nn.L1Loss()
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
return self.gen(x)
|
69 |
+
|
70 |
+
@staticmethod
|
71 |
+
def weights_init(m):
|
72 |
+
"""Initialize network weights.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
m: network module
|
76 |
+
"""
|
77 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
|
78 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
79 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
80 |
+
nn.init.constant_(m.bias, 0.0)
|
81 |
+
if isinstance(m, nn.BatchNorm2d):
|
82 |
+
nn.init.normal_(m.weight, 1.0, 0.02)
|
83 |
+
nn.init.constant_(m.bias, 0)
|
84 |
+
|
85 |
+
def _get_disc_inputs(self, real_images, target_images, fake_images):
|
86 |
+
"""Prepare discriminator inputs based on conditional/unconditional setup."""
|
87 |
+
if self.is_CGAN:
|
88 |
+
# Conditional GANs need both input and output together,
|
89 |
+
# Therefore, the total input channel is c_in+c_out
|
90 |
+
real_AB = torch.cat([real_images, target_images], dim=1)
|
91 |
+
fake_AB = torch.cat([real_images,
|
92 |
+
fake_images.detach()],
|
93 |
+
dim=1)
|
94 |
+
else:
|
95 |
+
real_AB = target_images
|
96 |
+
fake_AB = fake_images.detach()
|
97 |
+
return real_AB, fake_AB
|
98 |
+
|
99 |
+
def _get_gen_inputs(self, real_images, fake_images):
|
100 |
+
"""Prepare discriminator inputs based on conditional/unconditional setup."""
|
101 |
+
if self.is_CGAN:
|
102 |
+
# Conditional GANs need both input and output together,
|
103 |
+
# Therefore, the total input channel is c_in+c_out
|
104 |
+
fake_AB = torch.cat([real_images,
|
105 |
+
fake_images],
|
106 |
+
dim=1)
|
107 |
+
else:
|
108 |
+
fake_AB = fake_images
|
109 |
+
return fake_AB
|
110 |
+
|
111 |
+
|
112 |
+
def step_discriminator(self, real_images, target_images, fake_images):
|
113 |
+
"""Discriminator forward/backward pass.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
real_images: Input images
|
117 |
+
target_images: Ground truth images
|
118 |
+
fake_images: Generated images
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Discriminator loss value
|
122 |
+
"""
|
123 |
+
# Prepare inputs
|
124 |
+
real_AB, fake_AB = self._get_disc_inputs(real_images, target_images,
|
125 |
+
fake_images)
|
126 |
+
|
127 |
+
# Forward pass through the discriminator
|
128 |
+
pred_real = self.disc(real_AB) # D(x, y)
|
129 |
+
pred_fake = self.disc(fake_AB) # D(x, G(x))
|
130 |
+
|
131 |
+
# Compute the losses
|
132 |
+
lossD_real = self.criterion(pred_real, torch.ones_like(pred_real)) # (D(x, y), 1)
|
133 |
+
lossD_fake = self.criterion(pred_fake, torch.zeros_like(pred_fake)) # (D(x, y), 0)
|
134 |
+
lossD = (lossD_real + lossD_fake) * 0.5 # Combined Loss
|
135 |
+
return lossD
|
136 |
+
|
137 |
+
def step_generator(self, real_images, target_images, fake_images):
|
138 |
+
"""Discriminator forward/backward pass.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
real_images: Input images
|
142 |
+
target_images: Ground truth images
|
143 |
+
fake_images: Generated images
|
144 |
+
|
145 |
+
Returns:
|
146 |
+
Discriminator loss value
|
147 |
+
"""
|
148 |
+
# Prepare input
|
149 |
+
fake_AB = self._get_gen_inputs(real_images, fake_images)
|
150 |
+
|
151 |
+
# Forward pass through the discriminator
|
152 |
+
pred_fake = self.disc(fake_AB)
|
153 |
+
|
154 |
+
# Compute the losses
|
155 |
+
lossG_GaN = self.criterion(pred_fake, torch.ones_like(pred_fake)) # GAN Loss
|
156 |
+
lossG_L1 = self.criterion_L1(fake_images, target_images) # L1 Loss
|
157 |
+
lossG = lossG_GaN + self.lambda_L1 * lossG_L1 # Combined Loss
|
158 |
+
# Return total loss and individual components
|
159 |
+
return lossG, {
|
160 |
+
'loss_G': lossG.item(),
|
161 |
+
'loss_G_GAN': lossG_GaN.item(),
|
162 |
+
'loss_G_L1': lossG_L1.item()
|
163 |
+
}
|
164 |
+
|
165 |
+
def train_step(self, real_images, target_images):
|
166 |
+
"""Performs a single training step.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
real_images: Input images
|
170 |
+
target_images: Ground truth images
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
Dictionary containing all loss values from this step
|
174 |
+
"""
|
175 |
+
# Forward pass through the generator
|
176 |
+
fake_images = self.forward(real_images)
|
177 |
+
|
178 |
+
# Update discriminator
|
179 |
+
self.disc_optimizer.zero_grad() # Reset the gradients for D
|
180 |
+
lossD = self.stepD(real_images, target_images, fake_images) # Compute the loss
|
181 |
+
lossD.backward()
|
182 |
+
self.disc_optimizer.step() # Update D
|
183 |
+
|
184 |
+
# Update generator
|
185 |
+
self.gen_optimizer.zero_grad() # Reset the gradients for D
|
186 |
+
lossG, G_losses = self.stepG(real_images, target_images, fake_images) # Compute the loss
|
187 |
+
lossG.backward()
|
188 |
+
self.gen_optimizer.step() # Update D
|
189 |
+
|
190 |
+
# Return all losses
|
191 |
+
return {
|
192 |
+
'loss_D': lossD.item(),
|
193 |
+
**G_losses
|
194 |
+
}
|
195 |
+
|
196 |
+
def get_current_visuals(self, real_images, target_images):
|
197 |
+
"""Return visualization images.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
real_images: Input images
|
201 |
+
target_images: Ground truth images
|
202 |
+
|
203 |
+
Returns:
|
204 |
+
Dictionary containing input, target and generated images
|
205 |
+
"""
|
206 |
+
with torch.no_grad():
|
207 |
+
fake_images = self.gen(real_images)
|
208 |
+
return {
|
209 |
+
'real': real_images,
|
210 |
+
'fake': fake_images,
|
211 |
+
'target': target_images
|
212 |
+
}
|
pix2pix_disc_ckpt_200.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:914a7a2152fabd46a7bcc7aad3fb3e642cd0432df7151b4f04c9cf792fdc831b
|
3 |
+
size 11090624
|
pix2pix_gen_ckpt_200.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:db2fda865233203ba13e7c10220bfbf46fa6a92ecb462fb5ad39bf22bbad15af
|
3 |
+
size 218246966
|