Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torchvision.transforms as T | |
class ConvBlock(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int): | |
super(ConvBlock, self).__init__() | |
self.block = nn.Sequential( | |
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(out_channels), | |
nn.ReLU(inplace=True) | |
) | |
def forward(self, x: torch.Tensor): | |
return self.block(x) | |
class CopyAndCrop(nn.Module): | |
def forward(self, x: torch.Tensor, encoded: torch.Tensor): | |
_, _, h, w = encoded.shape | |
crop = T.CenterCrop((h, w))(x) | |
output = torch.cat((x, crop), 1) | |
return output | |
class UNet(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int): | |
super(UNet, self).__init__() | |
self.encoders = nn.ModuleList([ | |
ConvBlock(in_channels, 64), | |
ConvBlock(64, 128), | |
ConvBlock(128, 256), | |
ConvBlock(256, 512), | |
]) | |
self.down_sample = nn.MaxPool2d(2) | |
self.copyAndCrop = CopyAndCrop() | |
self.decoders = nn.ModuleList([ | |
ConvBlock(1024, 512), | |
ConvBlock(512, 256), | |
ConvBlock(256, 128), | |
ConvBlock(128, 64), | |
]) | |
self.up_samples = nn.ModuleList([ | |
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2), | |
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), | |
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), | |
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) | |
]) | |
self.bottleneck = ConvBlock(512, 1024) | |
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1, stride=1) | |
def forward(self, x: torch.Tensor): | |
# encode | |
encoded_features = [] | |
for enc in self.encoders: | |
x = enc(x) | |
encoded_features.append(x) | |
x = self.down_sample(x) | |
x = self.bottleneck(x) | |
# decode | |
for idx, denc in enumerate(self.decoders): | |
x = self.up_samples[idx](x) | |
encoded = encoded_features.pop() | |
x = self.copyAndCrop(x, encoded) | |
x = denc(x) | |
output = self.final_conv(x) | |
return output |