import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor class DropBlock(nn.Module): def __init__(self, block_size: int = 5, p: float = 0.1): super().__init__() self.block_size = block_size self.p = p def calculate_gamma(self, x: Tensor) -> float: invalid = (1 - self.p) / (self.block_size ** 2) valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2) return invalid * valid def forward(self, x: Tensor) -> Tensor: N, C, H, W = x.size() if self.training: gamma = self.calculate_gamma(x) mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1) mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device)) mask = F.pad(mask, [self.block_size // 2] * 4, value=0) mask_block = 1 - F.max_pool2d( mask, kernel_size=(self.block_size, self.block_size), stride=(1, 1), padding=(self.block_size // 2, self.block_size // 2), ) x = mask_block * x * (mask_block.numel() / mask_block.sum()) return x class double_conv(nn.Module): def __init__(self,intc,outc): super().__init__() self.conv1=nn.Conv2d(intc,outc,kernel_size=3,padding=1,stride=1) self.drop1= DropBlock(7, 0.9) self.bn1=nn.BatchNorm2d(outc) self.relu1=nn.ReLU() self.conv2=nn.Conv2d(outc,outc,kernel_size=3,padding=1,stride=1) self.drop2=DropBlock(7, 0.9) self.bn2=nn.BatchNorm2d(outc) self.relu2=nn.ReLU() def forward(self,input): x=self.relu1(self.bn1(self.drop1(self.conv1(input)))) x=self.relu2(self.bn2(self.drop2(self.conv2(x)))) return x class upconv(nn.Module): def __init__(self,intc,outc) -> None: super().__init__() self.up=nn.ConvTranspose2d(intc, outc, kernel_size=2, stride=2, padding=0) # self.relu=nn.ReLU() def forward(self,input): x=self.up(input) #x=self.relu(self.up(input)) return x class unet(nn.Module): def __init__(self,int,out) -> None: 'int: represent the number of image channels' 'out: number of the desired final channels' super().__init__() 'encoder' self.convlayer1=double_conv(int,64) self.down1=nn.MaxPool2d((2, 2)) self.convlayer2=double_conv(64,128) self.down2=nn.MaxPool2d((2, 2)) self.convlayer3=double_conv(128,256) self.down3=nn.MaxPool2d((2, 2)) self.convlayer4=double_conv(256,512) self.down4=nn.MaxPool2d((2, 2)) 'bridge' self.bridge=double_conv(512,1024) 'decoder' self.up1=upconv(1024,512) self.convlayer5=double_conv(1024,512) self.up2=upconv(512,256) self.convlayer6=double_conv(512,256) self.up3=upconv(256,128) self.convlayer7=double_conv(256,128) self.up4=upconv(128,64) self.convlayer8=double_conv(128,64) 'output' self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0) self.sig=nn.Sigmoid() def forward(self,input): 'encoder' l1=self.convlayer1(input) d1=self.down1(l1) l2=self.convlayer2(d1) d2=self.down2(l2) l3=self.convlayer3(d2) d3=self.down3(l3) l4=self.convlayer4(d3) d4=self.down4(l4) 'bridge' bridge=self.bridge(d4) 'decoder' up1=self.up1(bridge) up1 = torch.cat([up1, l4], axis=1) l5=self.convlayer5(up1) up2=self.up2(l5) up2 = torch.cat([up2, l3], axis=1) l6=self.convlayer6(up2) up3=self.up3(l6) up3= torch.cat([up3, l2], axis=1) l7=self.convlayer7(up3) up4=self.up4(l7) up4 = torch.cat([up4, l1], axis=1) l8=self.convlayer8(up4) out=self.outputs(l8) #out=self.sig(self.outputs(l8)) return out class spatialAttention(nn.Module): def __init__(self) -> None: super().__init__() self.conv77=nn.Conv2d(2,1,kernel_size=7,padding=3) self.sig=nn.Sigmoid() def forward(self,input): x=torch.cat( (torch.max(input,1)[0].unsqueeze(1), torch.mean(input,1).unsqueeze(1)), dim=1 ) x=self.sig(self.conv77(x)) x=input*x return x class SAunet(nn.Module): def __init__(self,int,out) -> None: 'int: represent the number of image channels' 'out: number of the desired final channels' super().__init__() 'encoder' self.convlayer1=double_conv(int,64) self.down1=nn.MaxPool2d((2, 2)) self.convlayer2=double_conv(64,128) self.down2=nn.MaxPool2d((2, 2)) self.convlayer3=double_conv(128,256) self.down3=nn.MaxPool2d((2, 2)) self.convlayer4=double_conv(256,512) self.down4=nn.MaxPool2d((2, 2)) 'bridge' self.attmodule=spatialAttention() self.bridge1=nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1) self.bn1=nn.BatchNorm2d(1024) self.act1=nn.ReLU() self.bridge2=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1) self.bn2=nn.BatchNorm2d(1024) self.act2=nn.ReLU() 'decoder' self.up1=upconv(1024,512) self.convlayer5=double_conv(1024,512) self.up2=upconv(512,256) self.convlayer6=double_conv(512,256) self.up3=upconv(256,128) self.convlayer7=double_conv(256,128) self.up4=upconv(128,64) self.convlayer8=double_conv(128,64) 'output' self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0) self.sig=nn.Sigmoid() def forward(self,input): 'encoder' l1=self.convlayer1(input) d1=self.down1(l1) l2=self.convlayer2(d1) d2=self.down2(l2) l3=self.convlayer3(d2) d3=self.down3(l3) l4=self.convlayer4(d3) d4=self.down4(l4) 'bridge' b1=self.act1(self.bn1(self.bridge1(d4))) att=self.attmodule(b1) b2=self.act2(self.bn2(self.bridge2(att))) 'decoder' up1=self.up1(b2) up1 = torch.cat([up1, l4], axis=1) l5=self.convlayer5(up1) up2=self.up2(l5) up2 = torch.cat([up2, l3], axis=1) l6=self.convlayer6(up2) up3=self.up3(l6) up3= torch.cat([up3, l2], axis=1) l7=self.convlayer7(up3) up4=self.up4(l7) up4 = torch.cat([up4, l1], axis=1) l8=self.convlayer8(up4) out=self.outputs(l8) #out=self.sig(self.outputs(l8)) return out