import torch import torch.nn as nn import torchvision.transforms as T class ConvBlock(nn.Module): def __init__(self, in_channels: int, out_channels: int): super(ConvBlock, self).__init__() self.block = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x: torch.Tensor): return self.block(x) class CopyAndCrop(nn.Module): def forward(self, x: torch.Tensor, encoded: torch.Tensor): _, _, h, w = encoded.shape crop = T.CenterCrop((h, w))(x) output = torch.cat((x, crop), 1) return output class UNet(nn.Module): def __init__(self, in_channels: int, out_channels: int): super(UNet, self).__init__() self.encoders = nn.ModuleList([ ConvBlock(in_channels, 64), ConvBlock(64, 128), ConvBlock(128, 256), ConvBlock(256, 512), ]) self.down_sample = nn.MaxPool2d(2) self.copyAndCrop = CopyAndCrop() self.decoders = nn.ModuleList([ ConvBlock(1024, 512), ConvBlock(512, 256), ConvBlock(256, 128), ConvBlock(128, 64), ]) self.up_samples = nn.ModuleList([ nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2), nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2), nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2), nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2) ]) self.bottleneck = ConvBlock(512, 1024) self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1, stride=1) def forward(self, x: torch.Tensor): # encode encoded_features = [] for enc in self.encoders: x = enc(x) encoded_features.append(x) x = self.down_sample(x) x = self.bottleneck(x) # decode for idx, denc in enumerate(self.decoders): x = self.up_samples[idx](x) encoded = encoded_features.pop() x = self.copyAndCrop(x, encoded) x = denc(x) output = self.final_conv(x) return output