yuulind commited on
Commit
a664a45
1 Parent(s): 53ce89b

Initial commit

Browse files
Files changed (5) hide show
  1. layers.py +98 -0
  2. networks.py +226 -0
  3. pix2pix.py +212 -0
  4. pix2pix_disc_ckpt_200.pt +3 -0
  5. pix2pix_gen_ckpt_200.pt +3 -0
layers.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DownsamplingBlock(nn.Module):
7
+ """Defines the Unet downsampling block.
8
+
9
+ Consists of Convolution-BatchNorm-ReLU layer with k filters.
10
+ """
11
+ def __init__(self, c_in, c_out, kernel_size=4, stride=2,
12
+ padding=1, negative_slope=0.2, use_norm=True):
13
+ """
14
+ Initializes the UnetDownsamplingBlock.
15
+
16
+ Args:
17
+ c_in (int): The number of input channels.
18
+ c_out (int): The number of output channels.
19
+ kernel_size (int, optional): The size of the convolving kernel. Default is 4.
20
+ stride (int, optional): Stride of the convolution. Default is 2.
21
+ padding (int, optional): Zero-padding added to both sides of the input. Default is 0.
22
+ negative_slope (float, optional): Negative slope for the LeakyReLU activation function. Default is 0.2.
23
+ use_norm (bool, optinal): If use norm layer. If True add a BatchNorm layer after Conv. Default is True.
24
+ """
25
+ super(DownsamplingBlock, self).__init__()
26
+ block = []
27
+ block += [nn.Conv2d(in_channels=c_in, out_channels=c_out,
28
+ kernel_size=kernel_size, stride=stride, padding=padding,
29
+ bias=(not use_norm) # No need to use a bias if there is a batchnorm layer after conv
30
+ )]
31
+ if use_norm:
32
+ block += [nn.BatchNorm2d(num_features=c_out)]
33
+
34
+ block += [nn.LeakyReLU(negative_slope=negative_slope)]
35
+
36
+ self.conv_block = nn.Sequential(*block)
37
+
38
+ def forward(self, x):
39
+ return self.conv_block(x)
40
+
41
+
42
+ class UpsamplingBlock(nn.Module):
43
+ """Defines the Unet upsampling block.
44
+ """
45
+ def __init__(self, c_in, c_out, kernel_size=4, stride=2,
46
+ padding=1, use_dropout=False, use_upsampling=False, mode='nearest'):
47
+
48
+ """
49
+ Initializes the Unet Upsampling Block.
50
+
51
+ Args:
52
+ c_in (int): The number of input channels.
53
+ c_out (int): The number of output channels.
54
+ kernel_size (int, optional): Size of the convolving kernel. Default is 4.
55
+ stride (int, optional): Stride of the convolution. Default is 2.
56
+ padding (int, optional): Zero-padding added to both sides of the input. Default is 0.
57
+ use_dropout (bool, optional): if use dropout layers. Default is False.
58
+ upsample (bool, optinal): if use upsampling rather than transpose convolution. Default is False.
59
+ mode (str, optional): the upsampling algorithm: one of 'nearest',
60
+ 'bilinear', 'bicubic'. Default: 'nearest'
61
+ """
62
+ super(UpsamplingBlock, self).__init__()
63
+ block = []
64
+ if use_upsampling:
65
+ # Transpose convolution causes checkerboard artifacts. Upsampling
66
+ # followed by a regular convolutions produces better results appearantly
67
+ # Please check for further reading: https://distill.pub/2016/deconv-checkerboard/
68
+ # Odena, et al., "Deconvolution and Checkerboard Artifacts", Distill, 2016. http://doi.org/10.23915/distill.00003
69
+
70
+ mode = mode if mode in ('nearest', 'bilinear', 'bicubic') else 'nearest'
71
+
72
+ block += [nn.Sequential(
73
+ nn.Upsample(scale_factor=2, mode=mode),
74
+ nn.Conv2d(in_channels=c_in, out_channels=c_out,
75
+ kernel_size=3, stride=1, padding=padding,
76
+ bias=False
77
+ )
78
+ )]
79
+ else:
80
+ block += [nn.ConvTranspose2d(in_channels=c_in,
81
+ out_channels=c_out,
82
+ kernel_size=kernel_size,
83
+ stride=stride,
84
+ padding=padding, bias=False
85
+ )
86
+ ]
87
+
88
+ block += [nn.BatchNorm2d(num_features=c_out)]
89
+
90
+ if use_dropout:
91
+ block += [nn.Dropout(0.5)]
92
+
93
+ block += [nn.ReLU()]
94
+
95
+ self.conv_block = nn.Sequential(*block)
96
+
97
+ def forward(self, x):
98
+ return self.conv_block(x)
networks.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .layers import DownsamplingBlock, UpsamplingBlock
6
+
7
+ class UnetEncoder(nn.Module):
8
+ """Create the Unet Encoder Network.
9
+
10
+ C64-C128-C256-C512-C512-C512-C512-C512
11
+ """
12
+ def __init__(self, c_in=3, c_out=512):
13
+ """
14
+ Constructs the Unet Encoder Network.
15
+
16
+ Ck denote a Convolution-BatchNorm-ReLU layer with k filters.
17
+ C64-C128-C256-C512-C512-C512-C512-C512
18
+ Args:
19
+ c_in (int, optional): Number of input channels.
20
+ c_out (int, optional): Number of output channels. Default is 512.
21
+ """
22
+ super(UnetEncoder, self).__init__()
23
+ self.enc1 = DownsamplingBlock(c_in, 64, use_norm=False) # C64
24
+ self.enc2 = DownsamplingBlock(64, 128) # C128
25
+ self.enc3 = DownsamplingBlock(128, 256) # C256
26
+ self.enc4 = DownsamplingBlock(256, 512) # C512
27
+ self.enc5 = DownsamplingBlock(512, 512) # C512
28
+ self.enc6 = DownsamplingBlock(512, 512) # C512
29
+ self.enc7 = DownsamplingBlock(512, 512) # C512
30
+ self.enc8 = DownsamplingBlock(512, c_out) # C512
31
+
32
+ def forward(self, x):
33
+ x1 = self.enc1(x)
34
+ x2 = self.enc2(x1)
35
+ x3 = self.enc3(x2)
36
+ x4 = self.enc4(x3)
37
+ x5 = self.enc5(x4)
38
+ x6 = self.enc6(x5)
39
+ x7 = self.enc7(x6)
40
+ x8 = self.enc8(x7)
41
+ out = [x8, x7, x6, x5, x4, x3, x2, x1] # latest activation is the first element
42
+ return out
43
+
44
+
45
+ class UnetDecoder(nn.Module):
46
+ """Creates the Unet Decoder Network.
47
+ """
48
+ def __init__(self, c_in=512, c_out=64, use_upsampling=False, mode='nearest'):
49
+ """
50
+ Constructs the Unet Decoder Network.
51
+
52
+ Ck denote a Convolution-BatchNorm-ReLU layer with k filters.
53
+
54
+ CDk denotes a Convolution-BatchNorm-Dropout-ReLU layer with a dropout rate of 50%.
55
+ CD512-CD1024-CD1024-C1024-C1024-C512-C256-C128
56
+ Args:
57
+ c_in (int): Number of input channels.
58
+ c_out (int, optional): Number of output channels. Default is 512.
59
+ use_upsampling (bool, optional): Upsampling method for decoder.
60
+ If True, use upsampling layer followed regular convolution layer.
61
+ If False, use transpose convolution. Default is False
62
+ mode (str, optional): the upsampling algorithm: one of 'nearest',
63
+ 'bilinear', 'bicubic'. Default: 'nearest'
64
+ """
65
+ super(UnetDecoder, self).__init__()
66
+ self.dec1 = UpsamplingBlock(c_in, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD512
67
+ self.dec2 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD1024
68
+ self.dec3 = UpsamplingBlock(1024, 512, use_dropout=True, use_upsampling=use_upsampling, mode=mode) # CD1024
69
+ self.dec4 = UpsamplingBlock(1024, 512, use_upsampling=use_upsampling, mode=mode) # C1024
70
+ self.dec5 = UpsamplingBlock(1024, 256, use_upsampling=use_upsampling, mode=mode) # C1024
71
+ self.dec6 = UpsamplingBlock(512, 128, use_upsampling=use_upsampling, mode=mode) # C512
72
+ self.dec7 = UpsamplingBlock(256, 64, use_upsampling=use_upsampling, mode=mode) # C256
73
+ self.dec8 = UpsamplingBlock(128, c_out, use_upsampling=use_upsampling, mode=mode) # C128
74
+
75
+
76
+ def forward(self, x):
77
+ x9 = torch.cat([x[1], self.dec1(x[0])], 1) # (N,1024,H,W)
78
+ x10 = torch.cat([x[2], self.dec2(x9)], 1) # (N,1024,H,W)
79
+ x11 = torch.cat([x[3], self.dec3(x10)], 1) # (N,1024,H,W)
80
+ x12 = torch.cat([x[4], self.dec4(x11)], 1) # (N,1024,H,W)
81
+ x13 = torch.cat([x[5], self.dec5(x12)], 1) # (N,512,H,W)
82
+ x14 = torch.cat([x[6], self.dec6(x13)], 1) # (N,256,H,W)
83
+ x15 = torch.cat([x[7], self.dec7(x14)], 1) # (N,128,H,W)
84
+ out = self.dec8(x15) # (N,64,H,W)
85
+ return out
86
+
87
+
88
+ class UnetGenerator(nn.Module):
89
+ """Create a Unet-based generator"""
90
+ def __init__(self, c_in=3, c_out=3, use_upsampling=False, mode='nearest'):
91
+ """
92
+ Constructs a Unet generator
93
+ Args:
94
+ c_in (int): The number of input channels.
95
+ c_out (int): The number of output channels.
96
+ use_upsampling (bool, optional): Upsampling method for decoder.
97
+ If True, use upsampling layer followed regular convolution layer.
98
+ If False, use transpose convolution. Default is False
99
+ mode (str, optional): the upsampling algorithm: one of 'nearest',
100
+ 'bilinear', 'bicubic'. Default: 'nearest'
101
+ """
102
+ super(UnetGenerator, self).__init__()
103
+ self.encoder = UnetEncoder(c_in=c_in)
104
+ self.decoder = UnetDecoder(use_upsampling=use_upsampling, mode=mode)
105
+ # In the paper, the authors state:
106
+ # """
107
+ # After the last layer in the decoder, a convolution is applied
108
+ # to map to the number of output channels (3 in general, except
109
+ # in colorization, where it is 2), followed by a Tanh function.
110
+ # """
111
+ # However, in the official Lua implementation, only a Tanh layer is applied.
112
+ # Therefore, I took the liberty of adding a convolutional layer with a
113
+ # kernel size of 3.
114
+ # For more information please check the paper and official github repo:
115
+ # https://github.com/phillipi/pix2pix
116
+ # https://arxiv.org/abs/1611.07004
117
+ self.head = nn.Sequential(
118
+ nn.Conv2d(in_channels=64, out_channels=c_out,
119
+ kernel_size=3, stride=1, padding=1,
120
+ bias=True
121
+ ),
122
+ nn.Tanh()
123
+ )
124
+
125
+ def forward(self, x):
126
+ outE = self.encoder(x)
127
+ outD = self.decoder(outE)
128
+ out = self.head(outD)
129
+ return out
130
+
131
+
132
+ class PatchDiscriminator(nn.Module):
133
+ """Create a PatchGAN discriminator"""
134
+ def __init__(self, c_in=3, c_hid=64, n_layers=3):
135
+ """Constructs a PatchGAN discriminator
136
+
137
+ Args:
138
+ c_in (int, optional): The number of input channels. Defaults to 3.
139
+ c_hid (int, optional): The number of channels after first conv layer.
140
+ Defaults to 64.
141
+ n_layers (int, optional): the number of convolution blocks in the
142
+ discriminator. Defaults to 3.
143
+ """
144
+ super(PatchDiscriminator, self).__init__()
145
+ model = [DownsamplingBlock(c_in, c_hid, use_norm=False)]
146
+
147
+ n_p = 1 # multiplier for previous channel
148
+ n_c = 1 # multiplier for current channel
149
+ # last block is with stride of 1, therefore iterate (n_layers-1) times
150
+ for n in range(1, n_layers):
151
+ n_p = n_c
152
+ n_c = min(2**n, 8) # The number of channels is 512 at most
153
+
154
+ model += [DownsamplingBlock(c_hid*n_p, c_hid*n_c)]
155
+
156
+ n_p = n_c
157
+ n_c = min(2**n_layers, 8)
158
+ model += [DownsamplingBlock(c_hid*n_p, c_hid*n_c, stride=1)] # last block is with stride of 1
159
+
160
+ # last layer is a convolution followed by a Sigmoid function.
161
+ model += [nn.Conv2d(in_channels=c_hid*n_c, out_channels=1,
162
+ kernel_size=4, stride=1, padding=1, bias=True
163
+ )]
164
+ # Normally, there should be a sigmoid layer at the end of discriminator.
165
+ # However, nn.BCEWithLogitsLoss combines the sigmoid layer with BCE loss,
166
+ # providing greater numerical stability. Therefore, the discriminator outputs
167
+ # logits to take advantage of this stability.
168
+ self.model = nn.Sequential(*model)
169
+
170
+ def forward(self, x):
171
+ return self.model(x)
172
+
173
+
174
+ class PixelDiscriminator(nn.Module):
175
+ """Create a PixelGAN discriminator (1x1 PatchGAN discriminator)"""
176
+ def __init__(self, c_in=3, c_hid=64):
177
+ """Constructs a PixelGAN discriminator, a special form of PatchGAN Discriminator.
178
+ All convolutions are 1x1 spatial filters
179
+
180
+ Args:
181
+ c_in (int, optional): The number of input channels. Defaults to 3.
182
+ c_hid (int, optional): The number of channels after first conv layer.
183
+ Defaults to 64.
184
+ """
185
+ super(PixelDiscriminator, self).__init__()
186
+ self.model = nn.Sequential(
187
+ DownsamplingBlock(c_in, c_hid, kernel_size=1, stride=1, padding=0, use_norm=False),
188
+ DownsamplingBlock(c_hid, c_hid*2, kernel_size=1, stride=1, padding=0),
189
+ nn.Conv2d(in_channels=c_hid*2, out_channels=1, kernel_size=1)
190
+ )
191
+ # Similar to PatchDiscriminator, there should be a sigmoid layer at the end of discriminator.
192
+ # However, nn.BCEWithLogitsLoss combines the sigmoid layer with BCE loss,
193
+ # providing greater numerical stability. Therefore, the discriminator outputs
194
+ # logits to take advantage of this stability.
195
+
196
+ def forward(self, x):
197
+ return self.model(x)
198
+
199
+
200
+ class PatchGAN(nn.Module):
201
+ """Create a PatchGAN discriminator"""
202
+ def __init__(self, c_in=3, c_hid=64, mode='patch', n_layers=3):
203
+ """Constructs a PatchGAN discriminator.
204
+
205
+ Args:
206
+ c_in (int, optional): The number of input channels. Defaults to 3.
207
+ c_hid (int, optional): The number of channels after first
208
+ convolutional layer. Defaults to 64.
209
+ mode (str, optional): PatchGAN type. Use 'pixel' for PixelGAN, and
210
+ 'patch' for other types. Defaults to 'patch'.
211
+ n_layers (int, optional): PatchGAN number of layers. Defaults to 3.
212
+ - 16x16 PatchGAN if n=1
213
+ - 34x34 PatchGAN if n=2
214
+ - 70x70 PatchGAN if n=3
215
+ - 142x142 PatchGAN if n=4
216
+ - 286x286 PatchGAN if n=5
217
+ - 574x574 PatchGAN if n=6
218
+ """
219
+ super(PatchGAN, self).__init__()
220
+ if mode == 'pixel':
221
+ self.model = PixelDiscriminator(c_in, c_hid)
222
+ else:
223
+ self.model = PatchDiscriminator(c_in, c_hid, n_layers)
224
+
225
+ def forward(self, x):
226
+ return self.model(x)
pix2pix.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .networks import UnetGenerator, PatchGAN
6
+
7
+ class Pix2Pix(nn.Module):
8
+ """Create a Pix2Pix class. It is a model for image to image translation tasks.
9
+ By default, the model uses a Unet architecture for generator with transposed
10
+ convolution. The discriminator is 70x70 PatchGAN discriminator, by default.
11
+ """
12
+ def __init__(self,
13
+ c_in: int = 3,
14
+ c_out: int = 3,
15
+ is_train: bool = True,
16
+ netD: str = 'patch',
17
+ lambda_L1: float = 100.0,
18
+ is_CGAN: bool = True,
19
+ use_upsampling: bool = False,
20
+ mode: str = 'nearest',
21
+ c_hid: int = 64,
22
+ n_layers: int = 3,
23
+ lr: float = 0.0002,
24
+ beta1: float = 0.5,
25
+ beta2: float = 0.999
26
+ ):
27
+ """Constructs the Pix2Pix class.
28
+
29
+ Args:
30
+ c_in: Number of input channels
31
+ c_out: Number of output channels
32
+ is_train: Whether the model is in training mode
33
+ netD: Type of discriminator ('patch' or 'pixel')
34
+ lambda_L1: Weight for L1 loss
35
+ is_CGAN: If True, use conditional GAN architecture
36
+ use_upsampling: If True, use upsampling in generator instead of transpose conv
37
+ mode: Upsampling mode ('nearest', 'bilinear', 'bicubic')
38
+ c_hid: Number of base filters in discriminator
39
+ n_layers: Number of layers in discriminator
40
+ lr: Learning rate
41
+ beta1: Beta1 parameter for Adam optimizer
42
+ beta2: Beta2 parameter for Adam optimizer
43
+ """
44
+ super(Pix2Pix, self).__init__()
45
+ self.is_CGAN = is_CGAN
46
+ self.lambda_L1 = lambda_L1
47
+
48
+ self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode)
49
+ self.gen = self.gen.apply(self.weights_init)
50
+
51
+ if is_train:
52
+ # Conditional GANs need both input and output together, the total input channel is c_in+c_out
53
+ disc_in = c_in + c_out if is_CGAN else c_out
54
+ self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers)
55
+ self.disc = self.disc.apply(self.weights_init)
56
+
57
+ # Initialize optimizers
58
+ self.gen_optimizer = torch.optim.Adam(
59
+ self.gen.parameters(), lr=lr, betas=(beta1, beta2))
60
+ self.disc_optimizer = torch.optim.Adam(
61
+ self.disc.parameters(), lr=lr, betas=(beta1, beta2))
62
+
63
+ # Initialize loss functions
64
+ self.criterion = nn.BCEWithLogitsLoss()
65
+ self.criterion_L1 = nn.L1Loss()
66
+
67
+ def forward(self, x):
68
+ return self.gen(x)
69
+
70
+ @staticmethod
71
+ def weights_init(m):
72
+ """Initialize network weights.
73
+
74
+ Args:
75
+ m: network module
76
+ """
77
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
78
+ nn.init.normal_(m.weight, 0.0, 0.02)
79
+ if hasattr(m, 'bias') and m.bias is not None:
80
+ nn.init.constant_(m.bias, 0.0)
81
+ if isinstance(m, nn.BatchNorm2d):
82
+ nn.init.normal_(m.weight, 1.0, 0.02)
83
+ nn.init.constant_(m.bias, 0)
84
+
85
+ def _get_disc_inputs(self, real_images, target_images, fake_images):
86
+ """Prepare discriminator inputs based on conditional/unconditional setup."""
87
+ if self.is_CGAN:
88
+ # Conditional GANs need both input and output together,
89
+ # Therefore, the total input channel is c_in+c_out
90
+ real_AB = torch.cat([real_images, target_images], dim=1)
91
+ fake_AB = torch.cat([real_images,
92
+ fake_images.detach()],
93
+ dim=1)
94
+ else:
95
+ real_AB = target_images
96
+ fake_AB = fake_images.detach()
97
+ return real_AB, fake_AB
98
+
99
+ def _get_gen_inputs(self, real_images, fake_images):
100
+ """Prepare discriminator inputs based on conditional/unconditional setup."""
101
+ if self.is_CGAN:
102
+ # Conditional GANs need both input and output together,
103
+ # Therefore, the total input channel is c_in+c_out
104
+ fake_AB = torch.cat([real_images,
105
+ fake_images],
106
+ dim=1)
107
+ else:
108
+ fake_AB = fake_images
109
+ return fake_AB
110
+
111
+
112
+ def step_discriminator(self, real_images, target_images, fake_images):
113
+ """Discriminator forward/backward pass.
114
+
115
+ Args:
116
+ real_images: Input images
117
+ target_images: Ground truth images
118
+ fake_images: Generated images
119
+
120
+ Returns:
121
+ Discriminator loss value
122
+ """
123
+ # Prepare inputs
124
+ real_AB, fake_AB = self._get_disc_inputs(real_images, target_images,
125
+ fake_images)
126
+
127
+ # Forward pass through the discriminator
128
+ pred_real = self.disc(real_AB) # D(x, y)
129
+ pred_fake = self.disc(fake_AB) # D(x, G(x))
130
+
131
+ # Compute the losses
132
+ lossD_real = self.criterion(pred_real, torch.ones_like(pred_real)) # (D(x, y), 1)
133
+ lossD_fake = self.criterion(pred_fake, torch.zeros_like(pred_fake)) # (D(x, y), 0)
134
+ lossD = (lossD_real + lossD_fake) * 0.5 # Combined Loss
135
+ return lossD
136
+
137
+ def step_generator(self, real_images, target_images, fake_images):
138
+ """Discriminator forward/backward pass.
139
+
140
+ Args:
141
+ real_images: Input images
142
+ target_images: Ground truth images
143
+ fake_images: Generated images
144
+
145
+ Returns:
146
+ Discriminator loss value
147
+ """
148
+ # Prepare input
149
+ fake_AB = self._get_gen_inputs(real_images, fake_images)
150
+
151
+ # Forward pass through the discriminator
152
+ pred_fake = self.disc(fake_AB)
153
+
154
+ # Compute the losses
155
+ lossG_GaN = self.criterion(pred_fake, torch.ones_like(pred_fake)) # GAN Loss
156
+ lossG_L1 = self.criterion_L1(fake_images, target_images) # L1 Loss
157
+ lossG = lossG_GaN + self.lambda_L1 * lossG_L1 # Combined Loss
158
+ # Return total loss and individual components
159
+ return lossG, {
160
+ 'loss_G': lossG.item(),
161
+ 'loss_G_GAN': lossG_GaN.item(),
162
+ 'loss_G_L1': lossG_L1.item()
163
+ }
164
+
165
+ def train_step(self, real_images, target_images):
166
+ """Performs a single training step.
167
+
168
+ Args:
169
+ real_images: Input images
170
+ target_images: Ground truth images
171
+
172
+ Returns:
173
+ Dictionary containing all loss values from this step
174
+ """
175
+ # Forward pass through the generator
176
+ fake_images = self.forward(real_images)
177
+
178
+ # Update discriminator
179
+ self.disc_optimizer.zero_grad() # Reset the gradients for D
180
+ lossD = self.stepD(real_images, target_images, fake_images) # Compute the loss
181
+ lossD.backward()
182
+ self.disc_optimizer.step() # Update D
183
+
184
+ # Update generator
185
+ self.gen_optimizer.zero_grad() # Reset the gradients for D
186
+ lossG, G_losses = self.stepG(real_images, target_images, fake_images) # Compute the loss
187
+ lossG.backward()
188
+ self.gen_optimizer.step() # Update D
189
+
190
+ # Return all losses
191
+ return {
192
+ 'loss_D': lossD.item(),
193
+ **G_losses
194
+ }
195
+
196
+ def get_current_visuals(self, real_images, target_images):
197
+ """Return visualization images.
198
+
199
+ Args:
200
+ real_images: Input images
201
+ target_images: Ground truth images
202
+
203
+ Returns:
204
+ Dictionary containing input, target and generated images
205
+ """
206
+ with torch.no_grad():
207
+ fake_images = self.gen(real_images)
208
+ return {
209
+ 'real': real_images,
210
+ 'fake': fake_images,
211
+ 'target': target_images
212
+ }
pix2pix_disc_ckpt_200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:914a7a2152fabd46a7bcc7aad3fb3e642cd0432df7151b4f04c9cf792fdc831b
3
+ size 11090624
pix2pix_gen_ckpt_200.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:db2fda865233203ba13e7c10220bfbf46fa6a92ecb462fb5ad39bf22bbad15af
3
+ size 218246966