rsortino commited on
Commit
80ab65e
1 Parent(s): e316221

First commit

Browse files
Files changed (5) hide show
  1. app.py +84 -0
  2. datasets/rg_masks.py +326 -0
  3. models/layers.py +86 -0
  4. models/tiramisu.py +121 -0
  5. requirements.txt +0 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from PIL import Image
3
+ import numpy as np
4
+ from datasets.rg_masks import get_transforms
5
+ from models import tiramisu
6
+ from torchvision.transforms.functional import to_pil_image
7
+ import torch
8
+ from astropy.io import fits
9
+
10
+
11
+ def load_fits(path):
12
+ array = fits.getdata(path).astype(np.float32)
13
+ array = np.expand_dims(array, 2)
14
+ return array
15
+
16
+ def load_image(path):
17
+ image = Image.open(path)
18
+ array = np.array(image)
19
+ array = np.expand_dims(array[:,:,0], 2)
20
+
21
+ return array
22
+
23
+ def load_weights(model, fpath, device="cuda"):
24
+ print("loading weights '{}'".format(fpath))
25
+ weights = torch.load(fpath, map_location=torch.device(device))
26
+ model.load_state_dict(weights['state_dict'])
27
+
28
+
29
+ # Function to apply color overlay to the input image based on the segmentation mask
30
+ def apply_color_overlay(input_image, segmentation_mask, alpha=0.5):
31
+ r = (segmentation_mask == 1).float()
32
+ g = (segmentation_mask == 2).float()
33
+ b = (segmentation_mask == 3).float()
34
+ overlay = torch.cat([r, g, b], dim=0)
35
+ overlay = to_pil_image(overlay)
36
+ output = Image.blend(input_image, overlay, alpha=alpha)
37
+ return output
38
+
39
+ # Streamlit app
40
+ def main():
41
+ st.title("Tiramisu for semantic segmentation of radio astronomy images")
42
+ st.write("Upload an image and see the segmentation result!")
43
+
44
+ uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "fits"])
45
+
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
+ model = tiramisu.FCDenseNet67(n_classes=4).to(device)
49
+ load_weights(model, "weights/real.pth")
50
+ model.eval()
51
+
52
+ st.markdown(
53
+ """
54
+ Category Legend:
55
+ - :blue[Extended]
56
+ - :green[Compact]
57
+ - :red[Spurious]
58
+ """
59
+ )
60
+ if uploaded_image is not None:
61
+ # Load the uploaded image
62
+ if uploaded_image.name.endswith(".fits"):
63
+ input_array = load_fits(uploaded_image)
64
+ else:
65
+ input_array = load_image(uploaded_image)
66
+
67
+ input_array = input_array.transpose(2,0,1)
68
+ transforms = get_transforms(input_array.shape[1])
69
+ image = transforms(input_array)
70
+ image = image.to(device)
71
+
72
+ with torch.no_grad():
73
+ output = model(image)
74
+ preds = output.argmax(1)
75
+
76
+ pil_image = to_pil_image(image[0])
77
+ # Apply color overlay to the input image
78
+ segmented_image = apply_color_overlay(pil_image, preds)
79
+
80
+ # Display the input image and the segmented output
81
+ st.image([pil_image, segmented_image], caption=["Input Image", "Segmented Output"], width=300)
82
+
83
+ if __name__ == "__main__":
84
+ main()
datasets/rg_masks.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import random
4
+ import warnings
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.data
11
+ import torchvision.transforms as T
12
+ import torchvision.transforms.functional as TF
13
+ from astropy.io import fits
14
+ from astropy.io.fits.verify import VerifyWarning
15
+ from einops import rearrange
16
+ from torch.utils.data import Dataset
17
+ from torchvision.transforms.functional import to_pil_image
18
+ from torchvision.utils import make_grid, save_image
19
+
20
+ warnings.simplefilter('ignore', category=VerifyWarning)
21
+ import warnings
22
+
23
+ import numpy as np
24
+ import torch
25
+ from astropy.stats import sigma_clip
26
+ from astropy.visualization import ZScaleInterval
27
+ from torch.utils.data import DataLoader
28
+
29
+ warnings.simplefilter('ignore', category=VerifyWarning)
30
+
31
+
32
+ CLASSES = ['background', 'spurious', 'compact', 'extended']
33
+ COLORS = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
34
+
35
+
36
+ def get_transforms(img_size):
37
+ return T.Compose([
38
+ RemoveNaNs(),
39
+ ZScale(),
40
+ SigmaClip(),
41
+ ToTensor(),
42
+ torch.nn.Tanh(),
43
+ MinMaxNormalize(),
44
+ Unsqueeze(),
45
+ T.Resize((img_size, img_size)),
46
+ RepeatChannels((3))
47
+ ])
48
+
49
+ class RemoveNaNs(object):
50
+ def __init__(self):
51
+ pass
52
+
53
+ def __call__(self, img):
54
+ img[np.isnan(img)] = 0
55
+ return img
56
+
57
+
58
+ class ZScale(object):
59
+ def __init__(self, contrast=0.15):
60
+ self.contrast = contrast
61
+
62
+ def __call__(self, img):
63
+ interval = ZScaleInterval(contrast=self.contrast)
64
+ min, max = interval.get_limits(img)
65
+
66
+ img = (img - min) / (max - min)
67
+ return img
68
+
69
+
70
+ class SigmaClip(object):
71
+ def __init__(self, sigma=3, masked=True):
72
+ self.sigma = sigma
73
+ self.masked = masked
74
+
75
+ def __call__(self, img):
76
+ img = sigma_clip(img, sigma=self.sigma, masked=self.masked)
77
+ return img
78
+
79
+
80
+ class MinMaxNormalize(object):
81
+ def __init__(self):
82
+ pass
83
+
84
+ def __call__(self, img):
85
+ img = (img - img.min()) / (img.max() - img.min())
86
+ return img
87
+
88
+
89
+ class ToTensor(object):
90
+ def __init__(self):
91
+ pass
92
+
93
+ def __call__(self, img):
94
+ return torch.tensor(img, dtype=torch.float32)
95
+
96
+ class RepeatChannels(object):
97
+ def __init__(self, ch):
98
+ self.ch = ch
99
+
100
+ def __call__(self, img):
101
+ return img.repeat(1, self.ch, 1, 1)
102
+
103
+ class FromNumpy(object):
104
+ def __init__(self):
105
+ pass
106
+
107
+ def __call__(self, img):
108
+ return torch.from_numpy(img.astype(np.float32)).type(torch.float32)
109
+
110
+ class Unsqueeze(object):
111
+ def __init__(self):
112
+ pass
113
+
114
+ def __call__(self, img):
115
+ return img.unsqueeze(0)
116
+
117
+
118
+ def mask_to_rgb(mask):
119
+ rgb_mask = torch.zeros_like(mask, device=mask.device).repeat(1, 3, 1, 1)
120
+ for i, c in enumerate(COLORS):
121
+ color_mask = torch.tensor(c, device=mask.device).unsqueeze(
122
+ 1).unsqueeze(2) * (mask == i)
123
+ rgb_mask += color_mask
124
+ return rgb_mask
125
+
126
+ def get_data_loader(dataset, batch_size, split="train"):
127
+ batch_size = batch_size
128
+ workers = min(8, batch_size)
129
+ is_train = split == "train"
130
+ return DataLoader(dataset, shuffle=is_train, batch_size=batch_size,
131
+ num_workers=workers, persistent_workers=True,
132
+ drop_last=is_train
133
+ )
134
+
135
+ def rgb_to_tensor(mask):
136
+ r,g,b = mask
137
+ r *= 1
138
+ g *= 2
139
+ b *= 3
140
+ mask, _ = torch.max(torch.stack([r,g,b]), dim=0, keepdim=True)
141
+ return mask
142
+
143
+
144
+ def rand_horizontal_flip(img, mask):
145
+ if random.random() < 0.5:
146
+ img = TF.hflip(img)
147
+ mask = TF.hflip(mask)
148
+ return img, mask
149
+
150
+
151
+ class RGDataset(Dataset):
152
+ def __init__(self, data_dir, img_paths, img_size=128):
153
+ super().__init__()
154
+ data_dir = Path(data_dir)
155
+ with open(img_paths) as f:
156
+ self.img_paths = f.read().splitlines()
157
+ self.img_paths = [data_dir / p for p in self.img_paths]
158
+
159
+ self.transforms = T.Compose([
160
+ RemoveNaNs(),
161
+ ZScale(),
162
+ SigmaClip(),
163
+ ToTensor(),
164
+ torch.nn.Tanh(),
165
+ MinMaxNormalize(),
166
+ # T.Resize((img_size),
167
+ # interpolation=T.InterpolationMode.NEAREST),
168
+ Unsqueeze(),
169
+ T.Resize((img_size, img_size)),
170
+
171
+ RepeatChannels((3))
172
+ ])
173
+ self.img_size = img_size
174
+
175
+ self.mask_transforms = T.Compose([
176
+ FromNumpy(),
177
+ Unsqueeze(),
178
+ T.Resize((img_size, img_size),
179
+ interpolation=T.InterpolationMode.NEAREST),
180
+ ])
181
+
182
+ def get_mask(self, img_path, type):
183
+ assert type in ["real", "synthetic"], f"Type {type} not supported"
184
+ if type == "real":
185
+ ann_path = str(img_path).replace(
186
+ 'imgs', 'masks').replace('.fits', '.json')
187
+ ann_dir = Path(ann_path).parent
188
+ ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
189
+ with open(ann_path) as j:
190
+ mask_info = json.load(j)
191
+
192
+ masks = []
193
+
194
+ for obj in mask_info['objs']:
195
+ seg_path = ann_dir / obj['mask']
196
+
197
+ mask = fits.getdata(seg_path)
198
+
199
+ mask = self.mask_transforms(mask.astype(np.float32))
200
+ masks.append(mask)
201
+ mask, _ = torch.max(torch.stack(masks), dim=0)
202
+
203
+ elif type == "synthetic":
204
+ mask_path = str(img_path).replace("gen_fits", "cond_fits")
205
+ mask = fits.getdata(mask_path)
206
+ mask = self.mask_transforms(mask)
207
+ mask = mask.squeeze()
208
+ if mask.shape[0] == 3:
209
+ mask = rgb_to_tensor(mask)
210
+ return mask
211
+
212
+
213
+ def __len__(self):
214
+ return len(self.img_paths)
215
+
216
+ def __getitem__(self, idx):
217
+ image_path = self.img_paths[idx]
218
+ img = fits.getdata(image_path)
219
+ img = self.transforms(img)
220
+
221
+ if "synthetic" in str(image_path):
222
+ mask = self.get_mask(image_path, type='synthetic')
223
+ else:
224
+ mask = self.get_mask(image_path, type='real')
225
+
226
+ # ann_path = str(image_path).replace(
227
+ # 'imgs', 'masks').replace('.fits', '.json')
228
+ # ann_dir = Path(ann_path).parent
229
+ # ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
230
+ # with open(ann_path) as j:
231
+ # mask_info = json.load(j)
232
+
233
+
234
+ # masks = []
235
+
236
+ # for obj in mask_info['objs']:
237
+ # seg_path = ann_dir / obj['mask']
238
+
239
+ # mask = fits.getdata(seg_path)
240
+
241
+ # mask = self.mask_transforms(mask.astype(np.float32))
242
+ # masks.append(mask)
243
+
244
+ # if 'bkg' in str(image_path):
245
+ # mask = torch.zeros_like(img)
246
+ # masks.append(mask)
247
+
248
+ # mask, _ = torch.max(torch.stack(masks), dim=0)
249
+ mask = mask.long()
250
+ return img.squeeze(), mask.squeeze()
251
+
252
+
253
+ class SyntheticRGDataset(Dataset):
254
+ def __init__(self, data_dir, img_paths, img_size=128):
255
+ super().__init__()
256
+ data_dir = Path(data_dir)
257
+ with open(img_paths) as f:
258
+ self.img_paths = f.read().splitlines()
259
+ self.img_paths = [data_dir / p for p in self.img_paths]
260
+
261
+
262
+
263
+ self.transforms = T.Compose([
264
+ RemoveNaNs(),
265
+ ZScale(),
266
+ SigmaClip(),
267
+ ToTensor(),
268
+ torch.nn.Tanh(),
269
+ MinMaxNormalize(),
270
+ # T.Resize((img_size),
271
+ # interpolation=T.InterpolationMode.NEAREST),
272
+ Unsqueeze(),
273
+ T.Resize((img_size, img_size)),
274
+
275
+ RepeatChannels((3))
276
+ ])
277
+ self.img_size = img_size
278
+
279
+ self.mask_transforms = T.Compose([
280
+ FromNumpy(),
281
+ Unsqueeze(),
282
+ T.Resize((img_size, img_size),
283
+ interpolation=T.InterpolationMode.NEAREST),
284
+ ])
285
+
286
+ def __len__(self):
287
+ return len(self.img_paths)
288
+
289
+ def __getitem__(self, idx):
290
+ image_path = self.img_paths[idx]
291
+ img = fits.getdata(image_path)
292
+ img = self.transforms(img)
293
+ img = img.squeeze()
294
+
295
+ mask_path = str(image_path).replace("gen_fits", "cond_fits")
296
+ mask = fits.getdata(mask_path)
297
+ mask = self.mask_transforms(mask)
298
+
299
+ img, mask = rand_horizontal_flip(img, mask)
300
+
301
+ mask = mask.squeeze().long()
302
+ return img, mask
303
+
304
+
305
+ if __name__ == '__main__':
306
+ rgtrain = SyntheticRGDataset('data/rg-dataset/data',
307
+ 'data/rg-dataset/val_w_bg.txt')
308
+ batch = next(iter(rgtrain))
309
+ image, mask, masked_image = batch
310
+ to_pil_image(image).save('image.png')
311
+ rgb_mask = mask_to_rgb(mask)[0]
312
+ to_pil_image(rgb_mask).save('mask.png')
313
+ to_pil_image(masked_image[0]).save('masked.png')
314
+
315
+ bs = 256
316
+
317
+ loader = torch.utils.data.DataLoader(
318
+ rgtrain, batch_size=bs, shuffle=False, num_workers=16)
319
+ for i, batch in enumerate(loader):
320
+ image, mask, masked_image = batch
321
+ rgb_mask = mask_to_rgb(mask)
322
+ nrow = int(math.sqrt(bs))
323
+ # nrow = bs // 2
324
+ grid = make_grid(rgb_mask, nrow=nrow, padding=0)
325
+ save_image(grid, f'mask_{nrow}x{nrow}.png')
326
+ break
models/layers.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class DenseLayer(nn.Sequential):
6
+ def __init__(self, in_channels, growth_rate):
7
+ super().__init__()
8
+ self.add_module('norm', nn.BatchNorm2d(in_channels))
9
+ self.add_module('relu', nn.ReLU(True))
10
+ self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3,
11
+ stride=1, padding=1, bias=True))
12
+ self.add_module('drop', nn.Dropout2d(0.2))
13
+
14
+ def forward(self, x):
15
+ return super().forward(x)
16
+
17
+
18
+ class DenseBlock(nn.Module):
19
+ def __init__(self, in_channels, growth_rate, n_layers, upsample=False):
20
+ super().__init__()
21
+ self.upsample = upsample
22
+ self.layers = nn.ModuleList([DenseLayer(
23
+ in_channels + i*growth_rate, growth_rate)
24
+ for i in range(n_layers)])
25
+
26
+ def forward(self, x):
27
+ if self.upsample:
28
+ new_features = []
29
+ # we pass all previous activations into each dense layer normally
30
+ # But we only store each dense layer's output in the new_features array
31
+ for layer in self.layers:
32
+ out = layer(x)
33
+ x = torch.cat([x, out], 1)
34
+ new_features.append(out)
35
+ return torch.cat(new_features, 1)
36
+ else:
37
+ for layer in self.layers:
38
+ out = layer(x)
39
+ x = torch.cat([x, out], 1) # 1 = channel axis
40
+ return x
41
+
42
+
43
+ class TransitionDown(nn.Sequential):
44
+ def __init__(self, in_channels):
45
+ super().__init__()
46
+ self.add_module('norm', nn.BatchNorm2d(num_features=in_channels))
47
+ self.add_module('relu', nn.ReLU(inplace=True))
48
+ self.add_module('conv', nn.Conv2d(in_channels, in_channels,
49
+ kernel_size=1, stride=1,
50
+ padding=0, bias=True))
51
+ self.add_module('drop', nn.Dropout2d(0.2))
52
+ self.add_module('maxpool', nn.MaxPool2d(2))
53
+
54
+ def forward(self, x):
55
+ return super().forward(x)
56
+
57
+
58
+ class TransitionUp(nn.Module):
59
+ def __init__(self, in_channels, out_channels):
60
+ super().__init__()
61
+ self.convTrans = nn.ConvTranspose2d(
62
+ in_channels=in_channels, out_channels=out_channels,
63
+ kernel_size=3, stride=2, padding=0, bias=True)
64
+
65
+ def forward(self, x, skip):
66
+ out = self.convTrans(x)
67
+ out = center_crop(out, skip.size(2), skip.size(3))
68
+ out = torch.cat([out, skip], 1)
69
+ return out
70
+
71
+
72
+ class Bottleneck(nn.Sequential):
73
+ def __init__(self, in_channels, growth_rate, n_layers):
74
+ super().__init__()
75
+ self.add_module('bottleneck', DenseBlock(
76
+ in_channels, growth_rate, n_layers, upsample=True))
77
+
78
+ def forward(self, x):
79
+ return super().forward(x)
80
+
81
+
82
+ def center_crop(layer, max_height, max_width):
83
+ _, _, h, w = layer.size()
84
+ xy1 = (w - max_width) // 2
85
+ xy2 = (h - max_height) // 2
86
+ return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]
models/tiramisu.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .layers import *
5
+
6
+
7
+ class FCDenseNet(nn.Module):
8
+ def __init__(self, in_channels=3, down_blocks=(5, 5, 5, 5, 5),
9
+ up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5,
10
+ growth_rate=16, out_chans_first_conv=48, n_classes=12):
11
+ super().__init__()
12
+ self.down_blocks = down_blocks
13
+ self.up_blocks = up_blocks
14
+ cur_channels_count = 0
15
+ skip_connection_channel_counts = []
16
+
17
+ ## First Convolution ##
18
+
19
+ self.add_module('firstconv', nn.Conv2d(in_channels=in_channels,
20
+ out_channels=out_chans_first_conv, kernel_size=3,
21
+ stride=1, padding=1, bias=True))
22
+ cur_channels_count = out_chans_first_conv
23
+
24
+ #####################
25
+ # Downsampling path #
26
+ #####################
27
+
28
+ self.denseBlocksDown = nn.ModuleList([])
29
+ self.transDownBlocks = nn.ModuleList([])
30
+ for i in range(len(down_blocks)):
31
+ self.denseBlocksDown.append(
32
+ DenseBlock(cur_channels_count, growth_rate, down_blocks[i]))
33
+ cur_channels_count += (growth_rate*down_blocks[i])
34
+ skip_connection_channel_counts.insert(0, cur_channels_count)
35
+ self.transDownBlocks.append(TransitionDown(cur_channels_count))
36
+
37
+ #####################
38
+ # Bottleneck #
39
+ #####################
40
+
41
+ self.add_module('bottleneck', Bottleneck(cur_channels_count,
42
+ growth_rate, bottleneck_layers))
43
+ prev_block_channels = growth_rate*bottleneck_layers
44
+ cur_channels_count += prev_block_channels
45
+
46
+ #######################
47
+ # Upsampling path #
48
+ #######################
49
+
50
+ self.transUpBlocks = nn.ModuleList([])
51
+ self.denseBlocksUp = nn.ModuleList([])
52
+ for i in range(len(up_blocks)-1):
53
+ self.transUpBlocks.append(TransitionUp(
54
+ prev_block_channels, prev_block_channels))
55
+ cur_channels_count = prev_block_channels + \
56
+ skip_connection_channel_counts[i]
57
+
58
+ self.denseBlocksUp.append(DenseBlock(
59
+ cur_channels_count, growth_rate, up_blocks[i],
60
+ upsample=True))
61
+ prev_block_channels = growth_rate*up_blocks[i]
62
+ cur_channels_count += prev_block_channels
63
+
64
+ ## Final DenseBlock ##
65
+
66
+ self.transUpBlocks.append(TransitionUp(
67
+ prev_block_channels, prev_block_channels))
68
+ cur_channels_count = prev_block_channels + \
69
+ skip_connection_channel_counts[-1]
70
+
71
+ self.denseBlocksUp.append(DenseBlock(
72
+ cur_channels_count, growth_rate, up_blocks[-1],
73
+ upsample=False))
74
+ cur_channels_count += growth_rate*up_blocks[-1]
75
+
76
+ ## Softmax ##
77
+
78
+ self.finalConv = nn.Conv2d(in_channels=cur_channels_count,
79
+ out_channels=n_classes, kernel_size=1, stride=1,
80
+ padding=0, bias=True)
81
+ self.softmax = nn.LogSoftmax(dim=1)
82
+
83
+ def forward(self, x):
84
+ out = self.firstconv(x)
85
+
86
+ skip_connections = []
87
+ for i in range(len(self.down_blocks)):
88
+ out = self.denseBlocksDown[i](out)
89
+ skip_connections.append(out)
90
+ out = self.transDownBlocks[i](out)
91
+
92
+ out = self.bottleneck(out)
93
+ for i in range(len(self.up_blocks)):
94
+ skip = skip_connections.pop()
95
+ out = self.transUpBlocks[i](out, skip)
96
+ out = self.denseBlocksUp[i](out)
97
+
98
+ out = self.finalConv(out)
99
+ out = self.softmax(out)
100
+ return out
101
+
102
+
103
+ def FCDenseNet57(n_classes):
104
+ return FCDenseNet(
105
+ in_channels=3, down_blocks=(4, 4, 4, 4, 4),
106
+ up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4,
107
+ growth_rate=12, out_chans_first_conv=48, n_classes=n_classes)
108
+
109
+
110
+ def FCDenseNet67(n_classes):
111
+ return FCDenseNet(
112
+ in_channels=3, down_blocks=(5, 5, 5, 5, 5),
113
+ up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5,
114
+ growth_rate=16, out_chans_first_conv=48, n_classes=n_classes)
115
+
116
+
117
+ def FCDenseNet103(n_classes):
118
+ return FCDenseNet(
119
+ in_channels=3, down_blocks=(4, 5, 7, 10, 12),
120
+ up_blocks=(12, 10, 7, 5, 4), bottleneck_layers=15,
121
+ growth_rate=16, out_chans_first_conv=48, n_classes=n_classes)
requirements.txt ADDED
Binary file (60 Bytes). View file