|
import pytorch_lightning as pl |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from typing import Dict, List, Optional, OrderedDict, Tuple |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size: Optional[int] = 64, |
|
channels: Optional[int] = 3, |
|
kernel_size: Optional[int] = 4, |
|
stride: Optional[int] = 2, |
|
padding: Optional[int] = 1, |
|
negative_slope: Optional[float] = 0.2, |
|
bias: Optional[bool] = False, |
|
): |
|
""" |
|
Initializes the discriminator. |
|
|
|
Parameters |
|
---------- |
|
hidden_size : int, optional |
|
The input size. (the default is 64) |
|
channels : int, optional |
|
The number of channels. (default: 3) |
|
kernel_size : int, optional |
|
The kernal size. (default: 4) |
|
stride : int, optional |
|
The stride. (default: 2) |
|
padding : int, optional |
|
The padding. (default: 1) |
|
negative_slope : float, optional |
|
The negative slope. (default: 0.2) |
|
bias : bool, optional |
|
Whether to use bias. (default: False) |
|
""" |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.channels = channels |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.negative_slope = negative_slope |
|
self.bias = bias |
|
|
|
self.model = nn.Sequential( |
|
nn.utils.spectral_norm( |
|
nn.Conv2d( |
|
self.channels, self.hidden_size, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
), |
|
nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
|
nn.utils.spectral_norm( |
|
nn.Conv2d( |
|
hidden_size, hidden_size * 2, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
), |
|
nn.BatchNorm2d(hidden_size * 2), |
|
nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
|
nn.utils.spectral_norm( |
|
nn.Conv2d( |
|
hidden_size * 2, hidden_size * 4, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
), |
|
nn.BatchNorm2d(hidden_size * 4), |
|
nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
|
nn.utils.spectral_norm( |
|
nn.Conv2d( |
|
hidden_size * 4, hidden_size * 8, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
), |
|
nn.BatchNorm2d(hidden_size * 8), |
|
nn.LeakyReLU(self.negative_slope, inplace=True), |
|
|
|
nn.utils.spectral_norm( |
|
nn.Conv2d(hidden_size * 8, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), |
|
), |
|
nn.Flatten(), |
|
nn.Sigmoid(), |
|
) |
|
|
|
|
|
def forward(self, input_img: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward propagation. |
|
|
|
Parameters |
|
---------- |
|
input_img : torch.Tensor |
|
The input image. |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
The output. |
|
""" |
|
logits = self.model(input_img) |
|
return logits |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size: Optional[int] = 64, |
|
latent_size: Optional[int] = 128, |
|
channels: Optional[int] = 3, |
|
kernel_size: Optional[int] = 4, |
|
stride: Optional[int] = 2, |
|
padding: Optional[int] = 1, |
|
bias: Optional[bool] = False, |
|
): |
|
""" |
|
Initializes the generator. |
|
|
|
Parameters |
|
---------- |
|
hidden_size : int, optional |
|
The hidden size. (default: 64) |
|
latent_size : int, optional |
|
The latent size. (default: 128) |
|
channels : int, optional |
|
The number of channels. (default: 3) |
|
kernel_size : int, optional |
|
The kernel size. (default: 4) |
|
stride : int, optional |
|
The stride. (default: 2) |
|
padding : int, optional |
|
The padding. (default: 1) |
|
bias : bool, optional |
|
Whether to use bias. (default: False) |
|
""" |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.latent_size = latent_size |
|
self.channels = channels |
|
self.kernel_size = kernel_size |
|
self.stride = stride |
|
self.padding = padding |
|
self.bias = bias |
|
|
|
self.model = nn.Sequential( |
|
nn.ConvTranspose2d( |
|
self.latent_size, self.hidden_size * 8, |
|
kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias |
|
), |
|
nn.BatchNorm2d(self.hidden_size * 8), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d( |
|
self.hidden_size * 8, self.hidden_size * 4, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
nn.BatchNorm2d(self.hidden_size * 4), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d( |
|
self.hidden_size * 4, self.hidden_size * 2, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
nn.BatchNorm2d(self.hidden_size * 2), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d( |
|
self.hidden_size * 2, self.hidden_size, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
nn.BatchNorm2d(self.hidden_size), |
|
nn.ReLU(inplace=True), |
|
|
|
nn.ConvTranspose2d( |
|
self.hidden_size, self.channels, |
|
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias |
|
), |
|
nn.Tanh() |
|
) |
|
|
|
|
|
def forward(self, input_noise: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward propagation. |
|
|
|
Parameters |
|
---------- |
|
input_noise : torch.Tensor |
|
The input image. |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
The output. |
|
""" |
|
fake_img = self.model(input_noise) |
|
return fake_img |
|
|
|
|
|
class DocuGAN(pl.LightningModule): |
|
def __init__( |
|
self, |
|
hidden_size: Optional[int] = 64, |
|
latent_size: Optional[int] = 128, |
|
num_channel: Optional[int] = 3, |
|
learning_rate: Optional[float] = 0.0002, |
|
batch_size: Optional[int] = 128, |
|
bias1: Optional[float] = 0.5, |
|
bias2: Optional[float] = 0.999, |
|
): |
|
""" |
|
Initializes the LightningGan. |
|
|
|
Parameters |
|
---------- |
|
hidden_size : int, optional |
|
The hidden size. (default: 64) |
|
latent_size : int, optional |
|
The latent size. (default: 128) |
|
num_channel : int, optional |
|
The number of channels. (default: 3) |
|
learning_rate : float, optional |
|
The learning rate. (default: 0.0002) |
|
batch_size : int, optional |
|
The batch size. (default: 128) |
|
bias1 : float, optional |
|
The bias1. (default: 0.5) |
|
bias2 : float, optional |
|
The bias2. (default: 0.999) |
|
""" |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.latent_size = latent_size |
|
self.num_channel = num_channel |
|
self.learning_rate = learning_rate |
|
self.batch_size = batch_size |
|
self.bias1 = bias1 |
|
self.bias2 = bias2 |
|
self.criterion = nn.BCELoss() |
|
self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1) |
|
self.save_hyperparameters() |
|
|
|
self.generator = Generator( |
|
latent_size=self.latent_size, channels=self.num_channel, hidden_size=self.hidden_size |
|
) |
|
self.generator.apply(self.weights_init) |
|
|
|
self.discriminator = Discriminator(channels=self.num_channel, hidden_size=self.hidden_size) |
|
self.discriminator.apply(self.weights_init) |
|
|
|
|
|
|
|
|
|
def weights_init(self, m: nn.Module) -> None: |
|
""" |
|
Initializes the weights. |
|
|
|
Parameters |
|
---------- |
|
m : nn.Module |
|
The module. |
|
""" |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
nn.init.normal_(m.weight.data, 0.0, 0.02) |
|
elif classname.find("BatchNorm") != -1: |
|
nn.init.normal_(m.weight.data, 1.0, 0.02) |
|
nn.init.constant_(m.bias.data, 0) |
|
|
|
|
|
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]: |
|
""" |
|
Configures the optimizers. |
|
|
|
Returns |
|
------- |
|
Tuple[List[torch.optim.Optimizer], List] |
|
The optimizers and the LR schedulers. |
|
""" |
|
opt_generator = torch.optim.Adam( |
|
self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2) |
|
) |
|
opt_discriminator = torch.optim.Adam( |
|
self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2) |
|
) |
|
return [opt_generator, opt_discriminator], [] |
|
|
|
|
|
def forward(self, z: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Forward propagation. |
|
|
|
Parameters |
|
---------- |
|
z : torch.Tensorh |
|
The latent vector. |
|
|
|
Returns |
|
------- |
|
torch.Tensor |
|
The output. |
|
""" |
|
return self.generator(z) |
|
|
|
|
|
def training_step( |
|
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int |
|
) -> Dict: |
|
""" |
|
Training step. |
|
|
|
Parameters |
|
---------- |
|
batch : Tuple[torch.Tensor, torch.Tensor] |
|
The batch. |
|
batch_idx : int |
|
The batch index. |
|
optimizer_idx : int |
|
The optimizer index. |
|
|
|
Returns |
|
------- |
|
Dict |
|
The training loss. |
|
""" |
|
real_images = batch["tr_image"] |
|
|
|
if optimizer_idx == 0: |
|
fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1) |
|
fake_random_noise = fake_random_noise.type_as(real_images) |
|
fake_images = self(fake_random_noise) |
|
|
|
|
|
preds = self.discriminator(fake_images) |
|
loss = self.criterion(preds, torch.ones_like(preds)) |
|
self.log("g_loss", loss, on_step=False, on_epoch=True) |
|
|
|
tqdm_dict = {"g_loss": loss} |
|
output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) |
|
return output |
|
|
|
elif optimizer_idx == 1: |
|
real_preds = self.discriminator(real_images) |
|
real_loss = self.criterion(real_preds, torch.ones_like(real_preds)) |
|
|
|
|
|
real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1) |
|
real_random_noise = real_random_noise.type_as(real_images) |
|
fake_images = self(real_random_noise) |
|
|
|
|
|
fake_preds = self.discriminator(fake_images) |
|
fake_loss = self.criterion(fake_preds, torch.zeros_like(fake_preds)) |
|
|
|
|
|
loss = real_loss + fake_loss |
|
self.log("d_loss", loss, on_step=False, on_epoch=True) |
|
|
|
tqdm_dict = {"d_loss": loss} |
|
output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) |
|
return output |
|
|