File size: 12,404 Bytes
a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 1e50af8 a664a45 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 |
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from networks import UnetGenerator, PatchGAN
class Pix2Pix(
nn.Module,
PyTorchModelHubMixin):
"""Create a Pix2Pix class. It is a model for image to image translation tasks.
By default, the model uses a Unet architecture for generator with transposed
convolution. The discriminator is 70x70 PatchGAN discriminator, by default.
"""
def __init__(self,
c_in: int = 3,
c_out: int = 3,
is_train: bool = True,
netD: str = 'patch',
lambda_L1: float = 100.0,
is_CGAN: bool = True,
use_upsampling: bool = False,
mode: str = 'nearest',
c_hid: int = 64,
n_layers: int = 3,
lr: float = 0.0002,
beta1: float = 0.5,
beta2: float = 0.999
):
"""Constructs the Pix2Pix class.
Args:
c_in: Number of input channels
c_out: Number of output channels
is_train: Whether the model is in training mode
netD: Type of discriminator ('patch' or 'pixel')
lambda_L1: Weight for L1 loss
is_CGAN: If True, use conditional GAN architecture
use_upsampling: If True, use upsampling in generator instead of transpose conv
mode: Upsampling mode ('nearest', 'bilinear', 'bicubic')
c_hid: Number of base filters in discriminator
n_layers: Number of layers in discriminator
lr: Learning rate
beta1: Beta1 parameter for Adam optimizer
beta2: Beta2 parameter for Adam optimizer
"""
super(Pix2Pix, self).__init__()
self.is_CGAN = is_CGAN
self.lambda_L1 = lambda_L1
self.is_train = is_train
self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode)
self.gen = self.gen.apply(self.weights_init)
if self.is_train:
# Conditional GANs need both input and output together, the total input channel is c_in+c_out
disc_in = c_in + c_out if is_CGAN else c_out
self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers)
self.disc = self.disc.apply(self.weights_init)
# Initialize optimizers
self.gen_optimizer = torch.optim.Adam(
self.gen.parameters(), lr=lr, betas=(beta1, beta2))
self.disc_optimizer = torch.optim.Adam(
self.disc.parameters(), lr=lr, betas=(beta1, beta2))
# Initialize loss functions
self.criterion = nn.BCEWithLogitsLoss()
self.criterion_L1 = nn.L1Loss()
def forward(self, x: torch.Tensor):
return self.gen(x)
@staticmethod
def weights_init(m):
"""Initialize network weights.
Args:
m: network module
"""
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, 0.0, 0.02)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0.0)
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 1.0, 0.02)
nn.init.constant_(m.bias, 0)
def _get_disc_inputs(self,
real_images: torch.Tensor,
target_images: torch.Tensor,
fake_images: torch.Tensor
):
"""Prepare discriminator inputs based on conditional/unconditional setup."""
if self.is_CGAN:
# Conditional GANs need both input and output together,
# Therefore, the total input channel is c_in+c_out
real_AB = torch.cat([real_images, target_images], dim=1)
fake_AB = torch.cat([real_images,
fake_images.detach()],
dim=1)
else:
real_AB = target_images
fake_AB = fake_images.detach()
return real_AB, fake_AB
def _get_gen_inputs(self,
real_images: torch.Tensor,
fake_images: torch.Tensor
):
"""Prepare discriminator inputs based on conditional/unconditional setup."""
if self.is_CGAN:
# Conditional GANs need both input and output together,
# Therefore, the total input channel is c_in+c_out
fake_AB = torch.cat([real_images,
fake_images],
dim=1)
else:
fake_AB = fake_images
return fake_AB
def step_discriminator(self,
real_images: torch.Tensor,
target_images: torch.Tensor,
fake_images: torch.Tensor
):
"""Discriminator forward/backward pass.
Args:
real_images: Input images
target_images: Ground truth images
fake_images: Generated images
Returns:
Discriminator loss value
"""
# Prepare inputs
real_AB, fake_AB = self._get_disc_inputs(real_images, target_images,
fake_images)
# Forward pass through the discriminator
pred_real = self.disc(real_AB) # D(x, y)
pred_fake = self.disc(fake_AB) # D(x, G(x))
# Compute the losses
lossD_real = self.criterion(pred_real, torch.ones_like(pred_real)) # (D(x, y), 1)
lossD_fake = self.criterion(pred_fake, torch.zeros_like(pred_fake)) # (D(x, y), 0)
lossD = (lossD_real + lossD_fake) * 0.5 # Combined Loss
return lossD
def step_generator(self,
real_images: torch.Tensor,
target_images: torch.Tensor,
fake_images: torch.Tensor
):
"""Discriminator forward/backward pass.
Args:
real_images: Input images
target_images: Ground truth images
fake_images: Generated images
Returns:
Discriminator loss value
"""
# Prepare input
fake_AB = self._get_gen_inputs(real_images, fake_images)
# Forward pass through the discriminator
pred_fake = self.disc(fake_AB)
# Compute the losses
lossG_GaN = self.criterion(pred_fake, torch.ones_like(pred_fake)) # GAN Loss
lossG_L1 = self.criterion_L1(fake_images, target_images) # L1 Loss
lossG = lossG_GaN + self.lambda_L1 * lossG_L1 # Combined Loss
# Return total loss and individual components
return lossG, {
'loss_G': lossG.item(),
'loss_G_GAN': lossG_GaN.item(),
'loss_G_L1': lossG_L1.item()
}
def train_step(self,
real_images: torch.Tensor,
target_images: torch.Tensor
):
"""Performs a single training step.
Args:
real_images: Input images
target_images: Ground truth images
Returns:
Dictionary containing all loss values from this step
"""
# Forward pass through the generator
fake_images = self.forward(real_images)
# Update discriminator
self.disc_optimizer.zero_grad() # Reset the gradients for D
lossD = self.step_discriminator(real_images, target_images, fake_images) # Compute the loss
lossD.backward()
self.disc_optimizer.step() # Update D
# Update generator
self.gen_optimizer.zero_grad() # Reset the gradients for D
lossG, G_losses = self.step_generator(real_images, target_images, fake_images) # Compute the loss
lossG.backward()
self.gen_optimizer.step() # Update D
# Return all losses
return {
'loss_D': lossD.item(),
**G_losses
}
def validation_step(self,
real_images: torch.Tensor,
target_images: torch.Tensor
):
"""Performs a single validation step.
Args:
real_images: Input images
target_images: Ground truth images
Returns:
Dictionary containing all loss values from this step
"""
with torch.no_grad():
# Forward pass through the generator
fake_images = self.forward(real_images)
# Compute the loss for D
lossD = self.step_discriminator(real_images, target_images, fake_images)
# Compute the loss for G
_, G_losses = self.step_generator(real_images, target_images, fake_images)
# Return all losses
return {
'loss_D': lossD.item(),
**G_losses
}
def generate(self,
real_images: torch.Tensor,
is_scaled: bool = False,
to_uint8: bool = False
):
if not is_scaled:
real_images = real_images.to(dtype=torch.float32) # Make sure it's a float tensor
real_images = real_images / 255.0 # Normalize to [0, 1]
real_images = (real_images - 0.5) / 0.5 # Scale to [-1, 1]
with torch.no_grad(): # generate image
generated_images = self.forward(real_images)
generated_images = (generated_images + 1) / 2 # Rescale to [0, 1]
if to_uint8:
generated_images = (generated_images* 255).to(dtype=torch.uint8) # Scale to [0, 255] and convert to uint8
return generated_images
def save_model(self, gen_path: str, disc_path: str = None):
"""
Saves the generator model's state dictionary to the specified path.
If in training mode and a discriminator path is provided, saves the
discriminator model's state dictionary as well.
Args:
gen_path (str): The file path where the generator model's state dictionary will be saved.
disc_path (str, optional): The file path where the discriminator model's state dictionary will be saved. Defaults to None.
"""
torch.save(self.gen.state_dict(), gen_path)
if self.is_train and disc_path is not None:
torch.save(self.disc.state_dict(), disc_path)
def load_model(self, gen_path: str, disc_path: str = None, device: str = None):
"""
Loads the generator and optionally the discriminator model from the specified file paths.
Args:
gen_path (str): Path to the generator model file.
disc_path (str, optional): Path to the discriminator model file. Defaults to None.
device (torch.device, optional): The device on which to load the models. If None, the device of the model's parameters will be used. Defaults to None.
Returns:
None
"""
device = device if device else next(self.gen.parameters()).device
self.gen.load_state_dict(torch.load(gen_path, map_location=device, weights_only=True), strict=False)
if disc_path is not None and self.is_train:
device = device if device else next(self.disc.parameters()).device
self.disc.load_state_dict(torch.load(gen_path, map_location=device, weights_only=True), strict=False)
def get_current_visuals(self,
real_images: torch.Tensor,
target_images: torch.Tensor
):
"""Return visualization images.
Args:
real_images: Input images
target_images: Ground truth images
Returns:
Dictionary containing input, target and generated images
"""
with torch.no_grad():
fake_images = self.gen(real_images)
return {
'real': real_images,
'fake': fake_images,
'target': target_images
} |