yasinelh commited on
Commit
8a33342
1 Parent(s): 22d9822

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +93 -0
  2. models.py +210 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import cv2
4
+ import numpy as np
5
+ import time
6
+ import models
7
+ import torch
8
+
9
+ from torchvision import transforms
10
+ from torchvision import transforms
11
+
12
+ def load_model(path, model):
13
+ model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
14
+ return model
15
+
16
+ def predict(img):
17
+ model = models.unet(3, 1)
18
+ model = load_model('model.pth',model)
19
+
20
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
21
+ img = cv2.resize(img, (512, 512))
22
+ convert_tensor = transforms.ToTensor()
23
+ img = convert_tensor(img).float()
24
+ img = normalize(img)
25
+ img = torch.unsqueeze(img, dim=0)
26
+
27
+ output = model(img)
28
+ result = torch.sigmoid(output)
29
+
30
+ threshold = 0.5
31
+ result = (result >= threshold).float()
32
+ prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
33
+ # Convert tensor to a numpy array
34
+ prediction_array = prediction.numpy()
35
+ # Rescale values to the range [0, 255]
36
+ prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
37
+ cv2.imwrite("test.png",prediction_array)
38
+ return prediction_array
39
+
40
+ def predicjt(img):
41
+ model1 = models.SAunet(3, 1)
42
+ model1 = load_model('saunet.pth',model1)
43
+
44
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
45
+ img = cv2.resize(img, (512, 512))
46
+ convert_tensor = transforms.ToTensor()
47
+ img = convert_tensor(img).float()
48
+ img = normalize(img)
49
+ img = torch.unsqueeze(img, dim=0)
50
+
51
+ output = model1(img)
52
+ result = torch.sigmoid(output)
53
+
54
+ threshold = 0.5
55
+ result = (result >= threshold).float()
56
+ prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
57
+ # Convert tensor to a numpy array
58
+ prediction_array = prediction.numpy()
59
+ # Rescale values to the range [0, 255]
60
+ prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
61
+ cv2.imwrite("test1.png",prediction_array)
62
+ return prediction_array
63
+ def main():
64
+ st.title("Image Segmentation Demo")
65
+
66
+ # Predefined list of image names
67
+ image_names = ["01_test.tif", "02_test.tif", "03_test.tif"]
68
+
69
+ # Create a selection box for the images
70
+ selected_image_name = st.selectbox("Select an Image", image_names)
71
+
72
+ # Load the selected image
73
+ selected_image = cv2.imread(selected_image_name)
74
+
75
+ # Display the selected image
76
+ st.image(selected_image, channels="RGB")
77
+
78
+ # Create a button for segmentation
79
+ if st.button("Segment"):
80
+ # Perform segmentation on the selected image
81
+ segmented_image = predict(selected_image)
82
+ segmented_image1 = predicjt(selected_image)
83
+
84
+
85
+ # Display the segmented image
86
+ st.image(segmented_image, channels="RGB",caption='U-Net segmentation')
87
+ st.image(segmented_image1, channels="RGB",caption='Spatial Attention U-Net segmentation ')
88
+
89
+ # Function to perform segmentation on the selected image
90
+
91
+
92
+ if __name__ == "__main__":
93
+ main()
models.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+
6
+
7
+ class DropBlock(nn.Module):
8
+ def __init__(self, block_size: int = 5, p: float = 0.1):
9
+ super().__init__()
10
+ self.block_size = block_size
11
+ self.p = p
12
+
13
+ def calculate_gamma(self, x: Tensor) -> float:
14
+
15
+
16
+ invalid = (1 - self.p) / (self.block_size ** 2)
17
+ valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2)
18
+ return invalid * valid
19
+
20
+ def forward(self, x: Tensor) -> Tensor:
21
+ N, C, H, W = x.size()
22
+ if self.training:
23
+ gamma = self.calculate_gamma(x)
24
+ mask_shape = (N, C, H - self.block_size + 1, W - self.block_size + 1)
25
+ mask = torch.bernoulli(torch.full(mask_shape, gamma, device=x.device))
26
+ mask = F.pad(mask, [self.block_size // 2] * 4, value=0)
27
+ mask_block = 1 - F.max_pool2d(
28
+ mask,
29
+ kernel_size=(self.block_size, self.block_size),
30
+ stride=(1, 1),
31
+ padding=(self.block_size // 2, self.block_size // 2),
32
+ )
33
+ x = mask_block * x * (mask_block.numel() / mask_block.sum())
34
+ return x
35
+
36
+
37
+ class double_conv(nn.Module):
38
+ def __init__(self,intc,outc):
39
+ super().__init__()
40
+ self.conv1=nn.Conv2d(intc,outc,kernel_size=3,padding=1,stride=1)
41
+ self.drop1= DropBlock(7, 0.9)
42
+ self.bn1=nn.BatchNorm2d(outc)
43
+ self.relu1=nn.ReLU()
44
+ self.conv2=nn.Conv2d(outc,outc,kernel_size=3,padding=1,stride=1)
45
+ self.drop2=DropBlock(7, 0.9)
46
+ self.bn2=nn.BatchNorm2d(outc)
47
+ self.relu2=nn.ReLU()
48
+
49
+ def forward(self,input):
50
+ x=self.relu1(self.bn1(self.drop1(self.conv1(input))))
51
+ x=self.relu2(self.bn2(self.drop2(self.conv2(x))))
52
+
53
+ return x
54
+ class upconv(nn.Module):
55
+ def __init__(self,intc,outc) -> None:
56
+ super().__init__()
57
+ self.up=nn.ConvTranspose2d(intc, outc, kernel_size=2, stride=2, padding=0)
58
+ # self.relu=nn.ReLU()
59
+
60
+ def forward(self,input):
61
+ x=self.up(input)
62
+ #x=self.relu(self.up(input))
63
+ return x
64
+ class unet(nn.Module):
65
+ def __init__(self,int,out) -> None:
66
+ 'int: represent the number of image channels'
67
+ 'out: number of the desired final channels'
68
+
69
+ super().__init__()
70
+ 'encoder'
71
+ self.convlayer1=double_conv(int,64)
72
+ self.down1=nn.MaxPool2d((2, 2))
73
+ self.convlayer2=double_conv(64,128)
74
+ self.down2=nn.MaxPool2d((2, 2))
75
+ self.convlayer3=double_conv(128,256)
76
+ self.down3=nn.MaxPool2d((2, 2))
77
+ self.convlayer4=double_conv(256,512)
78
+ self.down4=nn.MaxPool2d((2, 2))
79
+
80
+ 'bridge'
81
+ self.bridge=double_conv(512,1024)
82
+ 'decoder'
83
+ self.up1=upconv(1024,512)
84
+ self.convlayer5=double_conv(1024,512)
85
+ self.up2=upconv(512,256)
86
+ self.convlayer6=double_conv(512,256)
87
+ self.up3=upconv(256,128)
88
+ self.convlayer7=double_conv(256,128)
89
+ self.up4=upconv(128,64)
90
+ self.convlayer8=double_conv(128,64)
91
+ 'output'
92
+ self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)
93
+ self.sig=nn.Sigmoid()
94
+ def forward(self,input):
95
+ 'encoder'
96
+ l1=self.convlayer1(input)
97
+ d1=self.down1(l1)
98
+ l2=self.convlayer2(d1)
99
+ d2=self.down2(l2)
100
+ l3=self.convlayer3(d2)
101
+ d3=self.down3(l3)
102
+ l4=self.convlayer4(d3)
103
+ d4=self.down4(l4)
104
+ 'bridge'
105
+ bridge=self.bridge(d4)
106
+ 'decoder'
107
+ up1=self.up1(bridge)
108
+ up1 = torch.cat([up1, l4], axis=1)
109
+ l5=self.convlayer5(up1)
110
+
111
+ up2=self.up2(l5)
112
+ up2 = torch.cat([up2, l3], axis=1)
113
+ l6=self.convlayer6(up2)
114
+
115
+ up3=self.up3(l6)
116
+ up3= torch.cat([up3, l2], axis=1)
117
+ l7=self.convlayer7(up3)
118
+
119
+ up4=self.up4(l7)
120
+ up4 = torch.cat([up4, l1], axis=1)
121
+ l8=self.convlayer8(up4)
122
+ out=self.outputs(l8)
123
+
124
+ #out=self.sig(self.outputs(l8))
125
+ return out
126
+ class spatialAttention(nn.Module):
127
+ def __init__(self) -> None:
128
+ super().__init__()
129
+
130
+ self.conv77=nn.Conv2d(2,1,kernel_size=7,padding=3)
131
+ self.sig=nn.Sigmoid()
132
+ def forward(self,input):
133
+ x=torch.cat( (torch.max(input,1)[0].unsqueeze(1), torch.mean(input,1).unsqueeze(1)), dim=1 )
134
+ x=self.sig(self.conv77(x))
135
+ x=input*x
136
+ return x
137
+ class SAunet(nn.Module):
138
+ def __init__(self,int,out) -> None:
139
+ 'int: represent the number of image channels'
140
+ 'out: number of the desired final channels'
141
+
142
+ super().__init__()
143
+ 'encoder'
144
+ self.convlayer1=double_conv(int,64)
145
+ self.down1=nn.MaxPool2d((2, 2))
146
+ self.convlayer2=double_conv(64,128)
147
+ self.down2=nn.MaxPool2d((2, 2))
148
+ self.convlayer3=double_conv(128,256)
149
+ self.down3=nn.MaxPool2d((2, 2))
150
+ self.convlayer4=double_conv(256,512)
151
+ self.down4=nn.MaxPool2d((2, 2))
152
+
153
+ 'bridge'
154
+ self.attmodule=spatialAttention()
155
+ self.bridge1=nn.Conv2d(512,1024,kernel_size=3,stride=1,padding=1)
156
+ self.bn1=nn.BatchNorm2d(1024)
157
+ self.act1=nn.ReLU()
158
+ self.bridge2=nn.Conv2d(1024,1024,kernel_size=3,stride=1,padding=1)
159
+ self.bn2=nn.BatchNorm2d(1024)
160
+ self.act2=nn.ReLU()
161
+ 'decoder'
162
+ self.up1=upconv(1024,512)
163
+ self.convlayer5=double_conv(1024,512)
164
+ self.up2=upconv(512,256)
165
+ self.convlayer6=double_conv(512,256)
166
+ self.up3=upconv(256,128)
167
+ self.convlayer7=double_conv(256,128)
168
+ self.up4=upconv(128,64)
169
+ self.convlayer8=double_conv(128,64)
170
+ 'output'
171
+ self.outputs = nn.Conv2d(64, out, kernel_size=1, padding=0)
172
+ self.sig=nn.Sigmoid()
173
+ def forward(self,input):
174
+ 'encoder'
175
+ l1=self.convlayer1(input)
176
+ d1=self.down1(l1)
177
+ l2=self.convlayer2(d1)
178
+ d2=self.down2(l2)
179
+ l3=self.convlayer3(d2)
180
+ d3=self.down3(l3)
181
+ l4=self.convlayer4(d3)
182
+ d4=self.down4(l4)
183
+ 'bridge'
184
+ b1=self.act1(self.bn1(self.bridge1(d4)))
185
+ att=self.attmodule(b1)
186
+ b2=self.act2(self.bn2(self.bridge2(att)))
187
+ 'decoder'
188
+ up1=self.up1(b2)
189
+ up1 = torch.cat([up1, l4], axis=1)
190
+ l5=self.convlayer5(up1)
191
+
192
+ up2=self.up2(l5)
193
+ up2 = torch.cat([up2, l3], axis=1)
194
+ l6=self.convlayer6(up2)
195
+
196
+ up3=self.up3(l6)
197
+ up3= torch.cat([up3, l2], axis=1)
198
+ l7=self.convlayer7(up3)
199
+
200
+ up4=self.up4(l7)
201
+ up4 = torch.cat([up4, l1], axis=1)
202
+ l8=self.convlayer8(up4)
203
+ out=self.outputs(l8)
204
+
205
+ #out=self.sig(self.outputs(l8))
206
+ return out
207
+
208
+
209
+
210
+