test_space / model.py
Rounak28's picture
Upload 13 files
b3b6dd4
import torch
import torch.nn as nn
import torchvision.transforms as T
class ConvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(ConvBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.Tensor):
return self.block(x)
class CopyAndCrop(nn.Module):
def forward(self, x: torch.Tensor, encoded: torch.Tensor):
_, _, h, w = encoded.shape
crop = T.CenterCrop((h, w))(x)
output = torch.cat((x, crop), 1)
return output
class UNet(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super(UNet, self).__init__()
self.encoders = nn.ModuleList([
ConvBlock(in_channels, 64),
ConvBlock(64, 128),
ConvBlock(128, 256),
ConvBlock(256, 512),
])
self.down_sample = nn.MaxPool2d(2)
self.copyAndCrop = CopyAndCrop()
self.decoders = nn.ModuleList([
ConvBlock(1024, 512),
ConvBlock(512, 256),
ConvBlock(256, 128),
ConvBlock(128, 64),
])
self.up_samples = nn.ModuleList([
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2),
nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
])
self.bottleneck = ConvBlock(512, 1024)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor):
# encode
encoded_features = []
for enc in self.encoders:
x = enc(x)
encoded_features.append(x)
x = self.down_sample(x)
x = self.bottleneck(x)
# decode
for idx, denc in enumerate(self.decoders):
x = self.up_samples[idx](x)
encoded = encoded_features.pop()
x = self.copyAndCrop(x, encoded)
x = denc(x)
output = self.final_conv(x)
return output