Spaces:
Sleeping
Sleeping
First commit
Browse files- app.py +84 -0
- datasets/rg_masks.py +326 -0
- models/layers.py +86 -0
- models/tiramisu.py +121 -0
- requirements.txt +0 -0
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
from datasets.rg_masks import get_transforms
|
5 |
+
from models import tiramisu
|
6 |
+
from torchvision.transforms.functional import to_pil_image
|
7 |
+
import torch
|
8 |
+
from astropy.io import fits
|
9 |
+
|
10 |
+
|
11 |
+
def load_fits(path):
|
12 |
+
array = fits.getdata(path).astype(np.float32)
|
13 |
+
array = np.expand_dims(array, 2)
|
14 |
+
return array
|
15 |
+
|
16 |
+
def load_image(path):
|
17 |
+
image = Image.open(path)
|
18 |
+
array = np.array(image)
|
19 |
+
array = np.expand_dims(array[:,:,0], 2)
|
20 |
+
|
21 |
+
return array
|
22 |
+
|
23 |
+
def load_weights(model, fpath, device="cuda"):
|
24 |
+
print("loading weights '{}'".format(fpath))
|
25 |
+
weights = torch.load(fpath, map_location=torch.device(device))
|
26 |
+
model.load_state_dict(weights['state_dict'])
|
27 |
+
|
28 |
+
|
29 |
+
# Function to apply color overlay to the input image based on the segmentation mask
|
30 |
+
def apply_color_overlay(input_image, segmentation_mask, alpha=0.5):
|
31 |
+
r = (segmentation_mask == 1).float()
|
32 |
+
g = (segmentation_mask == 2).float()
|
33 |
+
b = (segmentation_mask == 3).float()
|
34 |
+
overlay = torch.cat([r, g, b], dim=0)
|
35 |
+
overlay = to_pil_image(overlay)
|
36 |
+
output = Image.blend(input_image, overlay, alpha=alpha)
|
37 |
+
return output
|
38 |
+
|
39 |
+
# Streamlit app
|
40 |
+
def main():
|
41 |
+
st.title("Tiramisu for semantic segmentation of radio astronomy images")
|
42 |
+
st.write("Upload an image and see the segmentation result!")
|
43 |
+
|
44 |
+
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png", "fits"])
|
45 |
+
|
46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
47 |
+
|
48 |
+
model = tiramisu.FCDenseNet67(n_classes=4).to(device)
|
49 |
+
load_weights(model, "weights/real.pth")
|
50 |
+
model.eval()
|
51 |
+
|
52 |
+
st.markdown(
|
53 |
+
"""
|
54 |
+
Category Legend:
|
55 |
+
- :blue[Extended]
|
56 |
+
- :green[Compact]
|
57 |
+
- :red[Spurious]
|
58 |
+
"""
|
59 |
+
)
|
60 |
+
if uploaded_image is not None:
|
61 |
+
# Load the uploaded image
|
62 |
+
if uploaded_image.name.endswith(".fits"):
|
63 |
+
input_array = load_fits(uploaded_image)
|
64 |
+
else:
|
65 |
+
input_array = load_image(uploaded_image)
|
66 |
+
|
67 |
+
input_array = input_array.transpose(2,0,1)
|
68 |
+
transforms = get_transforms(input_array.shape[1])
|
69 |
+
image = transforms(input_array)
|
70 |
+
image = image.to(device)
|
71 |
+
|
72 |
+
with torch.no_grad():
|
73 |
+
output = model(image)
|
74 |
+
preds = output.argmax(1)
|
75 |
+
|
76 |
+
pil_image = to_pil_image(image[0])
|
77 |
+
# Apply color overlay to the input image
|
78 |
+
segmented_image = apply_color_overlay(pil_image, preds)
|
79 |
+
|
80 |
+
# Display the input image and the segmented output
|
81 |
+
st.image([pil_image, segmented_image], caption=["Input Image", "Segmented Output"], width=300)
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
main()
|
datasets/rg_masks.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import math
|
3 |
+
import random
|
4 |
+
import warnings
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch.utils.data
|
11 |
+
import torchvision.transforms as T
|
12 |
+
import torchvision.transforms.functional as TF
|
13 |
+
from astropy.io import fits
|
14 |
+
from astropy.io.fits.verify import VerifyWarning
|
15 |
+
from einops import rearrange
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torchvision.transforms.functional import to_pil_image
|
18 |
+
from torchvision.utils import make_grid, save_image
|
19 |
+
|
20 |
+
warnings.simplefilter('ignore', category=VerifyWarning)
|
21 |
+
import warnings
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from astropy.stats import sigma_clip
|
26 |
+
from astropy.visualization import ZScaleInterval
|
27 |
+
from torch.utils.data import DataLoader
|
28 |
+
|
29 |
+
warnings.simplefilter('ignore', category=VerifyWarning)
|
30 |
+
|
31 |
+
|
32 |
+
CLASSES = ['background', 'spurious', 'compact', 'extended']
|
33 |
+
COLORS = [[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]
|
34 |
+
|
35 |
+
|
36 |
+
def get_transforms(img_size):
|
37 |
+
return T.Compose([
|
38 |
+
RemoveNaNs(),
|
39 |
+
ZScale(),
|
40 |
+
SigmaClip(),
|
41 |
+
ToTensor(),
|
42 |
+
torch.nn.Tanh(),
|
43 |
+
MinMaxNormalize(),
|
44 |
+
Unsqueeze(),
|
45 |
+
T.Resize((img_size, img_size)),
|
46 |
+
RepeatChannels((3))
|
47 |
+
])
|
48 |
+
|
49 |
+
class RemoveNaNs(object):
|
50 |
+
def __init__(self):
|
51 |
+
pass
|
52 |
+
|
53 |
+
def __call__(self, img):
|
54 |
+
img[np.isnan(img)] = 0
|
55 |
+
return img
|
56 |
+
|
57 |
+
|
58 |
+
class ZScale(object):
|
59 |
+
def __init__(self, contrast=0.15):
|
60 |
+
self.contrast = contrast
|
61 |
+
|
62 |
+
def __call__(self, img):
|
63 |
+
interval = ZScaleInterval(contrast=self.contrast)
|
64 |
+
min, max = interval.get_limits(img)
|
65 |
+
|
66 |
+
img = (img - min) / (max - min)
|
67 |
+
return img
|
68 |
+
|
69 |
+
|
70 |
+
class SigmaClip(object):
|
71 |
+
def __init__(self, sigma=3, masked=True):
|
72 |
+
self.sigma = sigma
|
73 |
+
self.masked = masked
|
74 |
+
|
75 |
+
def __call__(self, img):
|
76 |
+
img = sigma_clip(img, sigma=self.sigma, masked=self.masked)
|
77 |
+
return img
|
78 |
+
|
79 |
+
|
80 |
+
class MinMaxNormalize(object):
|
81 |
+
def __init__(self):
|
82 |
+
pass
|
83 |
+
|
84 |
+
def __call__(self, img):
|
85 |
+
img = (img - img.min()) / (img.max() - img.min())
|
86 |
+
return img
|
87 |
+
|
88 |
+
|
89 |
+
class ToTensor(object):
|
90 |
+
def __init__(self):
|
91 |
+
pass
|
92 |
+
|
93 |
+
def __call__(self, img):
|
94 |
+
return torch.tensor(img, dtype=torch.float32)
|
95 |
+
|
96 |
+
class RepeatChannels(object):
|
97 |
+
def __init__(self, ch):
|
98 |
+
self.ch = ch
|
99 |
+
|
100 |
+
def __call__(self, img):
|
101 |
+
return img.repeat(1, self.ch, 1, 1)
|
102 |
+
|
103 |
+
class FromNumpy(object):
|
104 |
+
def __init__(self):
|
105 |
+
pass
|
106 |
+
|
107 |
+
def __call__(self, img):
|
108 |
+
return torch.from_numpy(img.astype(np.float32)).type(torch.float32)
|
109 |
+
|
110 |
+
class Unsqueeze(object):
|
111 |
+
def __init__(self):
|
112 |
+
pass
|
113 |
+
|
114 |
+
def __call__(self, img):
|
115 |
+
return img.unsqueeze(0)
|
116 |
+
|
117 |
+
|
118 |
+
def mask_to_rgb(mask):
|
119 |
+
rgb_mask = torch.zeros_like(mask, device=mask.device).repeat(1, 3, 1, 1)
|
120 |
+
for i, c in enumerate(COLORS):
|
121 |
+
color_mask = torch.tensor(c, device=mask.device).unsqueeze(
|
122 |
+
1).unsqueeze(2) * (mask == i)
|
123 |
+
rgb_mask += color_mask
|
124 |
+
return rgb_mask
|
125 |
+
|
126 |
+
def get_data_loader(dataset, batch_size, split="train"):
|
127 |
+
batch_size = batch_size
|
128 |
+
workers = min(8, batch_size)
|
129 |
+
is_train = split == "train"
|
130 |
+
return DataLoader(dataset, shuffle=is_train, batch_size=batch_size,
|
131 |
+
num_workers=workers, persistent_workers=True,
|
132 |
+
drop_last=is_train
|
133 |
+
)
|
134 |
+
|
135 |
+
def rgb_to_tensor(mask):
|
136 |
+
r,g,b = mask
|
137 |
+
r *= 1
|
138 |
+
g *= 2
|
139 |
+
b *= 3
|
140 |
+
mask, _ = torch.max(torch.stack([r,g,b]), dim=0, keepdim=True)
|
141 |
+
return mask
|
142 |
+
|
143 |
+
|
144 |
+
def rand_horizontal_flip(img, mask):
|
145 |
+
if random.random() < 0.5:
|
146 |
+
img = TF.hflip(img)
|
147 |
+
mask = TF.hflip(mask)
|
148 |
+
return img, mask
|
149 |
+
|
150 |
+
|
151 |
+
class RGDataset(Dataset):
|
152 |
+
def __init__(self, data_dir, img_paths, img_size=128):
|
153 |
+
super().__init__()
|
154 |
+
data_dir = Path(data_dir)
|
155 |
+
with open(img_paths) as f:
|
156 |
+
self.img_paths = f.read().splitlines()
|
157 |
+
self.img_paths = [data_dir / p for p in self.img_paths]
|
158 |
+
|
159 |
+
self.transforms = T.Compose([
|
160 |
+
RemoveNaNs(),
|
161 |
+
ZScale(),
|
162 |
+
SigmaClip(),
|
163 |
+
ToTensor(),
|
164 |
+
torch.nn.Tanh(),
|
165 |
+
MinMaxNormalize(),
|
166 |
+
# T.Resize((img_size),
|
167 |
+
# interpolation=T.InterpolationMode.NEAREST),
|
168 |
+
Unsqueeze(),
|
169 |
+
T.Resize((img_size, img_size)),
|
170 |
+
|
171 |
+
RepeatChannels((3))
|
172 |
+
])
|
173 |
+
self.img_size = img_size
|
174 |
+
|
175 |
+
self.mask_transforms = T.Compose([
|
176 |
+
FromNumpy(),
|
177 |
+
Unsqueeze(),
|
178 |
+
T.Resize((img_size, img_size),
|
179 |
+
interpolation=T.InterpolationMode.NEAREST),
|
180 |
+
])
|
181 |
+
|
182 |
+
def get_mask(self, img_path, type):
|
183 |
+
assert type in ["real", "synthetic"], f"Type {type} not supported"
|
184 |
+
if type == "real":
|
185 |
+
ann_path = str(img_path).replace(
|
186 |
+
'imgs', 'masks').replace('.fits', '.json')
|
187 |
+
ann_dir = Path(ann_path).parent
|
188 |
+
ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
|
189 |
+
with open(ann_path) as j:
|
190 |
+
mask_info = json.load(j)
|
191 |
+
|
192 |
+
masks = []
|
193 |
+
|
194 |
+
for obj in mask_info['objs']:
|
195 |
+
seg_path = ann_dir / obj['mask']
|
196 |
+
|
197 |
+
mask = fits.getdata(seg_path)
|
198 |
+
|
199 |
+
mask = self.mask_transforms(mask.astype(np.float32))
|
200 |
+
masks.append(mask)
|
201 |
+
mask, _ = torch.max(torch.stack(masks), dim=0)
|
202 |
+
|
203 |
+
elif type == "synthetic":
|
204 |
+
mask_path = str(img_path).replace("gen_fits", "cond_fits")
|
205 |
+
mask = fits.getdata(mask_path)
|
206 |
+
mask = self.mask_transforms(mask)
|
207 |
+
mask = mask.squeeze()
|
208 |
+
if mask.shape[0] == 3:
|
209 |
+
mask = rgb_to_tensor(mask)
|
210 |
+
return mask
|
211 |
+
|
212 |
+
|
213 |
+
def __len__(self):
|
214 |
+
return len(self.img_paths)
|
215 |
+
|
216 |
+
def __getitem__(self, idx):
|
217 |
+
image_path = self.img_paths[idx]
|
218 |
+
img = fits.getdata(image_path)
|
219 |
+
img = self.transforms(img)
|
220 |
+
|
221 |
+
if "synthetic" in str(image_path):
|
222 |
+
mask = self.get_mask(image_path, type='synthetic')
|
223 |
+
else:
|
224 |
+
mask = self.get_mask(image_path, type='real')
|
225 |
+
|
226 |
+
# ann_path = str(image_path).replace(
|
227 |
+
# 'imgs', 'masks').replace('.fits', '.json')
|
228 |
+
# ann_dir = Path(ann_path).parent
|
229 |
+
# ann_path = ann_dir / f'mask_{ann_path.split("/")[-1]}'
|
230 |
+
# with open(ann_path) as j:
|
231 |
+
# mask_info = json.load(j)
|
232 |
+
|
233 |
+
|
234 |
+
# masks = []
|
235 |
+
|
236 |
+
# for obj in mask_info['objs']:
|
237 |
+
# seg_path = ann_dir / obj['mask']
|
238 |
+
|
239 |
+
# mask = fits.getdata(seg_path)
|
240 |
+
|
241 |
+
# mask = self.mask_transforms(mask.astype(np.float32))
|
242 |
+
# masks.append(mask)
|
243 |
+
|
244 |
+
# if 'bkg' in str(image_path):
|
245 |
+
# mask = torch.zeros_like(img)
|
246 |
+
# masks.append(mask)
|
247 |
+
|
248 |
+
# mask, _ = torch.max(torch.stack(masks), dim=0)
|
249 |
+
mask = mask.long()
|
250 |
+
return img.squeeze(), mask.squeeze()
|
251 |
+
|
252 |
+
|
253 |
+
class SyntheticRGDataset(Dataset):
|
254 |
+
def __init__(self, data_dir, img_paths, img_size=128):
|
255 |
+
super().__init__()
|
256 |
+
data_dir = Path(data_dir)
|
257 |
+
with open(img_paths) as f:
|
258 |
+
self.img_paths = f.read().splitlines()
|
259 |
+
self.img_paths = [data_dir / p for p in self.img_paths]
|
260 |
+
|
261 |
+
|
262 |
+
|
263 |
+
self.transforms = T.Compose([
|
264 |
+
RemoveNaNs(),
|
265 |
+
ZScale(),
|
266 |
+
SigmaClip(),
|
267 |
+
ToTensor(),
|
268 |
+
torch.nn.Tanh(),
|
269 |
+
MinMaxNormalize(),
|
270 |
+
# T.Resize((img_size),
|
271 |
+
# interpolation=T.InterpolationMode.NEAREST),
|
272 |
+
Unsqueeze(),
|
273 |
+
T.Resize((img_size, img_size)),
|
274 |
+
|
275 |
+
RepeatChannels((3))
|
276 |
+
])
|
277 |
+
self.img_size = img_size
|
278 |
+
|
279 |
+
self.mask_transforms = T.Compose([
|
280 |
+
FromNumpy(),
|
281 |
+
Unsqueeze(),
|
282 |
+
T.Resize((img_size, img_size),
|
283 |
+
interpolation=T.InterpolationMode.NEAREST),
|
284 |
+
])
|
285 |
+
|
286 |
+
def __len__(self):
|
287 |
+
return len(self.img_paths)
|
288 |
+
|
289 |
+
def __getitem__(self, idx):
|
290 |
+
image_path = self.img_paths[idx]
|
291 |
+
img = fits.getdata(image_path)
|
292 |
+
img = self.transforms(img)
|
293 |
+
img = img.squeeze()
|
294 |
+
|
295 |
+
mask_path = str(image_path).replace("gen_fits", "cond_fits")
|
296 |
+
mask = fits.getdata(mask_path)
|
297 |
+
mask = self.mask_transforms(mask)
|
298 |
+
|
299 |
+
img, mask = rand_horizontal_flip(img, mask)
|
300 |
+
|
301 |
+
mask = mask.squeeze().long()
|
302 |
+
return img, mask
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == '__main__':
|
306 |
+
rgtrain = SyntheticRGDataset('data/rg-dataset/data',
|
307 |
+
'data/rg-dataset/val_w_bg.txt')
|
308 |
+
batch = next(iter(rgtrain))
|
309 |
+
image, mask, masked_image = batch
|
310 |
+
to_pil_image(image).save('image.png')
|
311 |
+
rgb_mask = mask_to_rgb(mask)[0]
|
312 |
+
to_pil_image(rgb_mask).save('mask.png')
|
313 |
+
to_pil_image(masked_image[0]).save('masked.png')
|
314 |
+
|
315 |
+
bs = 256
|
316 |
+
|
317 |
+
loader = torch.utils.data.DataLoader(
|
318 |
+
rgtrain, batch_size=bs, shuffle=False, num_workers=16)
|
319 |
+
for i, batch in enumerate(loader):
|
320 |
+
image, mask, masked_image = batch
|
321 |
+
rgb_mask = mask_to_rgb(mask)
|
322 |
+
nrow = int(math.sqrt(bs))
|
323 |
+
# nrow = bs // 2
|
324 |
+
grid = make_grid(rgb_mask, nrow=nrow, padding=0)
|
325 |
+
save_image(grid, f'mask_{nrow}x{nrow}.png')
|
326 |
+
break
|
models/layers.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class DenseLayer(nn.Sequential):
|
6 |
+
def __init__(self, in_channels, growth_rate):
|
7 |
+
super().__init__()
|
8 |
+
self.add_module('norm', nn.BatchNorm2d(in_channels))
|
9 |
+
self.add_module('relu', nn.ReLU(True))
|
10 |
+
self.add_module('conv', nn.Conv2d(in_channels, growth_rate, kernel_size=3,
|
11 |
+
stride=1, padding=1, bias=True))
|
12 |
+
self.add_module('drop', nn.Dropout2d(0.2))
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return super().forward(x)
|
16 |
+
|
17 |
+
|
18 |
+
class DenseBlock(nn.Module):
|
19 |
+
def __init__(self, in_channels, growth_rate, n_layers, upsample=False):
|
20 |
+
super().__init__()
|
21 |
+
self.upsample = upsample
|
22 |
+
self.layers = nn.ModuleList([DenseLayer(
|
23 |
+
in_channels + i*growth_rate, growth_rate)
|
24 |
+
for i in range(n_layers)])
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
if self.upsample:
|
28 |
+
new_features = []
|
29 |
+
# we pass all previous activations into each dense layer normally
|
30 |
+
# But we only store each dense layer's output in the new_features array
|
31 |
+
for layer in self.layers:
|
32 |
+
out = layer(x)
|
33 |
+
x = torch.cat([x, out], 1)
|
34 |
+
new_features.append(out)
|
35 |
+
return torch.cat(new_features, 1)
|
36 |
+
else:
|
37 |
+
for layer in self.layers:
|
38 |
+
out = layer(x)
|
39 |
+
x = torch.cat([x, out], 1) # 1 = channel axis
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class TransitionDown(nn.Sequential):
|
44 |
+
def __init__(self, in_channels):
|
45 |
+
super().__init__()
|
46 |
+
self.add_module('norm', nn.BatchNorm2d(num_features=in_channels))
|
47 |
+
self.add_module('relu', nn.ReLU(inplace=True))
|
48 |
+
self.add_module('conv', nn.Conv2d(in_channels, in_channels,
|
49 |
+
kernel_size=1, stride=1,
|
50 |
+
padding=0, bias=True))
|
51 |
+
self.add_module('drop', nn.Dropout2d(0.2))
|
52 |
+
self.add_module('maxpool', nn.MaxPool2d(2))
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
return super().forward(x)
|
56 |
+
|
57 |
+
|
58 |
+
class TransitionUp(nn.Module):
|
59 |
+
def __init__(self, in_channels, out_channels):
|
60 |
+
super().__init__()
|
61 |
+
self.convTrans = nn.ConvTranspose2d(
|
62 |
+
in_channels=in_channels, out_channels=out_channels,
|
63 |
+
kernel_size=3, stride=2, padding=0, bias=True)
|
64 |
+
|
65 |
+
def forward(self, x, skip):
|
66 |
+
out = self.convTrans(x)
|
67 |
+
out = center_crop(out, skip.size(2), skip.size(3))
|
68 |
+
out = torch.cat([out, skip], 1)
|
69 |
+
return out
|
70 |
+
|
71 |
+
|
72 |
+
class Bottleneck(nn.Sequential):
|
73 |
+
def __init__(self, in_channels, growth_rate, n_layers):
|
74 |
+
super().__init__()
|
75 |
+
self.add_module('bottleneck', DenseBlock(
|
76 |
+
in_channels, growth_rate, n_layers, upsample=True))
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
return super().forward(x)
|
80 |
+
|
81 |
+
|
82 |
+
def center_crop(layer, max_height, max_width):
|
83 |
+
_, _, h, w = layer.size()
|
84 |
+
xy1 = (w - max_width) // 2
|
85 |
+
xy2 = (h - max_height) // 2
|
86 |
+
return layer[:, :, xy2:(xy2 + max_height), xy1:(xy1 + max_width)]
|
models/tiramisu.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .layers import *
|
5 |
+
|
6 |
+
|
7 |
+
class FCDenseNet(nn.Module):
|
8 |
+
def __init__(self, in_channels=3, down_blocks=(5, 5, 5, 5, 5),
|
9 |
+
up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5,
|
10 |
+
growth_rate=16, out_chans_first_conv=48, n_classes=12):
|
11 |
+
super().__init__()
|
12 |
+
self.down_blocks = down_blocks
|
13 |
+
self.up_blocks = up_blocks
|
14 |
+
cur_channels_count = 0
|
15 |
+
skip_connection_channel_counts = []
|
16 |
+
|
17 |
+
## First Convolution ##
|
18 |
+
|
19 |
+
self.add_module('firstconv', nn.Conv2d(in_channels=in_channels,
|
20 |
+
out_channels=out_chans_first_conv, kernel_size=3,
|
21 |
+
stride=1, padding=1, bias=True))
|
22 |
+
cur_channels_count = out_chans_first_conv
|
23 |
+
|
24 |
+
#####################
|
25 |
+
# Downsampling path #
|
26 |
+
#####################
|
27 |
+
|
28 |
+
self.denseBlocksDown = nn.ModuleList([])
|
29 |
+
self.transDownBlocks = nn.ModuleList([])
|
30 |
+
for i in range(len(down_blocks)):
|
31 |
+
self.denseBlocksDown.append(
|
32 |
+
DenseBlock(cur_channels_count, growth_rate, down_blocks[i]))
|
33 |
+
cur_channels_count += (growth_rate*down_blocks[i])
|
34 |
+
skip_connection_channel_counts.insert(0, cur_channels_count)
|
35 |
+
self.transDownBlocks.append(TransitionDown(cur_channels_count))
|
36 |
+
|
37 |
+
#####################
|
38 |
+
# Bottleneck #
|
39 |
+
#####################
|
40 |
+
|
41 |
+
self.add_module('bottleneck', Bottleneck(cur_channels_count,
|
42 |
+
growth_rate, bottleneck_layers))
|
43 |
+
prev_block_channels = growth_rate*bottleneck_layers
|
44 |
+
cur_channels_count += prev_block_channels
|
45 |
+
|
46 |
+
#######################
|
47 |
+
# Upsampling path #
|
48 |
+
#######################
|
49 |
+
|
50 |
+
self.transUpBlocks = nn.ModuleList([])
|
51 |
+
self.denseBlocksUp = nn.ModuleList([])
|
52 |
+
for i in range(len(up_blocks)-1):
|
53 |
+
self.transUpBlocks.append(TransitionUp(
|
54 |
+
prev_block_channels, prev_block_channels))
|
55 |
+
cur_channels_count = prev_block_channels + \
|
56 |
+
skip_connection_channel_counts[i]
|
57 |
+
|
58 |
+
self.denseBlocksUp.append(DenseBlock(
|
59 |
+
cur_channels_count, growth_rate, up_blocks[i],
|
60 |
+
upsample=True))
|
61 |
+
prev_block_channels = growth_rate*up_blocks[i]
|
62 |
+
cur_channels_count += prev_block_channels
|
63 |
+
|
64 |
+
## Final DenseBlock ##
|
65 |
+
|
66 |
+
self.transUpBlocks.append(TransitionUp(
|
67 |
+
prev_block_channels, prev_block_channels))
|
68 |
+
cur_channels_count = prev_block_channels + \
|
69 |
+
skip_connection_channel_counts[-1]
|
70 |
+
|
71 |
+
self.denseBlocksUp.append(DenseBlock(
|
72 |
+
cur_channels_count, growth_rate, up_blocks[-1],
|
73 |
+
upsample=False))
|
74 |
+
cur_channels_count += growth_rate*up_blocks[-1]
|
75 |
+
|
76 |
+
## Softmax ##
|
77 |
+
|
78 |
+
self.finalConv = nn.Conv2d(in_channels=cur_channels_count,
|
79 |
+
out_channels=n_classes, kernel_size=1, stride=1,
|
80 |
+
padding=0, bias=True)
|
81 |
+
self.softmax = nn.LogSoftmax(dim=1)
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
out = self.firstconv(x)
|
85 |
+
|
86 |
+
skip_connections = []
|
87 |
+
for i in range(len(self.down_blocks)):
|
88 |
+
out = self.denseBlocksDown[i](out)
|
89 |
+
skip_connections.append(out)
|
90 |
+
out = self.transDownBlocks[i](out)
|
91 |
+
|
92 |
+
out = self.bottleneck(out)
|
93 |
+
for i in range(len(self.up_blocks)):
|
94 |
+
skip = skip_connections.pop()
|
95 |
+
out = self.transUpBlocks[i](out, skip)
|
96 |
+
out = self.denseBlocksUp[i](out)
|
97 |
+
|
98 |
+
out = self.finalConv(out)
|
99 |
+
out = self.softmax(out)
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
def FCDenseNet57(n_classes):
|
104 |
+
return FCDenseNet(
|
105 |
+
in_channels=3, down_blocks=(4, 4, 4, 4, 4),
|
106 |
+
up_blocks=(4, 4, 4, 4, 4), bottleneck_layers=4,
|
107 |
+
growth_rate=12, out_chans_first_conv=48, n_classes=n_classes)
|
108 |
+
|
109 |
+
|
110 |
+
def FCDenseNet67(n_classes):
|
111 |
+
return FCDenseNet(
|
112 |
+
in_channels=3, down_blocks=(5, 5, 5, 5, 5),
|
113 |
+
up_blocks=(5, 5, 5, 5, 5), bottleneck_layers=5,
|
114 |
+
growth_rate=16, out_chans_first_conv=48, n_classes=n_classes)
|
115 |
+
|
116 |
+
|
117 |
+
def FCDenseNet103(n_classes):
|
118 |
+
return FCDenseNet(
|
119 |
+
in_channels=3, down_blocks=(4, 5, 7, 10, 12),
|
120 |
+
up_blocks=(12, 10, 7, 5, 4), bottleneck_layers=15,
|
121 |
+
growth_rate=16, out_chans_first_conv=48, n_classes=n_classes)
|
requirements.txt
ADDED
Binary file (60 Bytes). View file
|
|