import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from torch import Tensor |
class DropBlock(nn.Module): |
def __init__(self, block_size: int = 5, p: float = 0.1): |
super().__init__() |
self.block_size = block_size |
self.p = p |
def calculate_gamma(self, x: Tensor) -> float: |
invalid = (1 - self.p) / (self.block_size ** 2) |
valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2) |
return invalid * valid |
def forward(self, x: Tensor) -> Tensor: |
N, C, H, W = x.size() |
if self.training: |
gamma = self.calculate_gamma(x) |
mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1) |
mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device)) |
mask = F.pad(mask, [self.block_size // 2] * 4, value=0) |
mask_block = 1 - F.max_pool2d( |
mask, |
kernel_size=(self.block_size, self.block_size), |
stride=(1, 1), |
padding=(self.block_size // 2, self.block_size // 2), |
) |
x = mask_block * x * (mask_block.numel() / mask_block.sum()) |
return x |
class double_conv(nn.Module): |
def __init__(self,intc,outc): |
super().__init__() |
self.conv1=nn.Conv2d(intc,outc,kernel_size=3,padding=1,stride=1) |
self.drop1= DropBlock(7, 0.9) |
self.bn1=nn.BatchNorm2d(outc) |
self.relu1=nn.ReLU() |
self.conv2=nn.Conv2d(outc,outc,kernel_size=3,padding=1,stride=1) |
self.drop2=DropBlock(7, 0.9) |
self.bn2=nn.BatchNorm2d(outc) |
self.relu2=nn.ReLU() |
def forward(self,input): |
x=self.relu1(self.bn1(self.drop1(self.conv1(input)))) |
x=self.relu2(self.bn2(self.drop2(self.conv2(x)))) |
return x |
class upconv(nn.Module): |
def __init__(self,intc,outc) -> None: |
super().__init__() |
self.up=nn.ConvTranspose2d(intc, outc, kernel_size=2, stride=2, padding=0) |
def forward(self,input): |
x=self.up(input) |
return x |
class unet(nn.Module): |
def __init__(self,int,out) -> None: |
'int: represent the number of image channels' |
'out: number of the desired final channels' |
super().__init__() |
'encoder' |
self.convlayer1=double_conv(int,64) |
self.down1=nn.MaxPool2d((2, 2)) |
self.convlayer2=double_conv(64,128) |
self.down2=nn.MaxPool2d((2, 2)) |
self.convlayer3=double_conv(128,256) |
self.down3=nn.MaxPool2d((2, 2)) |
self.convlayer4=double_conv(256,512) |
self.down4=nn.MaxPool2d((2, 2)) |
'bridge' |
self.bridge=double_conv(512,1024) |
'decoder' |
self.up1=upconv(1024,512) |
self.convlayer5=double_conv(1024,512) |
self.up2=upconv(512,256) |
self.convlayer6=double_conv(512,256) |
self.up3=upconv(256,128) |
self.convlayer7=double_conv(256,128) |
self.up4=upconv(128,64) |
self.convlayer8=double_conv(128,64) |
'output' |
self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0) |
self.sig=nn.Sigmoid() |
def forward(self,input): |
'encoder' |
l1=self.convlayer1(input) |
d1=self.down1(l1) |
l2=self.convlayer2(d1) |
d2=self.down2(l2) |
l3=self.convlayer3(d2) |
d3=self.down3(l3) |
l4=self.convlayer4(d3) |
d4=self.down4(l4) |
'bridge' |
bridge=self.bridge(d4) |
'decoder' |
up1=self.up1(bridge) |
up1 = torch.cat([up1, l4], axis=1) |
l5=self.convlayer5(up1) |
up2=self.up2(l5) |
up2 = torch.cat([up2, l3], axis=1) |
l6=self.convlayer6(up2) |
up3=self.up3(l6) |
up3= torch.cat([up3, l2], axis=1) |
l7=self.convlayer7(up3) |
up4=self.up4(l7) |
up4 = torch.cat([up4, l1], axis=1) |
l8=self.convlayer8(up4) |
out=self.outputs(l8) |
return out |
class spatialAttention(nn.Module): |
def __init__(self) -> None: |
super().__init__() |
self.conv77=nn.Conv2d(2,1,kernel_size=7,padding=3) |
self.sig=nn.Sigmoid() |
def forward(self,input): |
x=torch.cat( (torch.max(input,1)[0].unsqueeze(1), torch.mean(input,1).unsqueeze(1)), dim=1 ) |
x=self.sig(self.conv77(x)) |
x=input*x |
return x |
class SAunet(nn.Module): |
def __init__(self,int,out) -> None: |
'int: represent the number of image channels' |
'out: number of the desired final channels' |
super().__init__() |
'encoder' |
self.convlayer1=double_conv(int,64) |
self.down1=nn.MaxPool2d((2, 2)) |
self.convlayer2=double_conv(64,128) |
self.down2=nn.MaxPool2d((2, 2)) |
self.convlayer3=double_conv(128,256) |
self.down3=nn.MaxPool2d((2, 2)) |
self.convlayer4=double_conv(256,512) |
self.down4=nn.MaxPool2d((2, 2)) |
'bridge' |
self.attmodule=spatialAttention() |
self.bridge1=nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1) |
self.bn1=nn.BatchNorm2d(1024) |
self.act1=nn.ReLU() |
self.bridge2=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1) |
self.bn2=nn.BatchNorm2d(1024) |
self.act2=nn.ReLU() |
'decoder' |
self.up1=upconv(1024,512) |
self.convlayer5=double_conv(1024,512) |
self.up2=upconv(512,256) |
self.convlayer6=double_conv(512,256) |
self.up3=upconv(256,128) |
self.convlayer7=double_conv(256,128) |
self.up4=upconv(128,64) |
self.convlayer8=double_conv(128,64) |
'output' |
self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0) |
self.sig=nn.Sigmoid() |
def forward(self,input): |
'encoder' |
l1=self.convlayer1(input) |
d1=self.down1(l1) |
l2=self.convlayer2(d1) |
d2=self.down2(l2) |
l3=self.convlayer3(d2) |
d3=self.down3(l3) |
l4=self.convlayer4(d3) |
d4=self.down4(l4) |
'bridge' |
b1=self.act1(self.bn1(self.bridge1(d4))) |
att=self.attmodule(b1) |
b2=self.act2(self.bn2(self.bridge2(att))) |
'decoder' |
up1=self.up1(b2) |
up1 = torch.cat([up1, l4], axis=1) |
l5=self.convlayer5(up1) |
up2=self.up2(l5) |
up2 = torch.cat([up2, l3], axis=1) |
l6=self.convlayer6(up2) |
up3=self.up3(l6) |
up3= torch.cat([up3, l2], axis=1) |
l7=self.convlayer7(up3) |
up4=self.up4(l7) |
up4 = torch.cat([up4, l1], axis=1) |
l8=self.convlayer8(up4) |
out=self.outputs(l8) |
return out |