|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
from torchsummary import summary |
|
from torch.utils.data import TensorDataset, DataLoader |
|
|
|
class recon_encoder(nn.Module): |
|
|
|
def __init__(self, latent_size, nconv=16, pool=4, drop=0.05): |
|
super(recon_encoder, self).__init__() |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
nn.Conv2d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv, nconv, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.MaxPool2d((pool,pool)), |
|
|
|
nn.Conv2d(nconv, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.MaxPool2d((pool,pool)), |
|
|
|
nn.Conv2d(nconv*2, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.MaxPool2d((pool,pool)), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
self.bottleneck = nn.Sequential( |
|
|
|
nn.Flatten(), |
|
nn.Linear(1024, latent_size), |
|
|
|
nn.ReLU(), |
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
self.decoder1 = nn.Sequential( |
|
|
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=pool, mode='bilinear'), |
|
|
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=pool, mode='bilinear'), |
|
|
|
nn.Conv2d(nconv*4, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=pool, mode='bilinear'), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.Conv2d(nconv*2, 1, 3, stride=1, padding=(1,1)), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
def forward(self,x): |
|
with torch.cuda.amp.autocast(): |
|
x1 = self.encoder(x) |
|
x1 = self.bottleneck(x1) |
|
|
|
return x1 |
|
|
|
|
|
|
|
def calc_fc_shape(self): |
|
x0 = torch.zeros([256,256]).unsqueeze(0) |
|
x0 = self.encoder(x0) |
|
|
|
self.conv_bock_output_shape = x0.shape |
|
|
|
self.flattened_size = x0.flatten().shape[0] |
|
|
|
return self.flattened_size |
|
|
|
class recon_model(nn.Module): |
|
|
|
def __init__(self, latent_size, nconv=16, pool=4, drop=0.05): |
|
super(recon_model, self).__init__() |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
nn.Conv2d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv, nconv, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.MaxPool2d((pool,pool)), |
|
|
|
nn.Conv2d(nconv, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.MaxPool2d((pool,pool)), |
|
|
|
nn.Conv2d(nconv*2, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.MaxPool2d((pool,pool)), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
self.bottleneck = nn.Sequential( |
|
|
|
nn.Flatten(), |
|
nn.Linear(1024, latent_size), |
|
|
|
nn.ReLU(), |
|
nn.Linear(latent_size, 1024), |
|
|
|
nn.ReLU(), |
|
nn.Unflatten(1,(64,4,4)) |
|
) |
|
|
|
|
|
self.decoder1 = nn.Sequential( |
|
|
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=pool, mode='bilinear'), |
|
|
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=pool, mode='bilinear'), |
|
|
|
nn.Conv2d(nconv*4, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)), |
|
nn.Dropout(drop), |
|
nn.ReLU(), |
|
nn.Upsample(scale_factor=pool, mode='bilinear'), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn.Conv2d(nconv*2, 1, 3, stride=1, padding=(1,1)), |
|
nn.Sigmoid() |
|
) |
|
|
|
|
|
def forward(self,x): |
|
with torch.cuda.amp.autocast(): |
|
x1 = self.encoder(x) |
|
x1 = self.bottleneck(x1) |
|
|
|
return self.decoder1(x1) |
|
|
|
|
|
|
|
def calc_fc_shape(self): |
|
x0 = torch.zeros([256,256]).unsqueeze(0) |
|
x0 = self.encoder(x0) |
|
|
|
self.conv_bock_output_shape = x0.shape |
|
|
|
self.flattened_size = x0.flatten().shape[0] |
|
|
|
return self.flattened_size |