import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from torch import Tensor # __all__ = [ # "ResidualConvBlock", # "Discriminator", "Generator", # ] class ResidualConvBlock(nn.Module): """Implements residual conv function. Args: channels (int): Number of channels in the input image. """ def __init__(self, channels: int) -> None: super(ResidualConvBlock, self).__init__() self.rcb = nn.Sequential( nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(channels), nn.PReLU(), nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(channels), ) def forward(self, x: Tensor) -> Tensor: identity = x out = self.rcb(x) out = torch.add(out, identity) return out class Discriminator(nn.Module): def __init__(self) -> None: super(Discriminator, self).__init__() self.features = nn.Sequential( # input size. (3) x 96 x 96 nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False), nn.LeakyReLU(0.2, True), # state size. (64) x 48 x 48 nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, True), nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True), # state size. (128) x 24 x 24 nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True), nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True), # state size. (256) x 12 x 12 nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True), nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, True), # state size. (512) x 6 x 6 nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, True), ) self.classifier = nn.Sequential( nn.Linear(512 * 6 * 6, 1024), nn.LeakyReLU(0.2, True), nn.Linear(1024, 1), ) def forward(self, x: Tensor) -> Tensor: out = self.features(x) out = torch.flatten(out, 1) out = self.classifier(out) return out class Generator(nn.Module): def __init__(self) -> None: super(Generator, self).__init__() # First conv layer. self.conv_block1 = nn.Sequential( nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)), nn.PReLU(), ) # Features trunk blocks. trunk = [] for _ in range(16): trunk.append(ResidualConvBlock(64)) self.trunk = nn.Sequential(*trunk) # Second conv layer. self.conv_block2 = nn.Sequential( nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False), nn.BatchNorm2d(64), ) # Upscale conv block. self.upsampling = nn.Sequential( nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)), nn.PixelShuffle(2), nn.PReLU(), nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)), nn.PixelShuffle(2), nn.PReLU(), ) # Output layer. self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4)) # Initialize neural network weights. self._initialize_weights() def forward(self, x: Tensor, dop=None) -> Tensor: if not dop: return self._forward_impl(x) else: return self._forward_w_dop_impl(x, dop) # Support torch.script function. def _forward_impl(self, x: Tensor) -> Tensor: out1 = self.conv_block1(x) out = self.trunk(out1) out2 = self.conv_block2(out) out = torch.add(out1, out2) out = self.upsampling(out) out = self.conv_block3(out) return out def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor: out1 = self.conv_block1(x) out = self.trunk(out1) out2 = F.dropout2d(self.conv_block2(out), p=dop) out = torch.add(out1, out2) out = self.upsampling(out) out = self.conv_block3(out) return out def _initialize_weights(self) -> None: for module in self.modules(): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.BatchNorm2d): nn.init.constant_(module.weight, 1) #### BayesCap class BayesCap(nn.Module): def __init__(self, in_channels=3, out_channels=3) -> None: super(BayesCap, self).__init__() # First conv layer. self.conv_block1 = nn.Sequential( nn.Conv2d( in_channels, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), ) # Features trunk blocks. trunk = [] for _ in range(16): trunk.append(ResidualConvBlock(64)) self.trunk = nn.Sequential(*trunk) # Second conv layer. self.conv_block2 = nn.Sequential( nn.Conv2d( 64, 64, kernel_size=3, stride=1, padding=1, bias=False ), nn.BatchNorm2d(64), ) # Output layer. self.conv_block3_mu = nn.Conv2d( 64, out_channels=out_channels, kernel_size=9, stride=1, padding=4 ) self.conv_block3_alpha = nn.Sequential( nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 1, kernel_size=9, stride=1, padding=4 ), nn.ReLU(), ) self.conv_block3_beta = nn.Sequential( nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 1, kernel_size=9, stride=1, padding=4 ), nn.ReLU(), ) # Initialize neural network weights. self._initialize_weights() def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) # Support torch.script function. def _forward_impl(self, x: Tensor) -> Tensor: out1 = self.conv_block1(x) out = self.trunk(out1) out2 = self.conv_block2(out) out = out1 + out2 out_mu = self.conv_block3_mu(out) out_alpha = self.conv_block3_alpha(out) out_beta = self.conv_block3_beta(out) return out_mu, out_alpha, out_beta def _initialize_weights(self) -> None: for module in self.modules(): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.BatchNorm2d): nn.init.constant_(module.weight, 1) class BayesCap_noID(nn.Module): def __init__(self, in_channels=3, out_channels=3) -> None: super(BayesCap_noID, self).__init__() # First conv layer. self.conv_block1 = nn.Sequential( nn.Conv2d( in_channels, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), ) # Features trunk blocks. trunk = [] for _ in range(16): trunk.append(ResidualConvBlock(64)) self.trunk = nn.Sequential(*trunk) # Second conv layer. self.conv_block2 = nn.Sequential( nn.Conv2d( 64, 64, kernel_size=3, stride=1, padding=1, bias=False ), nn.BatchNorm2d(64), ) # Output layer. # self.conv_block3_mu = nn.Conv2d( # 64, out_channels=out_channels, # kernel_size=9, stride=1, padding=4 # ) self.conv_block3_alpha = nn.Sequential( nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 1, kernel_size=9, stride=1, padding=4 ), nn.ReLU(), ) self.conv_block3_beta = nn.Sequential( nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 64, kernel_size=9, stride=1, padding=4 ), nn.PReLU(), nn.Conv2d( 64, 1, kernel_size=9, stride=1, padding=4 ), nn.ReLU(), ) # Initialize neural network weights. self._initialize_weights() def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) # Support torch.script function. def _forward_impl(self, x: Tensor) -> Tensor: out1 = self.conv_block1(x) out = self.trunk(out1) out2 = self.conv_block2(out) out = out1 + out2 # out_mu = self.conv_block3_mu(out) out_alpha = self.conv_block3_alpha(out) out_beta = self.conv_block3_beta(out) return out_alpha, out_beta def _initialize_weights(self) -> None: for module in self.modules(): if isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) elif isinstance(module, nn.BatchNorm2d): nn.init.constant_(module.weight, 1)