|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
|
|
|
|
class DropBlock(nn.Module): |
|
def __init__(self, block_size: int = 5, p: float = 0.1): |
|
super().__init__() |
|
self.block_size = block_size |
|
self.p = p |
|
|
|
def calculate_gamma(self, x: Tensor) -> float: |
|
|
|
|
|
invalid = (1 - self.p) / (self.block_size ** 2) |
|
valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2) |
|
return invalid * valid |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
N, C, H, W = x.size() |
|
if self.training: |
|
gamma = self.calculate_gamma(x) |
|
mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1) |
|
mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device)) |
|
mask = F.pad(mask, [self.block_size // 2] * 4, value=0) |
|
mask_block = 1 - F.max_pool2d( |
|
mask, |
|
kernel_size=(self.block_size, self.block_size), |
|
stride=(1, 1), |
|
padding=(self.block_size // 2, self.block_size // 2), |
|
) |
|
x = mask_block * x * (mask_block.numel() / mask_block.sum()) |
|
return x |
|
|
|
|
|
class double_conv(nn.Module): |
|
def __init__(self,intc,outc): |
|
super().__init__() |
|
self.conv1=nn.Conv2d(intc,outc,kernel_size=3,padding=1,stride=1) |
|
self.drop1= DropBlock(7, 0.9) |
|
self.bn1=nn.BatchNorm2d(outc) |
|
self.relu1=nn.ReLU() |
|
self.conv2=nn.Conv2d(outc,outc,kernel_size=3,padding=1,stride=1) |
|
self.drop2=DropBlock(7, 0.9) |
|
self.bn2=nn.BatchNorm2d(outc) |
|
self.relu2=nn.ReLU() |
|
|
|
def forward(self,input): |
|
x=self.relu1(self.bn1(self.drop1(self.conv1(input)))) |
|
x=self.relu2(self.bn2(self.drop2(self.conv2(x)))) |
|
|
|
return x |
|
class upconv(nn.Module): |
|
def __init__(self,intc,outc) -> None: |
|
super().__init__() |
|
self.up=nn.ConvTranspose2d(intc, outc, kernel_size=2, stride=2, padding=0) |
|
|
|
|
|
def forward(self,input): |
|
x=self.up(input) |
|
|
|
return x |
|
class unet(nn.Module): |
|
def __init__(self,int,out) -> None: |
|
'int: represent the number of image channels' |
|
'out: number of the desired final channels' |
|
|
|
super().__init__() |
|
'encoder' |
|
self.convlayer1=double_conv(int,64) |
|
self.down1=nn.MaxPool2d((2, 2)) |
|
self.convlayer2=double_conv(64,128) |
|
self.down2=nn.MaxPool2d((2, 2)) |
|
self.convlayer3=double_conv(128,256) |
|
self.down3=nn.MaxPool2d((2, 2)) |
|
self.convlayer4=double_conv(256,512) |
|
self.down4=nn.MaxPool2d((2, 2)) |
|
|
|
'bridge' |
|
self.bridge=double_conv(512,1024) |
|
'decoder' |
|
self.up1=upconv(1024,512) |
|
self.convlayer5=double_conv(1024,512) |
|
self.up2=upconv(512,256) |
|
self.convlayer6=double_conv(512,256) |
|
self.up3=upconv(256,128) |
|
self.convlayer7=double_conv(256,128) |
|
self.up4=upconv(128,64) |
|
self.convlayer8=double_conv(128,64) |
|
'output' |
|
self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0) |
|
self.sig=nn.Sigmoid() |
|
def forward(self,input): |
|
'encoder' |
|
l1=self.convlayer1(input) |
|
d1=self.down1(l1) |
|
l2=self.convlayer2(d1) |
|
d2=self.down2(l2) |
|
l3=self.convlayer3(d2) |
|
d3=self.down3(l3) |
|
l4=self.convlayer4(d3) |
|
d4=self.down4(l4) |
|
'bridge' |
|
bridge=self.bridge(d4) |
|
'decoder' |
|
up1=self.up1(bridge) |
|
up1 = torch.cat([up1, l4], axis=1) |
|
l5=self.convlayer5(up1) |
|
|
|
up2=self.up2(l5) |
|
up2 = torch.cat([up2, l3], axis=1) |
|
l6=self.convlayer6(up2) |
|
|
|
up3=self.up3(l6) |
|
up3= torch.cat([up3, l2], axis=1) |
|
l7=self.convlayer7(up3) |
|
|
|
up4=self.up4(l7) |
|
up4 = torch.cat([up4, l1], axis=1) |
|
l8=self.convlayer8(up4) |
|
out=self.outputs(l8) |
|
|
|
|
|
return out |
|
class spatialAttention(nn.Module): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
self.conv77=nn.Conv2d(2,1,kernel_size=7,padding=3) |
|
self.sig=nn.Sigmoid() |
|
def forward(self,input): |
|
x=torch.cat( (torch.max(input,1)[0].unsqueeze(1), torch.mean(input,1).unsqueeze(1)), dim=1 ) |
|
x=self.sig(self.conv77(x)) |
|
x=input*x |
|
return x |
|
class SAunet(nn.Module): |
|
def __init__(self,int,out) -> None: |
|
'int: represent the number of image channels' |
|
'out: number of the desired final channels' |
|
|
|
super().__init__() |
|
'encoder' |
|
self.convlayer1=double_conv(int,64) |
|
self.down1=nn.MaxPool2d((2, 2)) |
|
self.convlayer2=double_conv(64,128) |
|
self.down2=nn.MaxPool2d((2, 2)) |
|
self.convlayer3=double_conv(128,256) |
|
self.down3=nn.MaxPool2d((2, 2)) |
|
self.convlayer4=double_conv(256,512) |
|
self.down4=nn.MaxPool2d((2, 2)) |
|
|
|
'bridge' |
|
self.attmodule=spatialAttention() |
|
self.bridge1=nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1) |
|
self.bn1=nn.BatchNorm2d(1024) |
|
self.act1=nn.ReLU() |
|
self.bridge2=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1) |
|
self.bn2=nn.BatchNorm2d(1024) |
|
self.act2=nn.ReLU() |
|
'decoder' |
|
self.up1=upconv(1024,512) |
|
self.convlayer5=double_conv(1024,512) |
|
self.up2=upconv(512,256) |
|
self.convlayer6=double_conv(512,256) |
|
self.up3=upconv(256,128) |
|
self.convlayer7=double_conv(256,128) |
|
self.up4=upconv(128,64) |
|
self.convlayer8=double_conv(128,64) |
|
'output' |
|
self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0) |
|
self.sig=nn.Sigmoid() |
|
def forward(self,input): |
|
'encoder' |
|
l1=self.convlayer1(input) |
|
d1=self.down1(l1) |
|
l2=self.convlayer2(d1) |
|
d2=self.down2(l2) |
|
l3=self.convlayer3(d2) |
|
d3=self.down3(l3) |
|
l4=self.convlayer4(d3) |
|
d4=self.down4(l4) |
|
'bridge' |
|
b1=self.act1(self.bn1(self.bridge1(d4))) |
|
att=self.attmodule(b1) |
|
b2=self.act2(self.bn2(self.bridge2(att))) |
|
'decoder' |
|
up1=self.up1(b2) |
|
up1 = torch.cat([up1, l4], axis=1) |
|
l5=self.convlayer5(up1) |
|
|
|
up2=self.up2(l5) |
|
up2 = torch.cat([up2, l3], axis=1) |
|
l6=self.convlayer6(up2) |
|
|
|
up3=self.up3(l6) |
|
up3= torch.cat([up3, l2], axis=1) |
|
l7=self.convlayer7(up3) |
|
|
|
up4=self.up4(l7) |
|
up4 = torch.cat([up4, l1], axis=1) |
|
l8=self.convlayer8(up4) |
|
out=self.outputs(l8) |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|