DocuGAN / model.py
Thomas.Chaigneau
add model
6de6ae4
raw
history blame
12 kB
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List, Optional, OrderedDict, Tuple
class Discriminator(nn.Module):
def __init__(
self,
hidden_size: Optional[int] = 64,
channels: Optional[int] = 3,
kernel_size: Optional[int] = 4,
stride: Optional[int] = 2,
padding: Optional[int] = 1,
negative_slope: Optional[float] = 0.2,
bias: Optional[bool] = False,
):
"""
Initializes the discriminator.
Parameters
----------
hidden_size : int, optional
The input size. (the default is 64)
channels : int, optional
The number of channels. (default: 3)
kernel_size : int, optional
The kernal size. (default: 4)
stride : int, optional
The stride. (default: 2)
padding : int, optional
The padding. (default: 1)
negative_slope : float, optional
The negative slope. (default: 0.2)
bias : bool, optional
Whether to use bias. (default: False)
"""
super().__init__()
self.hidden_size = hidden_size
self.channels = channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.negative_slope = negative_slope
self.bias = bias
self.model = nn.Sequential(
nn.utils.spectral_norm(
nn.Conv2d(
self.channels, self.hidden_size,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.utils.spectral_norm(
nn.Conv2d(
hidden_size, hidden_size * 2,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
),
nn.BatchNorm2d(hidden_size * 2),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.utils.spectral_norm(
nn.Conv2d(
hidden_size * 2, hidden_size * 4,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
),
nn.BatchNorm2d(hidden_size * 4),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.utils.spectral_norm(
nn.Conv2d(
hidden_size * 4, hidden_size * 8,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
),
nn.BatchNorm2d(hidden_size * 8),
nn.LeakyReLU(self.negative_slope, inplace=True),
nn.utils.spectral_norm(
nn.Conv2d(hidden_size * 8, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), # output size: (1, 1, 1)
),
nn.Flatten(),
nn.Sigmoid(),
)
def forward(self, input_img: torch.Tensor) -> torch.Tensor:
"""
Forward propagation.
Parameters
----------
input_img : torch.Tensor
The input image.
Returns
-------
torch.Tensor
The output.
"""
logits = self.model(input_img)
return logits
class Generator(nn.Module):
def __init__(
self,
hidden_size: Optional[int] = 64,
latent_size: Optional[int] = 128,
channels: Optional[int] = 3,
kernel_size: Optional[int] = 4,
stride: Optional[int] = 2,
padding: Optional[int] = 1,
bias: Optional[bool] = False,
):
"""
Initializes the generator.
Parameters
----------
hidden_size : int, optional
The hidden size. (default: 64)
latent_size : int, optional
The latent size. (default: 128)
channels : int, optional
The number of channels. (default: 3)
kernel_size : int, optional
The kernel size. (default: 4)
stride : int, optional
The stride. (default: 2)
padding : int, optional
The padding. (default: 1)
bias : bool, optional
Whether to use bias. (default: False)
"""
super().__init__()
self.hidden_size = hidden_size
self.latent_size = latent_size
self.channels = channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.bias = bias
self.model = nn.Sequential(
nn.ConvTranspose2d(
self.latent_size, self.hidden_size * 8,
kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias
),
nn.BatchNorm2d(self.hidden_size * 8),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
self.hidden_size * 8, self.hidden_size * 4,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
nn.BatchNorm2d(self.hidden_size * 4),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
self.hidden_size * 4, self.hidden_size * 2,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
nn.BatchNorm2d(self.hidden_size * 2),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
self.hidden_size * 2, self.hidden_size,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
nn.BatchNorm2d(self.hidden_size),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(
self.hidden_size, self.channels,
kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias
),
nn.Tanh() # output size: (channels, 64, 64)
)
def forward(self, input_noise: torch.Tensor) -> torch.Tensor:
"""
Forward propagation.
Parameters
----------
input_noise : torch.Tensor
The input image.
Returns
-------
torch.Tensor
The output.
"""
fake_img = self.model(input_noise)
return fake_img
class DocuGAN(pl.LightningModule):
def __init__(
self,
hidden_size: Optional[int] = 64,
latent_size: Optional[int] = 128,
num_channel: Optional[int] = 3,
learning_rate: Optional[float] = 0.0002,
batch_size: Optional[int] = 128,
bias1: Optional[float] = 0.5,
bias2: Optional[float] = 0.999,
):
"""
Initializes the LightningGan.
Parameters
----------
hidden_size : int, optional
The hidden size. (default: 64)
latent_size : int, optional
The latent size. (default: 128)
num_channel : int, optional
The number of channels. (default: 3)
learning_rate : float, optional
The learning rate. (default: 0.0002)
batch_size : int, optional
The batch size. (default: 128)
bias1 : float, optional
The bias1. (default: 0.5)
bias2 : float, optional
The bias2. (default: 0.999)
"""
super().__init__()
self.hidden_size = hidden_size
self.latent_size = latent_size
self.num_channel = num_channel
self.learning_rate = learning_rate
self.batch_size = batch_size
self.bias1 = bias1
self.bias2 = bias2
self.criterion = nn.BCELoss()
self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1)
self.save_hyperparameters()
self.generator = Generator(
latent_size=self.latent_size, channels=self.num_channel, hidden_size=self.hidden_size
)
self.generator.apply(self.weights_init)
self.discriminator = Discriminator(channels=self.num_channel, hidden_size=self.hidden_size)
self.discriminator.apply(self.weights_init)
# self.model = InceptionV3() # For FID metric
def weights_init(self, m: nn.Module) -> None:
"""
Initializes the weights.
Parameters
----------
m : nn.Module
The module.
"""
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]:
"""
Configures the optimizers.
Returns
-------
Tuple[List[torch.optim.Optimizer], List]
The optimizers and the LR schedulers.
"""
opt_generator = torch.optim.Adam(
self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
)
opt_discriminator = torch.optim.Adam(
self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2)
)
return [opt_generator, opt_discriminator], []
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Forward propagation.
Parameters
----------
z : torch.Tensorh
The latent vector.
Returns
-------
torch.Tensor
The output.
"""
return self.generator(z)
def training_step(
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int
) -> Dict:
"""
Training step.
Parameters
----------
batch : Tuple[torch.Tensor, torch.Tensor]
The batch.
batch_idx : int
The batch index.
optimizer_idx : int
The optimizer index.
Returns
-------
Dict
The training loss.
"""
real_images = batch["tr_image"]
if optimizer_idx == 0: # Only train the generator
fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
fake_random_noise = fake_random_noise.type_as(real_images)
fake_images = self(fake_random_noise)
# Try to fool the discriminator
preds = self.discriminator(fake_images)
loss = self.criterion(preds, torch.ones_like(preds))
self.log("g_loss", loss, on_step=False, on_epoch=True)
tqdm_dict = {"g_loss": loss}
output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
return output
elif optimizer_idx == 1: # Only train the discriminator
real_preds = self.discriminator(real_images)
real_loss = self.criterion(real_preds, torch.ones_like(real_preds))
# Generate fake images
real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1)
real_random_noise = real_random_noise.type_as(real_images)
fake_images = self(real_random_noise)
# Pass fake images though discriminator
fake_preds = self.discriminator(fake_images)
fake_loss = self.criterion(fake_preds, torch.zeros_like(fake_preds))
# Update discriminator weights
loss = real_loss + fake_loss
self.log("d_loss", loss, on_step=False, on_epoch=True)
tqdm_dict = {"d_loss": loss}
output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict})
return output