umuthopeyildirim commited on
Commit
bd86ed9
1 Parent(s): db5b5dc

here we go

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import os
5
+ from PIL import Image
6
+ import spaces
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torchvision.transforms import Compose
10
+ import tempfile
11
+ from gradio_imageslider import ImageSlider
12
+
13
+ from iebins.networks.NewCRFDepth import NewCRFDepth
14
+ from iebins.utils.transfrom import Resize, NormalizeImage, PrepareForNet
15
+
16
+ css = """
17
+ #img-display-container {
18
+ max-height: 100vh;
19
+ }
20
+ #img-display-input {
21
+ max-height: 80vh;
22
+ }
23
+ #img-display-output {
24
+ max-height: 80vh;
25
+ }
26
+ """
27
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
28
+ model = NewCRFDepth(version="large07", inv_depth=False,
29
+ max_depth=10, pretrained=None).to(DEVICE).eval()
30
+ model.load_state_dict(torch.load('checkpoints/nyu_L.pth'))
31
+
32
+ title = "# IEBins: Iterative Elastic Bins for Monocular Depth Estimation"
33
+ description = """Demo for **IEBins: Iterative Elastic Bins for Monocular Depth Estimation**.
34
+ Please refer to the [paper](https://arxiv.org/abs/2309.14137), [github](https://github.com/ShuweiShao/IEBins), or [poster](https://nips.cc/media/PosterPDFs/NeurIPS%202023/70695.png?t=1701662442.5228624) for more details."""
35
+
36
+ transform = Compose([
37
+ Resize(
38
+ width=518,
39
+ height=518,
40
+ resize_target=False,
41
+ keep_aspect_ratio=True,
42
+ ensure_multiple_of=14,
43
+ resize_method='lower_bound',
44
+ image_interpolation_method=cv2.INTER_CUBIC,
45
+ ),
46
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
+ PrepareForNet(),
48
+ ])
49
+
50
+
51
+ @spaces.GPU
52
+ @torch.no_grad()
53
+ def predict_depth(model, image):
54
+ return model(image)
55
+
56
+
57
+ with gr.Blocks(css=css) as demo:
58
+ gr.Markdown(title)
59
+ gr.Markdown(description)
60
+ gr.Markdown("### Depth Prediction demo")
61
+ gr.Markdown(
62
+ "You can slide the output to compare the depth prediction with input image")
63
+
64
+ with gr.Row():
65
+ input_image = gr.Image(label="Input Image",
66
+ type='numpy', elem_id='img-display-input')
67
+ depth_image_slider = ImageSlider(
68
+ label="Depth Map with Slider View", elem_id='img-display-output', position=0.5,)
69
+ raw_file = gr.File(
70
+ label="16-bit raw depth (can be considered as disparity)")
71
+ submit = gr.Button("Submit")
72
+
73
+ def on_submit(image):
74
+ original_image = image.copy()
75
+
76
+ h, w = image.shape[:2]
77
+
78
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
79
+ image = transform({'image': image})['image']
80
+ image = torch.from_numpy(image).unsqueeze(0).to(DEVICE)
81
+
82
+ depth = predict_depth(model, image)
83
+ depth = F.interpolate(depth[None], (h, w),
84
+ mode='bilinear', align_corners=False)[0, 0]
85
+
86
+ raw_depth = Image.fromarray(depth.cpu().numpy().astype('uint16'))
87
+ tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
88
+ raw_depth.save(tmp.name)
89
+
90
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
91
+ depth = depth.cpu().numpy().astype(np.uint8)
92
+ colored_depth = cv2.applyColorMap(
93
+ depth, cv2.COLORMAP_INFERNO)[:, :, ::-1]
94
+
95
+ return [(original_image, colored_depth), tmp.name]
96
+
97
+ submit.click(on_submit, inputs=[input_image], outputs=[
98
+ depth_image_slider, raw_file])
99
+
100
+ example_files = os.listdir('examples')
101
+ example_files.sort()
102
+ example_files = [os.path.join('examples', filename)
103
+ for filename in example_files]
104
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[
105
+ depth_image_slider, raw_file], fn=on_submit, cache_examples=False)
106
+
107
+
108
+ if __name__ == '__main__':
109
+ demo.queue().launch()
checkpoints/kittieigen_L.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf10549a615b19b96ffdddc82e639662c421fe0cd30008cc3cf3e7d4bffa5f55
3
+ size 3276188594
checkpoints/nyu_L.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81d95d5f26f5d01b7e8b060467eef77ea6efea4ddf100d60f5fad87e6c0daae7
3
+ size 3276188594
iebins/dataloaders/__init__.py ADDED
File without changes
iebins/dataloaders/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
iebins/dataloaders/__pycache__/dataloader.cpython-38.pyc ADDED
Binary file (9.15 kB). View file
 
iebins/dataloaders/__pycache__/dataloader_sun.cpython-38.pyc ADDED
Binary file (8.93 kB). View file
 
iebins/dataloaders/dataloader.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch.utils.data.distributed
4
+ from torchvision import transforms
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ import os
9
+ import random
10
+ import copy
11
+
12
+ from utils import DistributedSamplerNoEvenlyDivisible
13
+
14
+
15
+ def _is_pil_image(img):
16
+ return isinstance(img, Image.Image)
17
+
18
+
19
+ def _is_numpy_image(img):
20
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
21
+
22
+
23
+ def preprocessing_transforms(mode):
24
+ return transforms.Compose([
25
+ ToTensor(mode=mode)
26
+ ])
27
+
28
+
29
+ class NewDataLoader(object):
30
+ def __init__(self, args, mode):
31
+ if mode == 'train':
32
+ self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
33
+ if args.distributed:
34
+ self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
35
+ else:
36
+ self.train_sampler = None
37
+
38
+ self.data = DataLoader(self.training_samples, args.batch_size,
39
+ shuffle=(self.train_sampler is None),
40
+ num_workers=args.num_threads,
41
+ pin_memory=True,
42
+ sampler=self.train_sampler)
43
+
44
+ elif mode == 'online_eval':
45
+ self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
46
+ if args.distributed:
47
+ # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
48
+ self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
49
+ else:
50
+ self.eval_sampler = None
51
+ self.data = DataLoader(self.testing_samples, 1,
52
+ shuffle=False,
53
+ num_workers=1,
54
+ pin_memory=True,
55
+ sampler=self.eval_sampler)
56
+
57
+ elif mode == 'test':
58
+ self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
59
+ self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
60
+
61
+ else:
62
+ print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
63
+
64
+
65
+ class DataLoadPreprocess(Dataset):
66
+ def __init__(self, args, mode, transform=None, is_for_online_eval=False):
67
+ self.args = args
68
+ if mode == 'online_eval':
69
+ with open(args.filenames_file_eval, 'r') as f:
70
+ self.filenames = f.readlines()
71
+ else:
72
+ with open(args.filenames_file, 'r') as f:
73
+ self.filenames = f.readlines()
74
+
75
+ self.mode = mode
76
+ self.transform = transform
77
+ self.to_tensor = ToTensor
78
+ self.is_for_online_eval = is_for_online_eval
79
+
80
+ def __getitem__(self, idx):
81
+ sample_path = self.filenames[idx]
82
+ # focal = float(sample_path.split()[2])
83
+ focal = 518.8579
84
+
85
+ if self.mode == 'train':
86
+ if self.args.dataset == 'kitti':
87
+ rgb_file = sample_path.split()[0]
88
+ depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1])
89
+ if self.args.use_right is True and random.random() > 0.5:
90
+ rgb_file = rgb_file.replace('image_02', 'image_03')
91
+ depth_file = depth_file.replace('image_02', 'image_03')
92
+ else:
93
+ rgb_file = sample_path.split()[0]
94
+ depth_file = sample_path.split()[1]
95
+
96
+ image_path = os.path.join(self.args.data_path, rgb_file)
97
+ depth_path = os.path.join(self.args.gt_path, depth_file)
98
+
99
+ image = Image.open(image_path)
100
+ depth_gt = Image.open(depth_path)
101
+
102
+ if self.args.do_kb_crop is True:
103
+ height = image.height
104
+ width = image.width
105
+ top_margin = int(height - 352)
106
+ left_margin = int((width - 1216) / 2)
107
+ depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
108
+ image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
109
+
110
+ # To avoid blank boundaries due to pixel registration
111
+ if self.args.dataset == 'nyu':
112
+ if self.args.input_height == 480:
113
+ depth_gt = np.array(depth_gt)
114
+ valid_mask = np.zeros_like(depth_gt)
115
+ valid_mask[45:472, 43:608] = 1
116
+ depth_gt[valid_mask==0] = 0
117
+ depth_gt = Image.fromarray(depth_gt)
118
+ else:
119
+ depth_gt = depth_gt.crop((43, 45, 608, 472))
120
+ image = image.crop((43, 45, 608, 472))
121
+
122
+ if self.args.do_random_rotate is True:
123
+ random_angle = (random.random() - 0.5) * 2 * self.args.degree
124
+ image = self.rotate_image(image, random_angle)
125
+ depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
126
+
127
+ image = np.asarray(image, dtype=np.float32) / 255.0
128
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
129
+ depth_gt = np.expand_dims(depth_gt, axis=2)
130
+
131
+ if self.args.dataset == 'nyu':
132
+ depth_gt = depth_gt / 1000.0
133
+ else:
134
+ depth_gt = depth_gt / 256.0
135
+
136
+ if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
137
+ image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
138
+ image, depth_gt = self.train_preprocess(image, depth_gt)
139
+ # https://github.com/ShuweiShao/URCDC-Depth
140
+ image, depth_gt = self.Cut_Flip(image, depth_gt)
141
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal}
142
+
143
+ else:
144
+ if self.mode == 'online_eval':
145
+ data_path = self.args.data_path_eval
146
+ else:
147
+ data_path = self.args.data_path
148
+
149
+ image_path = os.path.join(data_path, "./" + sample_path.split()[0])
150
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
151
+
152
+ if self.mode == 'online_eval':
153
+ gt_path = self.args.gt_path_eval
154
+ depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
155
+ if self.args.dataset == 'kitti':
156
+ depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1])
157
+ has_valid_depth = False
158
+ try:
159
+ depth_gt = Image.open(depth_path)
160
+ has_valid_depth = True
161
+ except IOError:
162
+ depth_gt = False
163
+ # print('Missing gt for {}'.format(image_path))
164
+
165
+ if has_valid_depth:
166
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
167
+ depth_gt = np.expand_dims(depth_gt, axis=2)
168
+ if self.args.dataset == 'nyu':
169
+ depth_gt = depth_gt / 1000.0
170
+ else:
171
+ depth_gt = depth_gt / 256.0
172
+
173
+ if self.args.do_kb_crop is True:
174
+ height = image.shape[0]
175
+ width = image.shape[1]
176
+ top_margin = int(height - 352)
177
+ left_margin = int((width - 1216) / 2)
178
+ image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
179
+ if self.mode == 'online_eval' and has_valid_depth:
180
+ depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
181
+
182
+ if self.mode == 'online_eval':
183
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
184
+ else:
185
+ sample = {'image': image, 'focal': focal}
186
+
187
+ if self.transform:
188
+ sample = self.transform([sample, self.args.dataset])
189
+
190
+ return sample
191
+
192
+ def rotate_image(self, image, angle, flag=Image.BILINEAR):
193
+ result = image.rotate(angle, resample=flag)
194
+ return result
195
+
196
+ def random_crop(self, img, depth, height, width):
197
+ assert img.shape[0] >= height
198
+ assert img.shape[1] >= width
199
+ assert img.shape[0] == depth.shape[0]
200
+ assert img.shape[1] == depth.shape[1]
201
+ x = random.randint(0, img.shape[1] - width)
202
+ y = random.randint(0, img.shape[0] - height)
203
+ img = img[y:y + height, x:x + width, :]
204
+ depth = depth[y:y + height, x:x + width, :]
205
+ return img, depth
206
+
207
+ def train_preprocess(self, image, depth_gt):
208
+ # Random flipping
209
+ do_flip = random.random()
210
+ if do_flip > 0.5:
211
+ image = (image[:, ::-1, :]).copy()
212
+ depth_gt = (depth_gt[:, ::-1, :]).copy()
213
+
214
+ # Random gamma, brightness, color augmentation
215
+ do_augment = random.random()
216
+ if do_augment > 0.5:
217
+ image = self.augment_image(image)
218
+
219
+ return image, depth_gt
220
+
221
+ def augment_image(self, image):
222
+ # gamma augmentation
223
+ gamma = random.uniform(0.9, 1.1)
224
+ image_aug = image ** gamma
225
+
226
+ # brightness augmentation
227
+ if self.args.dataset == 'nyu':
228
+ brightness = random.uniform(0.75, 1.25)
229
+ else:
230
+ brightness = random.uniform(0.9, 1.1)
231
+ image_aug = image_aug * brightness
232
+
233
+ # color augmentation
234
+ colors = np.random.uniform(0.9, 1.1, size=3)
235
+ white = np.ones((image.shape[0], image.shape[1]))
236
+ color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
237
+ image_aug *= color_image
238
+ image_aug = np.clip(image_aug, 0, 1)
239
+
240
+ return image_aug
241
+
242
+ def Cut_Flip(self, image, depth):
243
+
244
+ p = random.random()
245
+ if p < 0.5:
246
+ return image, depth
247
+ image_copy = copy.deepcopy(image)
248
+ depth_copy = copy.deepcopy(depth)
249
+ h, w, c = image.shape
250
+
251
+ N = 2
252
+ h_list = []
253
+ h_interval_list = [] # hight interval
254
+ for i in range(N-1):
255
+ h_list.append(random.randint(int(0.2*h), int(0.8*h)))
256
+ h_list.append(h)
257
+ h_list.append(0)
258
+ h_list.sort()
259
+ h_list_inv = np.array([h]*(N+1))-np.array(h_list)
260
+ for i in range(len(h_list)-1):
261
+ h_interval_list.append(h_list[i+1]-h_list[i])
262
+ for i in range(N):
263
+ image[h_list[i]:h_list[i+1], :, :] = image_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
264
+ depth[h_list[i]:h_list[i+1], :, :] = depth_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
265
+
266
+ return image, depth
267
+
268
+
269
+ def __len__(self):
270
+ return len(self.filenames)
271
+
272
+
273
+ class ToTensor(object):
274
+ def __init__(self, mode):
275
+ self.mode = mode
276
+ self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
277
+
278
+ def __call__(self, sample_dataset):
279
+
280
+ sample = sample_dataset[0]
281
+ dataset = sample_dataset[1]
282
+
283
+ image, focal = sample['image'], sample['focal']
284
+ image = self.to_tensor(image)
285
+ image = self.normalize(image)
286
+
287
+ if dataset == 'kitti':
288
+ K_p = np.array([[716.88, 0, 596.5593, 0],
289
+ [0, 716.88, 149.854, 0],
290
+ [0, 0, 1, 0],
291
+ [0, 0, 0, 1]], dtype=np.float32)
292
+ inv_K_p = np.linalg.pinv(K_p)
293
+ inv_K_p = torch.from_numpy(inv_K_p)
294
+
295
+ elif dataset == 'nyu':
296
+ K_p = np.array([[518.8579, 0, 325.5824, 0],
297
+ [0, 518.8579, 253.7362, 0],
298
+ [0, 0, 1, 0],
299
+ [0, 0, 0, 1]], dtype=np.float32)
300
+ inv_K_p = np.linalg.pinv(K_p)
301
+ inv_K_p = torch.from_numpy(inv_K_p)
302
+
303
+ if self.mode == 'test':
304
+ return {'image': image, 'inv_K_p': inv_K_p, 'focal': focal}
305
+
306
+ depth = sample['depth']
307
+ if self.mode == 'train':
308
+ depth = self.to_tensor(depth)
309
+ return {'image': image, 'depth': depth, 'focal': focal}
310
+ else:
311
+ has_valid_depth = sample['has_valid_depth']
312
+ return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
313
+
314
+ def to_tensor(self, pic):
315
+ if not (_is_pil_image(pic) or _is_numpy_image(pic)):
316
+ raise TypeError(
317
+ 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
318
+
319
+ if isinstance(pic, np.ndarray):
320
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
321
+ return img
322
+
323
+ # handle PIL Image
324
+ if pic.mode == 'I':
325
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
326
+ elif pic.mode == 'I;16':
327
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
328
+ else:
329
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
330
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
331
+ if pic.mode == 'YCbCr':
332
+ nchannel = 3
333
+ elif pic.mode == 'I;16':
334
+ nchannel = 1
335
+ else:
336
+ nchannel = len(pic.mode)
337
+ img = img.view(pic.size[1], pic.size[0], nchannel)
338
+
339
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
340
+ if isinstance(img, torch.ByteTensor):
341
+ return img.float()
342
+ else:
343
+ return img
iebins/dataloaders/dataloader_sun.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ import torch.utils.data.distributed
4
+ from torchvision import transforms
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ import os
9
+ import random
10
+ import copy
11
+ import cv2
12
+
13
+ from utils import DistributedSamplerNoEvenlyDivisible
14
+
15
+
16
+ def _is_pil_image(img):
17
+ return isinstance(img, Image.Image)
18
+
19
+
20
+ def _is_numpy_image(img):
21
+ return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
22
+
23
+
24
+ def preprocessing_transforms(mode):
25
+ return transforms.Compose([
26
+ ToTensor(mode=mode)
27
+ ])
28
+
29
+
30
+ class NewDataLoader(object):
31
+ def __init__(self, args, mode):
32
+ if mode == 'train':
33
+ self.training_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
34
+ if args.distributed:
35
+ self.train_sampler = torch.utils.data.distributed.DistributedSampler(self.training_samples)
36
+ else:
37
+ self.train_sampler = None
38
+
39
+ self.data = DataLoader(self.training_samples, args.batch_size,
40
+ shuffle=(self.train_sampler is None),
41
+ num_workers=args.num_threads,
42
+ pin_memory=True,
43
+ sampler=self.train_sampler)
44
+
45
+ elif mode == 'online_eval':
46
+ self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
47
+ if args.distributed:
48
+ # self.eval_sampler = torch.utils.data.distributed.DistributedSampler(self.testing_samples, shuffle=False)
49
+ self.eval_sampler = DistributedSamplerNoEvenlyDivisible(self.testing_samples, shuffle=False)
50
+ else:
51
+ self.eval_sampler = None
52
+ self.data = DataLoader(self.testing_samples, 1,
53
+ shuffle=False,
54
+ num_workers=1,
55
+ pin_memory=True,
56
+ sampler=self.eval_sampler)
57
+
58
+ elif mode == 'test':
59
+ self.testing_samples = DataLoadPreprocess(args, mode, transform=preprocessing_transforms(mode))
60
+ self.data = DataLoader(self.testing_samples, 1, shuffle=False, num_workers=1)
61
+
62
+ else:
63
+ print('mode should be one of \'train, test, online_eval\'. Got {}'.format(mode))
64
+
65
+
66
+ class DataLoadPreprocess(Dataset):
67
+ def __init__(self, args, mode, transform=None, is_for_online_eval=False):
68
+ self.args = args
69
+ if mode == 'online_eval':
70
+ with open(args.filenames_file_eval, 'r') as f:
71
+ self.filenames = f.readlines()
72
+ else:
73
+ with open(args.filenames_file, 'r') as f:
74
+ self.filenames = f.readlines()
75
+
76
+ self.mode = mode
77
+ self.transform = transform
78
+ self.to_tensor = ToTensor
79
+ self.is_for_online_eval = is_for_online_eval
80
+
81
+ def __getitem__(self, idx):
82
+ sample_path = self.filenames[idx]
83
+ # focal = float(sample_path.split()[2])
84
+ focal = 518.8579
85
+
86
+ if self.mode == 'train':
87
+ if self.args.dataset == 'kitti':
88
+ rgb_file = sample_path.split()[0]
89
+ depth_file = os.path.join(sample_path.split()[0].split('/')[0], sample_path.split()[1])
90
+ if self.args.use_right is True and random.random() > 0.5:
91
+ rgb_file = rgb_file.replace('image_02', 'image_03')
92
+ depth_file = depth_file.replace('image_02', 'image_03')
93
+ else:
94
+ rgb_file = sample_path.split()[0]
95
+ depth_file = sample_path.split()[1]
96
+
97
+ image_path = os.path.join(self.args.data_path, rgb_file)
98
+ depth_path = os.path.join(self.args.gt_path, depth_file)
99
+
100
+ image = Image.open(image_path)
101
+ depth_gt = Image.open(depth_path)
102
+
103
+ if self.args.do_kb_crop is True:
104
+ height = image.height
105
+ width = image.width
106
+ top_margin = int(height - 352)
107
+ left_margin = int((width - 1216) / 2)
108
+ depth_gt = depth_gt.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
109
+ image = image.crop((left_margin, top_margin, left_margin + 1216, top_margin + 352))
110
+
111
+ # To avoid blank boundaries due to pixel registration
112
+ if self.args.dataset == 'nyu':
113
+ if self.args.input_height == 480:
114
+ depth_gt = np.array(depth_gt)
115
+ valid_mask = np.zeros_like(depth_gt)
116
+ valid_mask[45:472, 43:608] = 1
117
+ depth_gt[valid_mask==0] = 0
118
+ depth_gt = Image.fromarray(depth_gt)
119
+ else:
120
+ depth_gt = depth_gt.crop((43, 45, 608, 472))
121
+ image = image.crop((43, 45, 608, 472))
122
+
123
+ if self.args.do_random_rotate is True:
124
+ random_angle = (random.random() - 0.5) * 2 * self.args.degree
125
+ image = self.rotate_image(image, random_angle)
126
+ depth_gt = self.rotate_image(depth_gt, random_angle, flag=Image.NEAREST)
127
+
128
+ image = np.asarray(image, dtype=np.float32) / 255.0
129
+ depth_gt = np.asarray(depth_gt, dtype=np.float32)
130
+ depth_gt = np.expand_dims(depth_gt, axis=2)
131
+
132
+ if self.args.dataset == 'nyu':
133
+ depth_gt = depth_gt / 1000.0
134
+ else:
135
+ depth_gt = depth_gt / 256.0
136
+
137
+ if image.shape[0] != self.args.input_height or image.shape[1] != self.args.input_width:
138
+ image, depth_gt = self.random_crop(image, depth_gt, self.args.input_height, self.args.input_width)
139
+ image, depth_gt = self.train_preprocess(image, depth_gt)
140
+ image, depth_gt = self.Cut_Flip(image, depth_gt)
141
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal}
142
+
143
+ else:
144
+ if self.mode == 'online_eval':
145
+ data_path = self.args.data_path_eval
146
+ else:
147
+ data_path = self.args.data_path
148
+
149
+ image_path = os.path.join(data_path, "./" + sample_path.split()[0])
150
+ image = np.asarray(Image.open(image_path), dtype=np.float32) / 255.0
151
+ image = cv2.resize(image, (640, 480))
152
+
153
+ if self.mode == 'online_eval':
154
+ gt_path = self.args.gt_path_eval
155
+ depth_path = os.path.join(gt_path, "./" + sample_path.split()[1])
156
+ if self.args.dataset == 'kitti':
157
+ depth_path = os.path.join(gt_path, sample_path.split()[0].split('/')[0], sample_path.split()[1])
158
+ has_valid_depth = False
159
+ try:
160
+ depth_gt = Image.open(depth_path)
161
+ has_valid_depth = True
162
+ except IOError:
163
+ depth_gt = False
164
+ # print('Missing gt for {}'.format(image_path))
165
+
166
+ if has_valid_depth:
167
+ depth_gt = np.asarray(depth_gt, dtype=np.uint16) # 2
168
+ depth_gt = np.bitwise_or(np.right_shift(depth_gt, 3), np.left_shift(depth_gt, 16 - 3)) # 3
169
+ depth_gt = np.expand_dims(depth_gt, axis=2)
170
+ if self.args.dataset == 'nyu':
171
+ depth_gt = depth_gt.astype(np.single) / 1000 # 4
172
+ depth_gt = depth_gt.astype(np.float32) # 5
173
+ else:
174
+ depth_gt = depth_gt / 256.0
175
+
176
+ if self.args.do_kb_crop is True:
177
+ height = image.shape[0]
178
+ width = image.shape[1]
179
+ top_margin = int(height - 352)
180
+ left_margin = int((width - 1216) / 2)
181
+ image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
182
+ if self.mode == 'online_eval' and has_valid_depth:
183
+ depth_gt = depth_gt[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
184
+
185
+ if self.mode == 'online_eval':
186
+ sample = {'image': image, 'depth': depth_gt, 'focal': focal, 'has_valid_depth': has_valid_depth}
187
+ else:
188
+ sample = {'image': image, 'focal': focal}
189
+
190
+ if self.transform:
191
+ sample = self.transform(sample)
192
+
193
+ return sample
194
+
195
+ def rotate_image(self, image, angle, flag=Image.BILINEAR):
196
+ result = image.rotate(angle, resample=flag)
197
+ return result
198
+
199
+ def random_crop(self, img, depth, height, width):
200
+ assert img.shape[0] >= height
201
+ assert img.shape[1] >= width
202
+ assert img.shape[0] == depth.shape[0]
203
+ assert img.shape[1] == depth.shape[1]
204
+ x = random.randint(0, img.shape[1] - width)
205
+ y = random.randint(0, img.shape[0] - height)
206
+ img = img[y:y + height, x:x + width, :]
207
+ depth = depth[y:y + height, x:x + width, :]
208
+ return img, depth
209
+
210
+ def train_preprocess(self, image, depth_gt):
211
+ # Random flipping
212
+ do_flip = random.random()
213
+ if do_flip > 0.5:
214
+ image = (image[:, ::-1, :]).copy()
215
+ depth_gt = (depth_gt[:, ::-1, :]).copy()
216
+
217
+ # Random gamma, brightness, color augmentation
218
+ do_augment = random.random()
219
+ if do_augment > 0.5:
220
+ image = self.augment_image(image)
221
+
222
+ return image, depth_gt
223
+
224
+ def augment_image(self, image):
225
+ # gamma augmentation
226
+ gamma = random.uniform(0.9, 1.1)
227
+ image_aug = image ** gamma
228
+
229
+ # brightness augmentation
230
+ if self.args.dataset == 'nyu':
231
+ brightness = random.uniform(0.75, 1.25)
232
+ else:
233
+ brightness = random.uniform(0.9, 1.1)
234
+ image_aug = image_aug * brightness
235
+
236
+ # color augmentation
237
+ colors = np.random.uniform(0.9, 1.1, size=3)
238
+ white = np.ones((image.shape[0], image.shape[1]))
239
+ color_image = np.stack([white * colors[i] for i in range(3)], axis=2)
240
+ image_aug *= color_image
241
+ image_aug = np.clip(image_aug, 0, 1)
242
+
243
+ return image_aug
244
+
245
+ def Cut_Flip(self, image, depth):
246
+
247
+ p = random.random()
248
+ if p < 0.5:
249
+ return image, depth
250
+ image_copy = copy.deepcopy(image)
251
+ depth_copy = copy.deepcopy(depth)
252
+ h, w, c = image.shape
253
+
254
+ N = 2
255
+ h_list = []
256
+ h_interval_list = [] # hight interval
257
+ for i in range(N-1):
258
+ h_list.append(random.randint(int(0.2*h), int(0.8*h)))
259
+ h_list.append(h)
260
+ h_list.append(0)
261
+ h_list.sort()
262
+ h_list_inv = np.array([h]*(N+1))-np.array(h_list)
263
+ for i in range(len(h_list)-1):
264
+ h_interval_list.append(h_list[i+1]-h_list[i])
265
+ for i in range(N):
266
+ image[h_list[i]:h_list[i+1], :, :] = image_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
267
+ depth[h_list[i]:h_list[i+1], :, :] = depth_copy[h_list_inv[i]-h_interval_list[i]:h_list_inv[i], :, :]
268
+
269
+ return image, depth
270
+
271
+
272
+ def __len__(self):
273
+ return len(self.filenames)
274
+
275
+
276
+ class ToTensor(object):
277
+ def __init__(self, mode):
278
+ self.mode = mode
279
+ self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
280
+
281
+ def __call__(self, sample):
282
+ image, focal = sample['image'], sample['focal']
283
+ image = self.to_tensor(image)
284
+ image = self.normalize(image)
285
+
286
+ if self.mode == 'test':
287
+ return {'image': image, 'focal': focal}
288
+
289
+ depth = sample['depth']
290
+ if self.mode == 'train':
291
+ depth = self.to_tensor(depth)
292
+ return {'image': image, 'depth': depth, 'focal': focal}
293
+ else:
294
+ has_valid_depth = sample['has_valid_depth']
295
+ return {'image': image, 'depth': depth, 'focal': focal, 'has_valid_depth': has_valid_depth}
296
+
297
+ def to_tensor(self, pic):
298
+ if not (_is_pil_image(pic) or _is_numpy_image(pic)):
299
+ raise TypeError(
300
+ 'pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
301
+
302
+ if isinstance(pic, np.ndarray):
303
+ img = torch.from_numpy(pic.transpose((2, 0, 1)))
304
+ return img
305
+
306
+ # handle PIL Image
307
+ if pic.mode == 'I':
308
+ img = torch.from_numpy(np.array(pic, np.int32, copy=False))
309
+ elif pic.mode == 'I;16':
310
+ img = torch.from_numpy(np.array(pic, np.int16, copy=False))
311
+ else:
312
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
313
+ # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
314
+ if pic.mode == 'YCbCr':
315
+ nchannel = 3
316
+ elif pic.mode == 'I;16':
317
+ nchannel = 1
318
+ else:
319
+ nchannel = len(pic.mode)
320
+ img = img.view(pic.size[1], pic.size[0], nchannel)
321
+
322
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
323
+ if isinstance(img, torch.ByteTensor):
324
+ return img.float()
325
+ else:
326
+ return img
iebins/eval.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.backends.cudnn as cudnn
3
+
4
+ import os, sys
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from utils import post_process_depth, flip_lr, compute_errors
10
+ from networks.NewCRFDepth import NewCRFDepth
11
+
12
+
13
+ def convert_arg_line_to_args(arg_line):
14
+ for arg in arg_line.split():
15
+ if not arg.strip():
16
+ continue
17
+ yield arg
18
+
19
+
20
+ parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
21
+ parser.convert_arg_line_to_args = convert_arg_line_to_args
22
+
23
+ parser.add_argument('--model_name', type=str, help='model name', default='iebins')
24
+ parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
25
+ parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
26
+
27
+ # Dataset
28
+ parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
29
+ parser.add_argument('--input_height', type=int, help='input height', default=480)
30
+ parser.add_argument('--input_width', type=int, help='input width', default=640)
31
+ parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
32
+
33
+ # Preprocessing
34
+ parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
35
+ parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
36
+ parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
37
+ parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
38
+
39
+ # Eval
40
+ parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False)
41
+ parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False)
42
+ parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False)
43
+ parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
44
+ parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
45
+ parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
46
+ parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
47
+
48
+
49
+ if sys.argv.__len__() == 2:
50
+ arg_filename_with_prefix = '@' + sys.argv[1]
51
+ args = parser.parse_args([arg_filename_with_prefix])
52
+ else:
53
+ args = parser.parse_args()
54
+
55
+ if args.dataset == 'kitti' or args.dataset == 'nyu':
56
+ from dataloaders.dataloader import NewDataLoader
57
+
58
+
59
+ def eval(model, dataloader_eval, post_process=False):
60
+ eval_measures = torch.zeros(10).cuda()
61
+
62
+ for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
63
+ with torch.no_grad():
64
+ image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
65
+ gt_depth = eval_sample_batched['depth']
66
+ has_valid_depth = eval_sample_batched['has_valid_depth']
67
+ if not has_valid_depth:
68
+ # print('Invalid depth. continue.')
69
+ continue
70
+
71
+ pred_depths_r_list, _, _ = model(image)
72
+ if post_process:
73
+ image_flipped = flip_lr(image)
74
+ pred_depths_r_list_flipped, _, _ = model(image_flipped)
75
+ pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
76
+
77
+ pred_depth = pred_depth.cpu().numpy().squeeze()
78
+ gt_depth = gt_depth.cpu().numpy().squeeze()
79
+
80
+ if args.do_kb_crop:
81
+ height, width = gt_depth.shape
82
+ top_margin = int(height - 352)
83
+ left_margin = int((width - 1216) / 2)
84
+ pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
85
+ pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
86
+ pred_depth = pred_depth_uncropped
87
+
88
+ pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
89
+ pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
90
+ pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
91
+ pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
92
+
93
+ valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
94
+
95
+ if args.garg_crop or args.eigen_crop:
96
+ gt_height, gt_width = gt_depth.shape
97
+ eval_mask = np.zeros(valid_mask.shape)
98
+
99
+ if args.garg_crop:
100
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
101
+
102
+ elif args.eigen_crop:
103
+ if args.dataset == 'kitti':
104
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
105
+ elif args.dataset == 'nyu':
106
+ eval_mask[45:471, 41:601] = 1
107
+
108
+ valid_mask = np.logical_and(valid_mask, eval_mask)
109
+
110
+ measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
111
+
112
+ eval_measures[:9] += torch.tensor(measures).cuda()
113
+ eval_measures[9] += 1
114
+
115
+ eval_measures_cpu = eval_measures.cpu()
116
+ cnt = eval_measures_cpu[9].item()
117
+ eval_measures_cpu /= cnt
118
+ print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
119
+ print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
120
+ 'sq_rel', 'log_rms', 'd1', 'd2',
121
+ 'd3'))
122
+ for i in range(8):
123
+ print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
124
+ print('{:7.4f}'.format(eval_measures_cpu[8]))
125
+ return eval_measures_cpu
126
+
127
+
128
+ def main_worker(args):
129
+
130
+ # CRF model
131
+ model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
132
+ model.train()
133
+
134
+ num_params = sum([np.prod(p.size()) for p in model.parameters()])
135
+ print("== Total number of parameters: {}".format(num_params))
136
+
137
+ num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
138
+ print("== Total number of learning parameters: {}".format(num_params_update))
139
+
140
+ model = torch.nn.DataParallel(model)
141
+ model.cuda()
142
+
143
+ print("== Model Initialized")
144
+
145
+ if args.checkpoint_path != '':
146
+ if os.path.isfile(args.checkpoint_path):
147
+ print("== Loading checkpoint '{}'".format(args.checkpoint_path))
148
+ checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
149
+ model.load_state_dict(checkpoint['model'])
150
+ print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
151
+ del checkpoint
152
+ else:
153
+ print("== No checkpoint found at '{}'".format(args.checkpoint_path))
154
+
155
+ cudnn.benchmark = True
156
+
157
+ dataloader_eval = NewDataLoader(args, 'online_eval')
158
+
159
+ # ===== Evaluation ======
160
+ model.eval()
161
+ with torch.no_grad():
162
+ eval_measures = eval(model, dataloader_eval, post_process=True)
163
+
164
+
165
+ def main():
166
+ torch.cuda.empty_cache()
167
+ args.distributed = False
168
+ ngpus_per_node = torch.cuda.device_count()
169
+ if ngpus_per_node > 1:
170
+ print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
171
+ return -1
172
+
173
+ main_worker(args)
174
+
175
+
176
+ if __name__ == '__main__':
177
+ main()
iebins/eval_sun.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.backends.cudnn as cudnn
3
+ import torch.nn.functional as F
4
+ import os, sys
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from utils import post_process_depth, flip_lr, compute_errors
10
+ from networks.NewCRFDepth import NewCRFDepth
11
+
12
+
13
+ def convert_arg_line_to_args(arg_line):
14
+ for arg in arg_line.split():
15
+ if not arg.strip():
16
+ continue
17
+ yield arg
18
+
19
+
20
+ parser = argparse.ArgumentParser(description='IEbins PyTorch implementation.', fromfile_prefix_chars='@')
21
+ parser.convert_arg_line_to_args = convert_arg_line_to_args
22
+
23
+ parser.add_argument('--model_name', type=str, help='model name', default='iebins')
24
+ parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
25
+ parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
26
+
27
+ # Dataset
28
+ parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
29
+ parser.add_argument('--input_height', type=int, help='input height', default=480)
30
+ parser.add_argument('--input_width', type=int, help='input width', default=640)
31
+ parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
32
+
33
+ # Preprocessing
34
+ parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
35
+ parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
36
+ parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
37
+ parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
38
+
39
+ # Eval
40
+ parser.add_argument('--data_path_eval', type=str, help='path to the data for evaluation', required=False)
41
+ parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for evaluation', required=False)
42
+ parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for evaluation', required=False)
43
+ parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
44
+ parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
45
+ parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
46
+ parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
47
+
48
+
49
+ if sys.argv.__len__() == 2:
50
+ arg_filename_with_prefix = '@' + sys.argv[1]
51
+ args = parser.parse_args([arg_filename_with_prefix])
52
+ else:
53
+ args = parser.parse_args()
54
+
55
+ if args.dataset == 'nyu':
56
+ from dataloaders.dataloader_sun import NewDataLoader
57
+
58
+
59
+ def eval(model, dataloader_eval, post_process=False):
60
+ eval_measures = torch.zeros(10).cuda()
61
+ for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
62
+ with torch.no_grad():
63
+ image = torch.autograd.Variable(eval_sample_batched['image'].cuda())
64
+ gt_depth = eval_sample_batched['depth']
65
+ has_valid_depth = eval_sample_batched['has_valid_depth']
66
+ if not has_valid_depth:
67
+ # print('Invalid depth. continue.')
68
+ continue
69
+ _, hh, ww, _ = gt_depth.shape
70
+ pred_depths_r_list, _, _ = model(image)
71
+ if post_process:
72
+ image_flipped = flip_lr(image)
73
+ pred_depths_r_list_flipped, _, _ = model(image_flipped)
74
+ pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
75
+ pred_depth = F.interpolate(pred_depth, [hh, ww], mode="bilinear", align_corners=False)
76
+
77
+ pred_depth = pred_depth.cpu().numpy().squeeze()
78
+ gt_depth = gt_depth.cpu().numpy().squeeze()
79
+
80
+ if args.do_kb_crop:
81
+ height, width = gt_depth.shape
82
+ top_margin = int(height - 352)
83
+ left_margin = int((width - 1216) / 2)
84
+ pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
85
+ pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
86
+ pred_depth = pred_depth_uncropped
87
+
88
+ pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
89
+ pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
90
+ pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
91
+ pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
92
+ pred_depth[pred_depth > 8] = 8
93
+ gt_depth[gt_depth > 8] = 8
94
+
95
+ valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
96
+
97
+ if args.garg_crop or args.eigen_crop:
98
+ gt_height, gt_width = gt_depth.shape
99
+ eval_mask = np.zeros(valid_mask.shape)
100
+
101
+ if args.garg_crop:
102
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
103
+
104
+ elif args.eigen_crop:
105
+ if args.dataset == 'kitti':
106
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
107
+ elif args.dataset == 'nyu':
108
+ eval_mask[45:471, 41:601] = 1
109
+
110
+ valid_mask = np.logical_and(valid_mask, eval_mask)
111
+
112
+ measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
113
+
114
+ eval_measures[:9] += torch.tensor(measures).cuda()
115
+ eval_measures[9] += 1
116
+
117
+ eval_measures_cpu = eval_measures.cpu()
118
+ cnt = eval_measures_cpu[9].item()
119
+ eval_measures_cpu /= cnt
120
+ print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
121
+ print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
122
+ 'sq_rel', 'log_rms', 'd1', 'd2',
123
+ 'd3'))
124
+ for i in range(8):
125
+ print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
126
+ print('{:7.4f}'.format(eval_measures_cpu[8]))
127
+ return eval_measures_cpu
128
+
129
+
130
+ def main_worker(args):
131
+
132
+ # CRF model
133
+ model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
134
+ model.train()
135
+
136
+ num_params = sum([np.prod(p.size()) for p in model.parameters()])
137
+ print("== Total number of parameters: {}".format(num_params))
138
+
139
+ num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
140
+ print("== Total number of learning parameters: {}".format(num_params_update))
141
+
142
+ model = torch.nn.DataParallel(model)
143
+ model.cuda()
144
+
145
+ print("== Model Initialized")
146
+
147
+ if args.checkpoint_path != '':
148
+ if os.path.isfile(args.checkpoint_path):
149
+ print("== Loading checkpoint '{}'".format(args.checkpoint_path))
150
+ checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
151
+ model.load_state_dict(checkpoint['model'])
152
+ print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
153
+ del checkpoint
154
+ else:
155
+ print("== No checkpoint found at '{}'".format(args.checkpoint_path))
156
+
157
+ cudnn.benchmark = True
158
+
159
+ dataloader_eval = NewDataLoader(args, 'online_eval')
160
+
161
+ # ===== Evaluation ======
162
+ model.eval()
163
+ with torch.no_grad():
164
+ eval_measures = eval(model, dataloader_eval, post_process=True)
165
+
166
+
167
+ def main():
168
+ torch.cuda.empty_cache()
169
+ args.distributed = False
170
+ ngpus_per_node = torch.cuda.device_count()
171
+ if ngpus_per_node > 1:
172
+ print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
173
+ return -1
174
+
175
+ main_worker(args)
176
+
177
+
178
+ if __name__ == '__main__':
179
+ main()
iebins/inference_single_image.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.backends.cudnn as cudnn
3
+
4
+ import os, sys
5
+ import argparse
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from utils import post_process_depth, flip_lr, compute_errors
10
+ from networks.NewCRFDepth import NewCRFDepth
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+ import matplotlib.pyplot as plt
14
+
15
+
16
+ def convert_arg_line_to_args(arg_line):
17
+ for arg in arg_line.split():
18
+ if not arg.strip():
19
+ continue
20
+ yield arg
21
+
22
+
23
+ parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
24
+ parser.convert_arg_line_to_args = convert_arg_line_to_args
25
+
26
+ parser.add_argument('--model_name', type=str, help='model name', default='iebins')
27
+ parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07', default='large07')
28
+ parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
29
+ parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
30
+ parser.add_argument('--image_path', type=str, help='path to the image for inference', required=False)
31
+ parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
32
+
33
+
34
+ if sys.argv.__len__() == 2:
35
+ arg_filename_with_prefix = '@' + sys.argv[1]
36
+ args = parser.parse_args([arg_filename_with_prefix])
37
+ else:
38
+ args = parser.parse_args()
39
+
40
+
41
+ def inference(model, post_process=False):
42
+
43
+ image = np.asarray(Image.open(args.image_path), dtype=np.float32) / 255.0
44
+
45
+ if args.dataset == 'kitti':
46
+ height = image.shape[0]
47
+ width = image.shape[1]
48
+ top_margin = int(height - 352)
49
+ left_margin = int((width - 1216) / 2)
50
+ image = image[top_margin:top_margin + 352, left_margin:left_margin + 1216, :]
51
+
52
+ image = torch.from_numpy(image.transpose((2, 0, 1)))
53
+ image = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
54
+
55
+ with torch.no_grad():
56
+ image = torch.autograd.Variable(image.unsqueeze(0).cuda())
57
+
58
+ pred_depths_r_list, _, _ = model(image)
59
+ if post_process:
60
+ image_flipped = flip_lr(image)
61
+ pred_depths_r_list_flipped, _, _ = model(image_flipped)
62
+ pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
63
+
64
+ pred_depth = pred_depth.cpu().numpy().squeeze()
65
+
66
+ if args.dataset == 'kitti':
67
+ plt.imsave('depth.png', np.log10(pred_depth), cmap='magma')
68
+ else:
69
+ plt.imsave('depth.png', pred_depth, cmap='jet')
70
+
71
+
72
+ def main_worker(args):
73
+
74
+ model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=None)
75
+ model.train()
76
+
77
+ num_params = sum([np.prod(p.size()) for p in model.parameters()])
78
+ print("== Total number of parameters: {}".format(num_params))
79
+
80
+ num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
81
+ print("== Total number of learning parameters: {}".format(num_params_update))
82
+
83
+ model = torch.nn.DataParallel(model)
84
+ model.cuda()
85
+
86
+ print("== Model Initialized")
87
+
88
+ if args.checkpoint_path != '':
89
+ if os.path.isfile(args.checkpoint_path):
90
+ checkpoint = torch.load(args.checkpoint_path, map_location='cpu')
91
+ model.load_state_dict(checkpoint['model'])
92
+ print("== Loaded checkpoint '{}'".format(args.checkpoint_path))
93
+ del checkpoint
94
+ else:
95
+ print("== No checkpoint found at '{}'".format(args.checkpoint_path))
96
+
97
+ cudnn.benchmark = True
98
+
99
+ # ===== Inference ======
100
+ model.eval()
101
+ with torch.no_grad():
102
+ inference(model, post_process=True)
103
+
104
+
105
+ def main():
106
+ torch.cuda.empty_cache()
107
+ args.distributed = False
108
+ ngpus_per_node = torch.cuda.device_count()
109
+ if ngpus_per_node > 1:
110
+ print("This machine has more than 1 gpu. Please set \'CUDA_VISIBLE_DEVICES=0\'")
111
+ return -1
112
+
113
+ main_worker(args)
114
+
115
+
116
+ if __name__ == '__main__':
117
+ main()
iebins/networks/NewCRFDepth.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .swin_transformer import SwinTransformer
6
+ from .newcrf_layers import NewCRF
7
+ from .uper_crf_head import PSP
8
+ from .depth_update import *
9
+ ########################################################################################################################
10
+
11
+
12
+ class NewCRFDepth(nn.Module):
13
+ """
14
+ Depth network based on neural window FC-CRFs architecture.
15
+ """
16
+ def __init__(self, version=None, inv_depth=False, pretrained=None,
17
+ frozen_stages=-1, min_depth=0.1, max_depth=100.0, **kwargs):
18
+ super().__init__()
19
+
20
+ self.inv_depth = inv_depth
21
+ self.with_auxiliary_head = False
22
+ self.with_neck = False
23
+
24
+ norm_cfg = dict(type='BN', requires_grad=True)
25
+
26
+ window_size = int(version[-2:])
27
+
28
+ if version[:-2] == 'base':
29
+ embed_dim = 128
30
+ depths = [2, 2, 18, 2]
31
+ num_heads = [4, 8, 16, 32]
32
+ in_channels = [128, 256, 512, 1024]
33
+ self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=128)
34
+ elif version[:-2] == 'large':
35
+ embed_dim = 192
36
+ depths = [2, 2, 18, 2]
37
+ num_heads = [6, 12, 24, 48]
38
+ in_channels = [192, 384, 768, 1536]
39
+ self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=192)
40
+ elif version[:-2] == 'tiny':
41
+ embed_dim = 96
42
+ depths = [2, 2, 6, 2]
43
+ num_heads = [3, 6, 12, 24]
44
+ in_channels = [96, 192, 384, 768]
45
+ self.update = BasicUpdateBlockDepth(hidden_dim=128, context_dim=96)
46
+
47
+ backbone_cfg = dict(
48
+ embed_dim=embed_dim,
49
+ depths=depths,
50
+ num_heads=num_heads,
51
+ window_size=window_size,
52
+ ape=False,
53
+ drop_path_rate=0.3,
54
+ patch_norm=True,
55
+ use_checkpoint=False,
56
+ frozen_stages=frozen_stages
57
+ )
58
+
59
+ embed_dim = 512
60
+ decoder_cfg = dict(
61
+ in_channels=in_channels,
62
+ in_index=[0, 1, 2, 3],
63
+ pool_scales=(1, 2, 3, 6),
64
+ channels=embed_dim,
65
+ dropout_ratio=0.0,
66
+ num_classes=32,
67
+ norm_cfg=norm_cfg,
68
+ align_corners=False
69
+ )
70
+
71
+ self.backbone = SwinTransformer(**backbone_cfg)
72
+ v_dim = decoder_cfg['num_classes']*4
73
+ win = 7
74
+ crf_dims = [128, 256, 512, 1024]
75
+ v_dims = [64, 128, 256, embed_dim]
76
+ self.crf3 = NewCRF(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, v_dim=v_dims[3], num_heads=32)
77
+ self.crf2 = NewCRF(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, v_dim=v_dims[2], num_heads=16)
78
+ self.crf1 = NewCRF(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, v_dim=v_dims[1], num_heads=8)
79
+
80
+ self.decoder = PSP(**decoder_cfg)
81
+ self.disp_head1 = DispHead(input_dim=crf_dims[0])
82
+
83
+ self.up_mode = 'bilinear'
84
+ if self.up_mode == 'mask':
85
+ self.mask_head = nn.Sequential(
86
+ nn.Conv2d(v_dims[0], 64, 3, padding=1),
87
+ nn.ReLU(inplace=True),
88
+ nn.Conv2d(64, 16*9, 1, padding=0))
89
+
90
+ self.min_depth = min_depth
91
+ self.max_depth = max_depth
92
+ self.depth_num = 16
93
+ self.hidden_dim = 128
94
+ self.project = Projection(v_dims[0], self.hidden_dim)
95
+
96
+ self.init_weights(pretrained=pretrained)
97
+
98
+ def init_weights(self, pretrained=None):
99
+ """Initialize the weights in backbone and heads.
100
+
101
+ Args:
102
+ pretrained (str, optional): Path to pre-trained weights.
103
+ Defaults to None.
104
+ """
105
+ print(f'== Load encoder backbone from: {pretrained}')
106
+ self.backbone.init_weights(pretrained=pretrained)
107
+ self.decoder.init_weights()
108
+ if self.with_auxiliary_head:
109
+ if isinstance(self.auxiliary_head, nn.ModuleList):
110
+ for aux_head in self.auxiliary_head:
111
+ aux_head.init_weights()
112
+ else:
113
+ self.auxiliary_head.init_weights()
114
+
115
+ def upsample_mask(self, disp, mask):
116
+ """ Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """
117
+ N, C, H, W = disp.shape
118
+ mask = mask.view(N, 1, 9, 4, 4, H, W)
119
+ mask = torch.softmax(mask, dim=2)
120
+
121
+ up_disp = F.unfold(disp, kernel_size=3, padding=1)
122
+ up_disp = up_disp.view(N, C, 9, 1, 1, H, W)
123
+
124
+ up_disp = torch.sum(mask * up_disp, dim=2)
125
+ up_disp = up_disp.permute(0, 1, 4, 2, 5, 3)
126
+ return up_disp.reshape(N, C, 4*H, 4*W)
127
+
128
+ def forward(self, imgs, epoch=1, step=100):
129
+
130
+ feats = self.backbone(imgs)
131
+ ppm_out = self.decoder(feats)
132
+
133
+ e3 = self.crf3(feats[3], ppm_out)
134
+ e3 = nn.PixelShuffle(2)(e3)
135
+ e2 = self.crf2(feats[2], e3)
136
+ e2 = nn.PixelShuffle(2)(e2)
137
+ e1 = self.crf1(feats[1], e2)
138
+ e1 = nn.PixelShuffle(2)(e1)
139
+
140
+ # iterative bins
141
+ if epoch == 0 and step < 80:
142
+ max_tree_depth = 3
143
+ else:
144
+ max_tree_depth = 6
145
+
146
+ if self.up_mode == 'mask':
147
+ mask = self.mask_head(e1)
148
+
149
+ b, c, h, w = e1.shape
150
+ device = e1.device
151
+
152
+ depth = torch.zeros([b, 1, h, w]).to(device)
153
+ context = feats[0]
154
+ gru_hidden = torch.tanh(self.project(e1))
155
+ pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list = self.update(depth, context, gru_hidden, max_tree_depth, self.depth_num, self.min_depth, self.max_depth)
156
+
157
+ if self.up_mode == 'mask':
158
+ for i in range(len(pred_depths_r_list)):
159
+ pred_depths_r_list[i] = self.upsample_mask(pred_depths_r_list[i], mask)
160
+ for i in range(len(pred_depths_c_list)):
161
+ pred_depths_c_list[i] = self.upsample_mask(pred_depths_c_list[i], mask.detach())
162
+ for i in range(len(uncertainty_maps_list)):
163
+ uncertainty_maps_list[i] = self.upsample_mask(uncertainty_maps_list[i], mask.detach())
164
+ else:
165
+ for i in range(len(pred_depths_r_list)):
166
+ pred_depths_r_list[i] = upsample(pred_depths_r_list[i], scale_factor=4)
167
+ for i in range(len(pred_depths_c_list)):
168
+ pred_depths_c_list[i] = upsample(pred_depths_c_list[i], scale_factor=4)
169
+ for i in range(len(uncertainty_maps_list)):
170
+ uncertainty_maps_list[i] = upsample(uncertainty_maps_list[i], scale_factor=4)
171
+
172
+ return pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list
173
+
174
+ class DispHead(nn.Module):
175
+ def __init__(self, input_dim=100):
176
+ super(DispHead, self).__init__()
177
+ # self.norm1 = nn.BatchNorm2d(input_dim)
178
+ self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1)
179
+ # self.relu = nn.ReLU(inplace=True)
180
+ self.sigmoid = nn.Sigmoid()
181
+
182
+ def forward(self, x, scale):
183
+ # x = self.relu(self.norm1(x))
184
+ x = self.sigmoid(self.conv1(x))
185
+ if scale > 1:
186
+ x = upsample(x, scale_factor=scale)
187
+ return x
188
+
189
+ class BasicUpdateBlockDepth(nn.Module):
190
+ def __init__(self, hidden_dim=128, context_dim=192):
191
+ super(BasicUpdateBlockDepth, self).__init__()
192
+
193
+ self.encoder = ProjectionInputDepth(hidden_dim=hidden_dim, out_chs=hidden_dim * 2)
194
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=self.encoder.out_chs+context_dim)
195
+ self.p_head = PHead(hidden_dim, hidden_dim)
196
+
197
+ def forward(self, depth, context, gru_hidden, seq_len, depth_num, min_depth, max_depth):
198
+
199
+ pred_depths_r_list = []
200
+ pred_depths_c_list = []
201
+ uncertainty_maps_list = []
202
+
203
+ b, _, h, w = depth.shape
204
+ depth_range = max_depth - min_depth
205
+ interval = depth_range / depth_num
206
+ interval = interval * torch.ones_like(depth)
207
+ interval = interval.repeat(1, depth_num, 1, 1)
208
+ interval = torch.cat([torch.ones_like(depth) * min_depth, interval], 1)
209
+
210
+ bin_edges = torch.cumsum(interval, 1)
211
+ current_depths = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
212
+ index_iter = 0
213
+
214
+ for i in range(seq_len):
215
+ input_features = self.encoder(current_depths.detach())
216
+ input_c = torch.cat([input_features, context], dim=1)
217
+
218
+ gru_hidden = self.gru(gru_hidden, input_c)
219
+ pred_prob = self.p_head(gru_hidden)
220
+
221
+ depth_r = (pred_prob * current_depths.detach()).sum(1, keepdim=True)
222
+ pred_depths_r_list.append(depth_r)
223
+
224
+ uncertainty_map = torch.sqrt((pred_prob * ((current_depths.detach() - depth_r.repeat(1, depth_num, 1, 1))**2)).sum(1, keepdim=True))
225
+ uncertainty_maps_list.append(uncertainty_map)
226
+
227
+ index_iter = index_iter + 1
228
+
229
+ pred_label = get_label(torch.squeeze(depth_r, 1), bin_edges, depth_num).unsqueeze(1)
230
+ depth_c = torch.gather(current_depths.detach(), 1, pred_label.detach())
231
+ pred_depths_c_list.append(depth_c)
232
+
233
+ label_target_bin_left = pred_label
234
+ target_bin_left = torch.gather(bin_edges, 1, label_target_bin_left)
235
+ label_target_bin_right = (pred_label.float() + 1).long()
236
+ target_bin_right = torch.gather(bin_edges, 1, label_target_bin_right)
237
+
238
+ bin_edges, current_depths = update_sample(bin_edges, target_bin_left, target_bin_right, depth_r.detach(), pred_label.detach(), depth_num, min_depth, max_depth, uncertainty_map)
239
+
240
+ return pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list
241
+
242
+ class PHead(nn.Module):
243
+ def __init__(self, input_dim=128, hidden_dim=128):
244
+ super(PHead, self).__init__()
245
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
246
+ self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1)
247
+
248
+ def forward(self, x):
249
+ out = torch.softmax(self.conv2(F.relu(self.conv1(x))), 1)
250
+ return out
251
+
252
+ class SepConvGRU(nn.Module):
253
+ def __init__(self, hidden_dim=128, input_dim=128+192):
254
+ super(SepConvGRU, self).__init__()
255
+
256
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
257
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
258
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
259
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
260
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
261
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
262
+
263
+ def forward(self, h, x):
264
+ # horizontal
265
+ hx = torch.cat([h, x], dim=1)
266
+ z = torch.sigmoid(self.convz1(hx))
267
+ r = torch.sigmoid(self.convr1(hx))
268
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
269
+
270
+ h = (1-z) * h + z * q
271
+
272
+ # vertical
273
+ hx = torch.cat([h, x], dim=1)
274
+ z = torch.sigmoid(self.convz2(hx))
275
+ r = torch.sigmoid(self.convr2(hx))
276
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
277
+ h = (1-z) * h + z * q
278
+
279
+ return h
280
+
281
+ class ProjectionInputDepth(nn.Module):
282
+ def __init__(self, hidden_dim, out_chs):
283
+ super().__init__()
284
+ self.out_chs = out_chs
285
+ self.convd1 = nn.Conv2d(16, hidden_dim, 7, padding=3)
286
+ self.convd2 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
287
+ self.convd3 = nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1)
288
+ self.convd4 = nn.Conv2d(hidden_dim, out_chs, 3, padding=1)
289
+
290
+ def forward(self, depth):
291
+ d = F.relu(self.convd1(depth))
292
+ d = F.relu(self.convd2(d))
293
+ d = F.relu(self.convd3(d))
294
+ d = F.relu(self.convd4(d))
295
+
296
+ return d
297
+
298
+ class Projection(nn.Module):
299
+ def __init__(self, in_chs, out_chs):
300
+ super().__init__()
301
+ self.conv = nn.Conv2d(in_chs, out_chs, 3, padding=1)
302
+
303
+ def forward(self, x):
304
+ out = self.conv(x)
305
+
306
+ return out
307
+
308
+ def upsample(x, scale_factor=2, mode="bilinear", align_corners=False):
309
+ """Upsample input tensor by a factor of 2
310
+ """
311
+ return F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners)
312
+
313
+ def upsample1(x, scale_factor=2, mode="bilinear"):
314
+ """Upsample input tensor by a factor of 2
315
+ """
316
+ return F.interpolate(x, scale_factor=scale_factor, mode=mode)
317
+
318
+
iebins/networks/__init__.py ADDED
File without changes
iebins/networks/depth_update.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import copy
4
+
5
+ def update_sample(bin_edges, target_bin_left, target_bin_right, depth_r, pred_label, depth_num, min_depth, max_depth, uncertainty_range):
6
+
7
+ with torch.no_grad():
8
+ b, _, h, w = bin_edges.shape
9
+
10
+ mode = 'direct'
11
+ if mode == 'direct':
12
+ depth_range = uncertainty_range
13
+ depth_start_update = torch.clamp_min(depth_r - 0.5 * depth_range, min_depth)
14
+ else:
15
+ depth_range = uncertainty_range + (target_bin_right - target_bin_left).abs()
16
+ depth_start_update = torch.clamp_min(target_bin_left - 0.5 * uncertainty_range, min_depth)
17
+
18
+ interval = depth_range / depth_num
19
+ interval = interval.repeat(1, depth_num, 1, 1)
20
+ interval = torch.cat([torch.ones([b, 1, h, w], device=bin_edges.device) * depth_start_update, interval], 1)
21
+
22
+ bin_edges = torch.cumsum(interval, 1).clamp(min_depth, max_depth)
23
+ curr_depth = 0.5 * (bin_edges[:, :-1] + bin_edges[:, 1:])
24
+
25
+ return bin_edges.detach(), curr_depth.detach()
26
+
27
+ def get_label(gt_depth_img, bin_edges, depth_num):
28
+
29
+ with torch.no_grad():
30
+ gt_label = torch.zeros(gt_depth_img.size(), dtype=torch.int64, device=gt_depth_img.device)
31
+ for i in range(depth_num):
32
+ bin_mask = torch.ge(gt_depth_img, bin_edges[:, i])
33
+ bin_mask = torch.logical_and(bin_mask,
34
+ torch.lt(gt_depth_img, bin_edges[:, i + 1]))
35
+ gt_label[bin_mask] = i
36
+
37
+ return gt_label
38
+
39
+
iebins/networks/newcrf_layers.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint as checkpoint
5
+ import numpy as np
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+
8
+
9
+ class Mlp(nn.Module):
10
+ """ Multilayer perceptron."""
11
+
12
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
13
+ super().__init__()
14
+ out_features = out_features or in_features
15
+ hidden_features = hidden_features or in_features
16
+ self.fc1 = nn.Linear(in_features, hidden_features)
17
+ self.act = act_layer()
18
+ self.fc2 = nn.Linear(hidden_features, out_features)
19
+ self.drop = nn.Dropout(drop)
20
+
21
+ def forward(self, x):
22
+ x = self.fc1(x)
23
+ x = self.act(x)
24
+ x = self.drop(x)
25
+ x = self.fc2(x)
26
+ x = self.drop(x)
27
+ return x
28
+
29
+
30
+ def window_partition(x, window_size):
31
+ """
32
+ Args:
33
+ x: (B, H, W, C)
34
+ window_size (int): window size
35
+
36
+ Returns:
37
+ windows: (num_windows*B, window_size, window_size, C)
38
+ """
39
+ B, H, W, C = x.shape
40
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
41
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
42
+ return windows
43
+
44
+
45
+ def window_reverse(windows, window_size, H, W):
46
+ """
47
+ Args:
48
+ windows: (num_windows*B, window_size, window_size, C)
49
+ window_size (int): Window size
50
+ H (int): Height of image
51
+ W (int): Width of image
52
+
53
+ Returns:
54
+ x: (B, H, W, C)
55
+ """
56
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
57
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
58
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
59
+ return x
60
+
61
+
62
+ class WindowAttention(nn.Module):
63
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
64
+ It supports both of shifted and non-shifted window.
65
+
66
+ Args:
67
+ dim (int): Number of input channels.
68
+ window_size (tuple[int]): The height and width of the window.
69
+ num_heads (int): Number of attention heads.
70
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
71
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
72
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
73
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
74
+ """
75
+
76
+ def __init__(self, dim, window_size, num_heads, v_dim, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
77
+
78
+ super().__init__()
79
+ self.dim = dim
80
+ self.window_size = window_size # Wh, Ww
81
+ self.num_heads = num_heads
82
+ head_dim = dim // num_heads
83
+ self.scale = qk_scale or head_dim ** -0.5
84
+
85
+ # define a parameter table of relative position bias
86
+ self.relative_position_bias_table = nn.Parameter(
87
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
88
+
89
+ # get pair-wise relative position index for each token inside the window
90
+ coords_h = torch.arange(self.window_size[0])
91
+ coords_w = torch.arange(self.window_size[1])
92
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
93
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
94
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
95
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
96
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
97
+ relative_coords[:, :, 1] += self.window_size[1] - 1
98
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
99
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
100
+ self.register_buffer("relative_position_index", relative_position_index)
101
+
102
+ self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
103
+ self.attn_drop = nn.Dropout(attn_drop)
104
+ self.proj = nn.Linear(v_dim, v_dim)
105
+ self.proj_drop = nn.Dropout(proj_drop)
106
+
107
+ trunc_normal_(self.relative_position_bias_table, std=.02)
108
+ self.softmax = nn.Softmax(dim=-1)
109
+
110
+ def forward(self, x, v, mask=None):
111
+ """ Forward function.
112
+
113
+ Args:
114
+ x: input features with shape of (num_windows*B, N, C)
115
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
116
+ """
117
+ B_, N, C = x.shape
118
+ qk = self.qk(x).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
119
+ q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
120
+
121
+ q = q * self.scale
122
+ attn = (q @ k.transpose(-2, -1))
123
+
124
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
125
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
126
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
127
+ attn = attn + relative_position_bias.unsqueeze(0)
128
+
129
+ if mask is not None:
130
+ nW = mask.shape[0]
131
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
132
+ attn = attn.view(-1, self.num_heads, N, N)
133
+ attn = self.softmax(attn)
134
+ else:
135
+ attn = self.softmax(attn)
136
+
137
+ attn = self.attn_drop(attn)
138
+
139
+ # assert self.dim % v.shape[-1] == 0, "self.dim % v.shape[-1] != 0"
140
+ # repeat_num = self.dim // v.shape[-1]
141
+ # v = v.view(B_, N, self.num_heads // repeat_num, -1).transpose(1, 2).repeat(1, repeat_num, 1, 1)
142
+
143
+ assert self.dim == v.shape[-1], "self.dim != v.shape[-1]"
144
+ v = v.view(B_, N, self.num_heads, -1).transpose(1, 2)
145
+
146
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
147
+ x = self.proj(x)
148
+ x = self.proj_drop(x)
149
+ return x
150
+
151
+
152
+ class CRFBlock(nn.Module):
153
+ """ CRF Block.
154
+
155
+ Args:
156
+ dim (int): Number of input channels.
157
+ num_heads (int): Number of attention heads.
158
+ window_size (int): Window size.
159
+ shift_size (int): Shift size for SW-MSA.
160
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
161
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
162
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
163
+ drop (float, optional): Dropout rate. Default: 0.0
164
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
165
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
166
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
167
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
168
+ """
169
+
170
+ def __init__(self, dim, num_heads, v_dim, window_size=7, shift_size=0,
171
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
172
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
173
+ super().__init__()
174
+ self.dim = dim
175
+ self.num_heads = num_heads
176
+ self.v_dim = v_dim
177
+ self.window_size = window_size
178
+ self.shift_size = shift_size
179
+ self.mlp_ratio = mlp_ratio
180
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
181
+
182
+ self.norm1 = norm_layer(dim)
183
+ self.attn = WindowAttention(
184
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, v_dim=v_dim,
185
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
186
+
187
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
188
+ self.norm2 = norm_layer(v_dim)
189
+ mlp_hidden_dim = int(v_dim * mlp_ratio)
190
+ self.mlp = Mlp(in_features=v_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
191
+
192
+ self.H = None
193
+ self.W = None
194
+
195
+ def forward(self, x, v, mask_matrix):
196
+ """ Forward function.
197
+
198
+ Args:
199
+ x: Input feature, tensor size (B, H*W, C).
200
+ H, W: Spatial resolution of the input feature.
201
+ mask_matrix: Attention mask for cyclic shift.
202
+ """
203
+ B, L, C = x.shape
204
+ H, W = self.H, self.W
205
+ assert L == H * W, "input feature has wrong size"
206
+
207
+ shortcut = x
208
+ x = self.norm1(x)
209
+ x = x.view(B, H, W, C)
210
+
211
+ # pad feature maps to multiples of window size
212
+ pad_l = pad_t = 0
213
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
214
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
215
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
216
+ v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b))
217
+ _, Hp, Wp, _ = x.shape
218
+
219
+ # cyclic shift
220
+ if self.shift_size > 0:
221
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
222
+ shifted_v = torch.roll(v, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
223
+ attn_mask = mask_matrix
224
+ else:
225
+ shifted_x = x
226
+ shifted_v = v
227
+ attn_mask = None
228
+
229
+ # partition windows
230
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
231
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
232
+ v_windows = window_partition(shifted_v, self.window_size) # nW*B, window_size, window_size, C
233
+ v_windows = v_windows.view(-1, self.window_size * self.window_size, v_windows.shape[-1]) # nW*B, window_size*window_size, C
234
+
235
+ # W-MSA/SW-MSA
236
+ attn_windows = self.attn(x_windows, v_windows, mask=attn_mask) # nW*B, window_size*window_size, C
237
+
238
+ # merge windows
239
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.v_dim)
240
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
241
+
242
+ # reverse cyclic shift
243
+ if self.shift_size > 0:
244
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
245
+ else:
246
+ x = shifted_x
247
+
248
+ if pad_r > 0 or pad_b > 0:
249
+ x = x[:, :H, :W, :].contiguous()
250
+
251
+ x = x.view(B, H * W, self.v_dim)
252
+
253
+ # FFN
254
+ x = shortcut + self.drop_path(x)
255
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
256
+
257
+ return x
258
+
259
+
260
+ class BasicCRFLayer(nn.Module):
261
+ """ A basic NeWCRFs layer for one stage.
262
+
263
+ Args:
264
+ dim (int): Number of feature channels
265
+ depth (int): Depths of this stage.
266
+ num_heads (int): Number of attention head.
267
+ window_size (int): Local window size. Default: 7.
268
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
269
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
270
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
271
+ drop (float, optional): Dropout rate. Default: 0.0
272
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
273
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
274
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
275
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
276
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
277
+ """
278
+
279
+ def __init__(self,
280
+ dim,
281
+ depth,
282
+ num_heads,
283
+ v_dim,
284
+ window_size=7,
285
+ mlp_ratio=4.,
286
+ qkv_bias=True,
287
+ qk_scale=None,
288
+ drop=0.,
289
+ attn_drop=0.,
290
+ drop_path=0.,
291
+ norm_layer=nn.LayerNorm,
292
+ downsample=None,
293
+ use_checkpoint=False):
294
+ super().__init__()
295
+ self.window_size = window_size
296
+ self.shift_size = window_size // 2
297
+ self.depth = depth
298
+ self.use_checkpoint = use_checkpoint
299
+
300
+ # build blocks
301
+ self.blocks = nn.ModuleList([
302
+ CRFBlock(
303
+ dim=dim,
304
+ num_heads=num_heads,
305
+ v_dim=v_dim,
306
+ window_size=window_size,
307
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
308
+ mlp_ratio=mlp_ratio,
309
+ qkv_bias=qkv_bias,
310
+ qk_scale=qk_scale,
311
+ drop=drop,
312
+ attn_drop=attn_drop,
313
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
314
+ norm_layer=norm_layer)
315
+ for i in range(depth)])
316
+
317
+ # patch merging layer
318
+ if downsample is not None:
319
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
320
+ else:
321
+ self.downsample = None
322
+
323
+ def forward(self, x, v, H, W):
324
+ """ Forward function.
325
+
326
+ Args:
327
+ x: Input feature, tensor size (B, H*W, C).
328
+ H, W: Spatial resolution of the input feature.
329
+ """
330
+
331
+ # calculate attention mask for SW-MSA
332
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
333
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
334
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
335
+ h_slices = (slice(0, -self.window_size),
336
+ slice(-self.window_size, -self.shift_size),
337
+ slice(-self.shift_size, None))
338
+ w_slices = (slice(0, -self.window_size),
339
+ slice(-self.window_size, -self.shift_size),
340
+ slice(-self.shift_size, None))
341
+ cnt = 0
342
+ for h in h_slices:
343
+ for w in w_slices:
344
+ img_mask[:, h, w, :] = cnt
345
+ cnt += 1
346
+
347
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
348
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
349
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
350
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
351
+
352
+ for blk in self.blocks:
353
+ blk.H, blk.W = H, W
354
+ if self.use_checkpoint:
355
+ x = checkpoint.checkpoint(blk, x, attn_mask)
356
+ else:
357
+ x = blk(x, v, attn_mask)
358
+ if self.downsample is not None:
359
+ x_down = self.downsample(x, H, W)
360
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
361
+ return x, H, W, x_down, Wh, Ww
362
+ else:
363
+ return x, H, W, x, H, W
364
+
365
+
366
+ class NewCRF(nn.Module):
367
+ def __init__(self,
368
+ input_dim=96,
369
+ embed_dim=96,
370
+ v_dim=64,
371
+ window_size=7,
372
+ num_heads=4,
373
+ depth=2,
374
+ patch_size=4,
375
+ in_chans=3,
376
+ norm_layer=nn.LayerNorm,
377
+ patch_norm=True):
378
+ super().__init__()
379
+
380
+ self.embed_dim = embed_dim
381
+ self.patch_norm = patch_norm
382
+
383
+ if input_dim != embed_dim:
384
+ self.proj_x = nn.Conv2d(input_dim, embed_dim, 3, padding=1)
385
+ else:
386
+ self.proj_x = None
387
+
388
+ if v_dim != embed_dim:
389
+ self.proj_v = nn.Conv2d(v_dim, embed_dim, 3, padding=1)
390
+ elif embed_dim % v_dim == 0:
391
+ self.proj_v = None
392
+
393
+ # For now, v_dim need to be equal to embed_dim, because the output of window-attn is the input of shift-window-attn
394
+ v_dim = embed_dim
395
+ assert v_dim == embed_dim
396
+
397
+ self.crf_layer = BasicCRFLayer(
398
+ dim=embed_dim,
399
+ depth=depth,
400
+ num_heads=num_heads,
401
+ v_dim=v_dim,
402
+ window_size=window_size,
403
+ mlp_ratio=4.,
404
+ qkv_bias=True,
405
+ qk_scale=None,
406
+ drop=0.,
407
+ attn_drop=0.,
408
+ drop_path=0.,
409
+ norm_layer=norm_layer,
410
+ downsample=None,
411
+ use_checkpoint=False)
412
+
413
+ layer = norm_layer(embed_dim)
414
+ layer_name = 'norm_crf'
415
+ self.add_module(layer_name, layer)
416
+
417
+
418
+ def forward(self, x, v):
419
+ if self.proj_x is not None:
420
+ x = self.proj_x(x)
421
+ if self.proj_v is not None:
422
+ v = self.proj_v(v)
423
+
424
+ Wh, Ww = x.size(2), x.size(3)
425
+ x = x.flatten(2).transpose(1, 2)
426
+ v = v.transpose(1, 2).transpose(2, 3)
427
+
428
+ x_out, H, W, x, Wh, Ww = self.crf_layer(x, v, Wh, Ww)
429
+ norm_layer = getattr(self, f'norm_crf')
430
+ x_out = norm_layer(x_out)
431
+ out = x_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, 2).contiguous()
432
+
433
+ return out
iebins/networks/newcrf_utils.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import os
3
+ import os.path as osp
4
+ import pkgutil
5
+ import warnings
6
+ from collections import OrderedDict
7
+ from importlib import import_module
8
+
9
+ import torch
10
+ import torchvision
11
+ import torch.nn as nn
12
+ from torch.utils import model_zoo
13
+ from torch.nn import functional as F
14
+ from torch.nn.parallel import DataParallel, DistributedDataParallel
15
+ from torch import distributed as dist
16
+
17
+ TORCH_VERSION = torch.__version__
18
+
19
+
20
+ def resize(input,
21
+ size=None,
22
+ scale_factor=None,
23
+ mode='nearest',
24
+ align_corners=None,
25
+ warning=True):
26
+ if warning:
27
+ if size is not None and align_corners:
28
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
29
+ output_h, output_w = tuple(int(x) for x in size)
30
+ if output_h > input_h or output_w > output_h:
31
+ if ((output_h > 1 and output_w > 1 and input_h > 1
32
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
33
+ and (output_w - 1) % (input_w - 1)):
34
+ warnings.warn(
35
+ f'When align_corners={align_corners}, '
36
+ 'the output would more aligned if '
37
+ f'input size {(input_h, input_w)} is `x+1` and '
38
+ f'out size {(output_h, output_w)} is `nx+1`')
39
+ if isinstance(size, torch.Size):
40
+ size = tuple(int(x) for x in size)
41
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
42
+
43
+
44
+ def normal_init(module, mean=0, std=1, bias=0):
45
+ if hasattr(module, 'weight') and module.weight is not None:
46
+ nn.init.normal_(module.weight, mean, std)
47
+ if hasattr(module, 'bias') and module.bias is not None:
48
+ nn.init.constant_(module.bias, bias)
49
+
50
+
51
+ def is_module_wrapper(module):
52
+ module_wrappers = (DataParallel, DistributedDataParallel)
53
+ return isinstance(module, module_wrappers)
54
+
55
+
56
+ def get_dist_info():
57
+ if TORCH_VERSION < '1.0':
58
+ initialized = dist._initialized
59
+ else:
60
+ if dist.is_available():
61
+ initialized = dist.is_initialized()
62
+ else:
63
+ initialized = False
64
+ if initialized:
65
+ rank = dist.get_rank()
66
+ world_size = dist.get_world_size()
67
+ else:
68
+ rank = 0
69
+ world_size = 1
70
+ return rank, world_size
71
+
72
+
73
+ def load_state_dict(module, state_dict, strict=False, logger=None):
74
+ """Load state_dict to a module.
75
+
76
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
77
+ Default value for ``strict`` is set to ``False`` and the message for
78
+ param mismatch will be shown even if strict is False.
79
+
80
+ Args:
81
+ module (Module): Module that receives the state_dict.
82
+ state_dict (OrderedDict): Weights.
83
+ strict (bool): whether to strictly enforce that the keys
84
+ in :attr:`state_dict` match the keys returned by this module's
85
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
86
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
87
+ message. If not specified, print function will be used.
88
+ """
89
+ unexpected_keys = []
90
+ all_missing_keys = []
91
+ err_msg = []
92
+
93
+ metadata = getattr(state_dict, '_metadata', None)
94
+ state_dict = state_dict.copy()
95
+ if metadata is not None:
96
+ state_dict._metadata = metadata
97
+
98
+ # use _load_from_state_dict to enable checkpoint version control
99
+ def load(module, prefix=''):
100
+ # recursively check parallel module in case that the model has a
101
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
102
+ if is_module_wrapper(module):
103
+ module = module.module
104
+ local_metadata = {} if metadata is None else metadata.get(
105
+ prefix[:-1], {})
106
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
107
+ all_missing_keys, unexpected_keys,
108
+ err_msg)
109
+ for name, child in module._modules.items():
110
+ if child is not None:
111
+ load(child, prefix + name + '.')
112
+
113
+ load(module)
114
+ load = None # break load->load reference cycle
115
+
116
+ # ignore "num_batches_tracked" of BN layers
117
+ missing_keys = [
118
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
119
+ ]
120
+
121
+ if unexpected_keys:
122
+ err_msg.append('unexpected key in source '
123
+ f'state_dict: {", ".join(unexpected_keys)}\n')
124
+ if missing_keys:
125
+ err_msg.append(
126
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
127
+
128
+ rank, _ = get_dist_info()
129
+ if len(err_msg) > 0 and rank == 0:
130
+ err_msg.insert(
131
+ 0, 'The model and loaded state dict do not match exactly\n')
132
+ err_msg = '\n'.join(err_msg)
133
+ if strict:
134
+ raise RuntimeError(err_msg)
135
+ elif logger is not None:
136
+ logger.warning(err_msg)
137
+ else:
138
+ print(err_msg)
139
+
140
+
141
+ def load_url_dist(url, model_dir=None):
142
+ """In distributed setting, this function only download checkpoint at local
143
+ rank 0."""
144
+ rank, world_size = get_dist_info()
145
+ rank = int(os.environ.get('LOCAL_RANK', rank))
146
+ if rank == 0:
147
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
148
+ if world_size > 1:
149
+ torch.distributed.barrier()
150
+ if rank > 0:
151
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
152
+ return checkpoint
153
+
154
+
155
+ def get_torchvision_models():
156
+ model_urls = dict()
157
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
158
+ if ispkg:
159
+ continue
160
+ _zoo = import_module(f'torchvision.models.{name}')
161
+ if hasattr(_zoo, 'model_urls'):
162
+ _urls = getattr(_zoo, 'model_urls')
163
+ model_urls.update(_urls)
164
+ return model_urls
165
+
166
+
167
+ def _load_checkpoint(filename, map_location=None):
168
+ """Load checkpoint from somewhere (modelzoo, file, url).
169
+
170
+ Args:
171
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
172
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
173
+ details.
174
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
175
+
176
+ Returns:
177
+ dict | OrderedDict: The loaded checkpoint. It can be either an
178
+ OrderedDict storing model weights or a dict containing other
179
+ information, which depends on the checkpoint.
180
+ """
181
+ if filename.startswith('modelzoo://'):
182
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
183
+ 'use "torchvision://" instead')
184
+ model_urls = get_torchvision_models()
185
+ model_name = filename[11:]
186
+ checkpoint = load_url_dist(model_urls[model_name])
187
+ else:
188
+ if not osp.isfile(filename):
189
+ raise IOError(f'{filename} is not a checkpoint file')
190
+ checkpoint = torch.load(filename, map_location=map_location)
191
+ return checkpoint
192
+
193
+
194
+ def load_checkpoint(model,
195
+ filename,
196
+ map_location='cpu',
197
+ strict=False,
198
+ logger=None):
199
+ """Load checkpoint from a file or URI.
200
+
201
+ Args:
202
+ model (Module): Module to load checkpoint.
203
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
204
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
205
+ details.
206
+ map_location (str): Same as :func:`torch.load`.
207
+ strict (bool): Whether to allow different params for the model and
208
+ checkpoint.
209
+ logger (:mod:`logging.Logger` or None): The logger for error message.
210
+
211
+ Returns:
212
+ dict or OrderedDict: The loaded checkpoint.
213
+ """
214
+ checkpoint = _load_checkpoint(filename, map_location)
215
+ # OrderedDict is a subclass of dict
216
+ if not isinstance(checkpoint, dict):
217
+ raise RuntimeError(
218
+ f'No state_dict found in checkpoint file {filename}')
219
+ # get state_dict from checkpoint
220
+ if 'state_dict' in checkpoint:
221
+ state_dict = checkpoint['state_dict']
222
+ elif 'model' in checkpoint:
223
+ state_dict = checkpoint['model']
224
+ else:
225
+ state_dict = checkpoint
226
+ # strip prefix of state_dict
227
+ if list(state_dict.keys())[0].startswith('module.'):
228
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
229
+
230
+ # for MoBY, load model of online branch
231
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
232
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
233
+
234
+ # reshape absolute position embedding
235
+ if state_dict.get('absolute_pos_embed') is not None:
236
+ absolute_pos_embed = state_dict['absolute_pos_embed']
237
+ N1, L, C1 = absolute_pos_embed.size()
238
+ N2, C2, H, W = model.absolute_pos_embed.size()
239
+ if N1 != N2 or C1 != C2 or L != H*W:
240
+ logger.warning("Error in loading absolute_pos_embed, pass")
241
+ else:
242
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
243
+
244
+ # interpolate position bias table if needed
245
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
246
+ for table_key in relative_position_bias_table_keys:
247
+ table_pretrained = state_dict[table_key]
248
+ table_current = model.state_dict()[table_key]
249
+ L1, nH1 = table_pretrained.size()
250
+ L2, nH2 = table_current.size()
251
+ if nH1 != nH2:
252
+ logger.warning(f"Error in loading {table_key}, pass")
253
+ else:
254
+ if L1 != L2:
255
+ S1 = int(L1 ** 0.5)
256
+ S2 = int(L2 ** 0.5)
257
+ table_pretrained_resized = F.interpolate(
258
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
259
+ size=(S2, S2), mode='bicubic')
260
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
261
+
262
+ # load state_dict
263
+ load_state_dict(model, state_dict, strict, logger)
264
+ return checkpoint
iebins/networks/resize.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def resize(input,
9
+ size=None,
10
+ scale_factor=None,
11
+ mode='nearest',
12
+ align_corners=None,
13
+ warning=False):
14
+ if warning:
15
+ if size is not None and align_corners:
16
+ input_h, input_w = tuple(int(x) for x in input.shape[2:])
17
+ output_h, output_w = tuple(int(x) for x in size)
18
+ if output_h > input_h or output_w > output_h:
19
+ if ((output_h > 1 and output_w > 1 and input_h > 1
20
+ and input_w > 1) and (output_h - 1) % (input_h - 1)
21
+ and (output_w - 1) % (input_w - 1)):
22
+ warnings.warn(
23
+ f'When align_corners={align_corners}, '
24
+ 'the output would more aligned if '
25
+ f'input size {(input_h, input_w)} is `x+1` and '
26
+ f'out size {(output_h, output_w)} is `nx+1`')
27
+ return F.interpolate(input, size, scale_factor, mode, align_corners)
28
+
29
+
30
+ class Upsample(nn.Module):
31
+
32
+ def __init__(self,
33
+ size=None,
34
+ scale_factor=None,
35
+ mode='nearest',
36
+ align_corners=None):
37
+ super(Upsample, self).__init__()
38
+ self.size = size
39
+ if isinstance(scale_factor, tuple):
40
+ self.scale_factor = tuple(float(factor) for factor in scale_factor)
41
+ else:
42
+ self.scale_factor = float(scale_factor) if scale_factor else None
43
+ self.mode = mode
44
+ self.align_corners = align_corners
45
+
46
+ def forward(self, x):
47
+ if not self.size:
48
+ size = [int(t * self.scale_factor) for t in x.shape[-2:]]
49
+ else:
50
+ size = self.size
51
+ return resize(x, size, None, self.mode, self.align_corners)
iebins/networks/swin_transformer.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.utils.checkpoint as checkpoint
5
+ import numpy as np
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+
8
+ from .newcrf_utils import load_checkpoint
9
+
10
+
11
+ class Mlp(nn.Module):
12
+ """ Multilayer perceptron."""
13
+
14
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
15
+ super().__init__()
16
+ out_features = out_features or in_features
17
+ hidden_features = hidden_features or in_features
18
+ self.fc1 = nn.Linear(in_features, hidden_features)
19
+ self.act = act_layer()
20
+ self.fc2 = nn.Linear(hidden_features, out_features)
21
+ self.drop = nn.Dropout(drop)
22
+
23
+ def forward(self, x):
24
+ x = self.fc1(x)
25
+ x = self.act(x)
26
+ x = self.drop(x)
27
+ x = self.fc2(x)
28
+ x = self.drop(x)
29
+ return x
30
+
31
+
32
+ def window_partition(x, window_size):
33
+ """
34
+ Args:
35
+ x: (B, H, W, C)
36
+ window_size (int): window size
37
+
38
+ Returns:
39
+ windows: (num_windows*B, window_size, window_size, C)
40
+ """
41
+ B, H, W, C = x.shape
42
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
43
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
44
+ return windows
45
+
46
+
47
+ def window_reverse(windows, window_size, H, W):
48
+ """
49
+ Args:
50
+ windows: (num_windows*B, window_size, window_size, C)
51
+ window_size (int): Window size
52
+ H (int): Height of image
53
+ W (int): Width of image
54
+
55
+ Returns:
56
+ x: (B, H, W, C)
57
+ """
58
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
59
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
60
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
61
+ return x
62
+
63
+
64
+ class WindowAttention(nn.Module):
65
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
66
+ It supports both of shifted and non-shifted window.
67
+
68
+ Args:
69
+ dim (int): Number of input channels.
70
+ window_size (tuple[int]): The height and width of the window.
71
+ num_heads (int): Number of attention heads.
72
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
73
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
74
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
75
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
76
+ """
77
+
78
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
79
+
80
+ super().__init__()
81
+ self.dim = dim
82
+ self.window_size = window_size # Wh, Ww
83
+ self.num_heads = num_heads
84
+ head_dim = dim // num_heads
85
+ self.scale = qk_scale or head_dim ** -0.5
86
+
87
+ # define a parameter table of relative position bias
88
+ self.relative_position_bias_table = nn.Parameter(
89
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
90
+
91
+ # get pair-wise relative position index for each token inside the window
92
+ coords_h = torch.arange(self.window_size[0])
93
+ coords_w = torch.arange(self.window_size[1])
94
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
99
+ relative_coords[:, :, 1] += self.window_size[1] - 1
100
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
101
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
102
+ self.register_buffer("relative_position_index", relative_position_index)
103
+
104
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
105
+ self.attn_drop = nn.Dropout(attn_drop)
106
+ self.proj = nn.Linear(dim, dim)
107
+ self.proj_drop = nn.Dropout(proj_drop)
108
+
109
+ trunc_normal_(self.relative_position_bias_table, std=.02)
110
+ self.softmax = nn.Softmax(dim=-1)
111
+
112
+ def forward(self, x, mask=None):
113
+ """ Forward function.
114
+
115
+ Args:
116
+ x: input features with shape of (num_windows*B, N, C)
117
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
118
+ """
119
+ B_, N, C = x.shape
120
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
121
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
122
+
123
+ q = q * self.scale
124
+ attn = (q @ k.transpose(-2, -1))
125
+
126
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
127
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
128
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
129
+ attn = attn + relative_position_bias.unsqueeze(0)
130
+
131
+ if mask is not None:
132
+ nW = mask.shape[0]
133
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
134
+ attn = attn.view(-1, self.num_heads, N, N)
135
+ attn = self.softmax(attn)
136
+ else:
137
+ attn = self.softmax(attn)
138
+
139
+ attn = self.attn_drop(attn)
140
+
141
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
142
+ x = self.proj(x)
143
+ x = self.proj_drop(x)
144
+ return x
145
+
146
+
147
+ class SwinTransformerBlock(nn.Module):
148
+ """ Swin Transformer Block.
149
+
150
+ Args:
151
+ dim (int): Number of input channels.
152
+ num_heads (int): Number of attention heads.
153
+ window_size (int): Window size.
154
+ shift_size (int): Shift size for SW-MSA.
155
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
156
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
157
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
158
+ drop (float, optional): Dropout rate. Default: 0.0
159
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
160
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
161
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
162
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
163
+ """
164
+
165
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
166
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
167
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
168
+ super().__init__()
169
+ self.dim = dim
170
+ self.num_heads = num_heads
171
+ self.window_size = window_size
172
+ self.shift_size = shift_size
173
+ self.mlp_ratio = mlp_ratio
174
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
175
+
176
+ self.norm1 = norm_layer(dim)
177
+ self.attn = WindowAttention(
178
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
179
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
180
+
181
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
182
+ self.norm2 = norm_layer(dim)
183
+ mlp_hidden_dim = int(dim * mlp_ratio)
184
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
185
+
186
+ self.H = None
187
+ self.W = None
188
+
189
+ def forward(self, x, mask_matrix):
190
+ """ Forward function.
191
+
192
+ Args:
193
+ x: Input feature, tensor size (B, H*W, C).
194
+ H, W: Spatial resolution of the input feature.
195
+ mask_matrix: Attention mask for cyclic shift.
196
+ """
197
+ B, L, C = x.shape
198
+ H, W = self.H, self.W
199
+ assert L == H * W, "input feature has wrong size"
200
+
201
+ shortcut = x
202
+ x = self.norm1(x)
203
+ x = x.view(B, H, W, C)
204
+
205
+ # pad feature maps to multiples of window size
206
+ pad_l = pad_t = 0
207
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
208
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
209
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
210
+ _, Hp, Wp, _ = x.shape
211
+
212
+ # cyclic shift
213
+ if self.shift_size > 0:
214
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
215
+ attn_mask = mask_matrix
216
+ else:
217
+ shifted_x = x
218
+ attn_mask = None
219
+
220
+ # partition windows
221
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
222
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
223
+
224
+ # W-MSA/SW-MSA
225
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
226
+
227
+ # merge windows
228
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
229
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
230
+
231
+ # reverse cyclic shift
232
+ if self.shift_size > 0:
233
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
234
+ else:
235
+ x = shifted_x
236
+
237
+ if pad_r > 0 or pad_b > 0:
238
+ x = x[:, :H, :W, :].contiguous()
239
+
240
+ x = x.view(B, H * W, C)
241
+
242
+ # FFN
243
+ x = shortcut + self.drop_path(x)
244
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
245
+
246
+ return x
247
+
248
+
249
+ class PatchMerging(nn.Module):
250
+ """ Patch Merging Layer
251
+
252
+ Args:
253
+ dim (int): Number of input channels.
254
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
255
+ """
256
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
257
+ super().__init__()
258
+ self.dim = dim
259
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
260
+ self.norm = norm_layer(4 * dim)
261
+
262
+ def forward(self, x, H, W):
263
+ """ Forward function.
264
+
265
+ Args:
266
+ x: Input feature, tensor size (B, H*W, C).
267
+ H, W: Spatial resolution of the input feature.
268
+ """
269
+ B, L, C = x.shape
270
+ assert L == H * W, "input feature has wrong size"
271
+
272
+ x = x.view(B, H, W, C)
273
+
274
+ # padding
275
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
276
+ if pad_input:
277
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
278
+
279
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
280
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
281
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
282
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
283
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
284
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
285
+
286
+ x = self.norm(x)
287
+ x = self.reduction(x)
288
+
289
+ return x
290
+
291
+
292
+ class BasicLayer(nn.Module):
293
+ """ A basic Swin Transformer layer for one stage.
294
+
295
+ Args:
296
+ dim (int): Number of feature channels
297
+ depth (int): Depths of this stage.
298
+ num_heads (int): Number of attention head.
299
+ window_size (int): Local window size. Default: 7.
300
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
301
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
302
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
303
+ drop (float, optional): Dropout rate. Default: 0.0
304
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
305
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
306
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
307
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
308
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
309
+ """
310
+
311
+ def __init__(self,
312
+ dim,
313
+ depth,
314
+ num_heads,
315
+ window_size=7,
316
+ mlp_ratio=4.,
317
+ qkv_bias=True,
318
+ qk_scale=None,
319
+ drop=0.,
320
+ attn_drop=0.,
321
+ drop_path=0.,
322
+ norm_layer=nn.LayerNorm,
323
+ downsample=None,
324
+ use_checkpoint=False):
325
+ super().__init__()
326
+ self.window_size = window_size
327
+ self.shift_size = window_size // 2
328
+ self.depth = depth
329
+ self.use_checkpoint = use_checkpoint
330
+
331
+ # build blocks
332
+ self.blocks = nn.ModuleList([
333
+ SwinTransformerBlock(
334
+ dim=dim,
335
+ num_heads=num_heads,
336
+ window_size=window_size,
337
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
338
+ mlp_ratio=mlp_ratio,
339
+ qkv_bias=qkv_bias,
340
+ qk_scale=qk_scale,
341
+ drop=drop,
342
+ attn_drop=attn_drop,
343
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
344
+ norm_layer=norm_layer)
345
+ for i in range(depth)])
346
+
347
+ # patch merging layer
348
+ if downsample is not None:
349
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
350
+ else:
351
+ self.downsample = None
352
+
353
+ def forward(self, x, H, W):
354
+ """ Forward function.
355
+
356
+ Args:
357
+ x: Input feature, tensor size (B, H*W, C).
358
+ H, W: Spatial resolution of the input feature.
359
+ """
360
+
361
+ # calculate attention mask for SW-MSA
362
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
363
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
364
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
365
+ h_slices = (slice(0, -self.window_size),
366
+ slice(-self.window_size, -self.shift_size),
367
+ slice(-self.shift_size, None))
368
+ w_slices = (slice(0, -self.window_size),
369
+ slice(-self.window_size, -self.shift_size),
370
+ slice(-self.shift_size, None))
371
+ cnt = 0
372
+ for h in h_slices:
373
+ for w in w_slices:
374
+ img_mask[:, h, w, :] = cnt
375
+ cnt += 1
376
+
377
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
378
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
379
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
380
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
381
+
382
+ for blk in self.blocks:
383
+ blk.H, blk.W = H, W
384
+ if self.use_checkpoint:
385
+ x = checkpoint.checkpoint(blk, x, attn_mask)
386
+ else:
387
+ x = blk(x, attn_mask)
388
+ if self.downsample is not None:
389
+ x_down = self.downsample(x, H, W)
390
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
391
+ return x, H, W, x_down, Wh, Ww
392
+ else:
393
+ return x, H, W, x, H, W
394
+
395
+
396
+ class PatchEmbed(nn.Module):
397
+ """ Image to Patch Embedding
398
+
399
+ Args:
400
+ patch_size (int): Patch token size. Default: 4.
401
+ in_chans (int): Number of input image channels. Default: 3.
402
+ embed_dim (int): Number of linear projection output channels. Default: 96.
403
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
404
+ """
405
+
406
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
407
+ super().__init__()
408
+ patch_size = to_2tuple(patch_size)
409
+ self.patch_size = patch_size
410
+
411
+ self.in_chans = in_chans
412
+ self.embed_dim = embed_dim
413
+
414
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
415
+ if norm_layer is not None:
416
+ self.norm = norm_layer(embed_dim)
417
+ else:
418
+ self.norm = None
419
+
420
+ def forward(self, x):
421
+ """Forward function."""
422
+ # padding
423
+ _, _, H, W = x.size()
424
+ if W % self.patch_size[1] != 0:
425
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
426
+ if H % self.patch_size[0] != 0:
427
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
428
+
429
+ x = self.proj(x) # B C Wh Ww
430
+ if self.norm is not None:
431
+ Wh, Ww = x.size(2), x.size(3)
432
+ x = x.flatten(2).transpose(1, 2)
433
+ x = self.norm(x)
434
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
435
+
436
+ return x
437
+
438
+
439
+ class SwinTransformer(nn.Module):
440
+ """ Swin Transformer backbone.
441
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
442
+ https://arxiv.org/pdf/2103.14030
443
+
444
+ Args:
445
+ pretrain_img_size (int): Input image size for training the pretrained model,
446
+ used in absolute postion embedding. Default 224.
447
+ patch_size (int | tuple(int)): Patch size. Default: 4.
448
+ in_chans (int): Number of input image channels. Default: 3.
449
+ embed_dim (int): Number of linear projection output channels. Default: 96.
450
+ depths (tuple[int]): Depths of each Swin Transformer stage.
451
+ num_heads (tuple[int]): Number of attention head of each stage.
452
+ window_size (int): Window size. Default: 7.
453
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
454
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
455
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
456
+ drop_rate (float): Dropout rate.
457
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
458
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
459
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
460
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
461
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
462
+ out_indices (Sequence[int]): Output from which stages.
463
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
464
+ -1 means not freezing any parameters.
465
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
466
+ """
467
+
468
+ def __init__(self,
469
+ pretrain_img_size=224,
470
+ patch_size=4,
471
+ in_chans=3,
472
+ embed_dim=96,
473
+ depths=[2, 2, 6, 2],
474
+ num_heads=[3, 6, 12, 24],
475
+ window_size=7,
476
+ mlp_ratio=4.,
477
+ qkv_bias=True,
478
+ qk_scale=None,
479
+ drop_rate=0.,
480
+ attn_drop_rate=0.,
481
+ drop_path_rate=0.2,
482
+ norm_layer=nn.LayerNorm,
483
+ ape=False,
484
+ patch_norm=True,
485
+ out_indices=(0, 1, 2, 3),
486
+ frozen_stages=-1,
487
+ use_checkpoint=False):
488
+ super().__init__()
489
+
490
+ self.pretrain_img_size = pretrain_img_size
491
+ self.num_layers = len(depths)
492
+ self.embed_dim = embed_dim
493
+ self.ape = ape
494
+ self.patch_norm = patch_norm
495
+ self.out_indices = out_indices
496
+ self.frozen_stages = frozen_stages
497
+
498
+ # split image into non-overlapping patches
499
+ self.patch_embed = PatchEmbed(
500
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
501
+ norm_layer=norm_layer if self.patch_norm else None)
502
+
503
+ # absolute position embedding
504
+ if self.ape:
505
+ pretrain_img_size = to_2tuple(pretrain_img_size)
506
+ patch_size = to_2tuple(patch_size)
507
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
508
+
509
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
510
+ trunc_normal_(self.absolute_pos_embed, std=.02)
511
+
512
+ self.pos_drop = nn.Dropout(p=drop_rate)
513
+
514
+ # stochastic depth
515
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
516
+
517
+ # build layers
518
+ self.layers = nn.ModuleList()
519
+ for i_layer in range(self.num_layers):
520
+ layer = BasicLayer(
521
+ dim=int(embed_dim * 2 ** i_layer),
522
+ depth=depths[i_layer],
523
+ num_heads=num_heads[i_layer],
524
+ window_size=window_size,
525
+ mlp_ratio=mlp_ratio,
526
+ qkv_bias=qkv_bias,
527
+ qk_scale=qk_scale,
528
+ drop=drop_rate,
529
+ attn_drop=attn_drop_rate,
530
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
531
+ norm_layer=norm_layer,
532
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
533
+ use_checkpoint=use_checkpoint)
534
+ self.layers.append(layer)
535
+
536
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
537
+ self.num_features = num_features
538
+
539
+ # add a norm layer for each output
540
+ for i_layer in out_indices:
541
+ layer = norm_layer(num_features[i_layer])
542
+ layer_name = f'norm{i_layer}'
543
+ self.add_module(layer_name, layer)
544
+
545
+ self._freeze_stages()
546
+
547
+ def _freeze_stages(self):
548
+ if self.frozen_stages >= 0:
549
+ self.patch_embed.eval()
550
+ for param in self.patch_embed.parameters():
551
+ param.requires_grad = False
552
+
553
+ if self.frozen_stages >= 1 and self.ape:
554
+ self.absolute_pos_embed.requires_grad = False
555
+
556
+ if self.frozen_stages >= 2:
557
+ self.pos_drop.eval()
558
+ for i in range(0, self.frozen_stages - 1):
559
+ m = self.layers[i]
560
+ m.eval()
561
+ for param in m.parameters():
562
+ param.requires_grad = False
563
+
564
+ def init_weights(self, pretrained=None):
565
+ """Initialize the weights in backbone.
566
+
567
+ Args:
568
+ pretrained (str, optional): Path to pre-trained weights.
569
+ Defaults to None.
570
+ """
571
+
572
+ def _init_weights(m):
573
+ if isinstance(m, nn.Linear):
574
+ trunc_normal_(m.weight, std=.02)
575
+ if isinstance(m, nn.Linear) and m.bias is not None:
576
+ nn.init.constant_(m.bias, 0)
577
+ elif isinstance(m, nn.LayerNorm):
578
+ nn.init.constant_(m.bias, 0)
579
+ nn.init.constant_(m.weight, 1.0)
580
+
581
+ if isinstance(pretrained, str):
582
+ self.apply(_init_weights)
583
+ # logger = get_root_logger()
584
+ load_checkpoint(self, pretrained, strict=False)
585
+ elif pretrained is None:
586
+ self.apply(_init_weights)
587
+ else:
588
+ raise TypeError('pretrained must be a str or None')
589
+
590
+ def forward(self, x):
591
+ """Forward function."""
592
+ x = self.patch_embed(x)
593
+
594
+ Wh, Ww = x.size(2), x.size(3)
595
+ if self.ape:
596
+ # interpolate the position embedding to the corresponding size
597
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
598
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
599
+ else:
600
+ x = x.flatten(2).transpose(1, 2)
601
+ x = self.pos_drop(x)
602
+
603
+ outs = []
604
+ for i in range(self.num_layers):
605
+ layer = self.layers[i]
606
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
607
+
608
+ if i in self.out_indices:
609
+ norm_layer = getattr(self, f'norm{i}')
610
+ x_out = norm_layer(x_out)
611
+
612
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
613
+ outs.append(out)
614
+
615
+ return tuple(outs)
616
+
617
+ def train(self, mode=True):
618
+ """Convert the model into training mode while keep layers freezed."""
619
+ super(SwinTransformer, self).train(mode)
620
+ self._freeze_stages()
iebins/networks/uper_crf_head.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from mmcv.cnn import ConvModule
6
+ from .newcrf_utils import resize, normal_init
7
+
8
+
9
+ class PPM(nn.ModuleList):
10
+ """Pooling Pyramid Module used in PSPNet.
11
+
12
+ Args:
13
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
14
+ Module.
15
+ in_channels (int): Input channels.
16
+ channels (int): Channels after modules, before conv_seg.
17
+ conv_cfg (dict|None): Config of conv layers.
18
+ norm_cfg (dict|None): Config of norm layers.
19
+ act_cfg (dict): Config of activation layers.
20
+ align_corners (bool): align_corners argument of F.interpolate.
21
+ """
22
+
23
+ def __init__(self, pool_scales, in_channels, channels, conv_cfg, norm_cfg,
24
+ act_cfg, align_corners):
25
+ super(PPM, self).__init__()
26
+ self.pool_scales = pool_scales
27
+ self.align_corners = align_corners
28
+ self.in_channels = in_channels
29
+ self.channels = channels
30
+ self.conv_cfg = conv_cfg
31
+ self.norm_cfg = norm_cfg
32
+ self.act_cfg = act_cfg
33
+ for pool_scale in pool_scales:
34
+ # == if batch size = 1, BN is not supported, change to GN
35
+ if pool_scale == 1: norm_cfg = dict(type='GN', requires_grad=True, num_groups=256)
36
+ self.append(
37
+ nn.Sequential(
38
+ nn.AdaptiveAvgPool2d(pool_scale),
39
+ ConvModule(
40
+ self.in_channels,
41
+ self.channels,
42
+ 1,
43
+ conv_cfg=self.conv_cfg,
44
+ norm_cfg=norm_cfg,
45
+ act_cfg=self.act_cfg)))
46
+
47
+ def forward(self, x):
48
+ """Forward function."""
49
+ ppm_outs = []
50
+ for ppm in self:
51
+ ppm_out = ppm(x)
52
+ upsampled_ppm_out = resize(
53
+ ppm_out,
54
+ size=x.size()[2:],
55
+ mode='bilinear',
56
+ align_corners=self.align_corners)
57
+ ppm_outs.append(upsampled_ppm_out)
58
+ return ppm_outs
59
+
60
+
61
+ class BaseDecodeHead(nn.Module):
62
+ """Base class for BaseDecodeHead.
63
+
64
+ Args:
65
+ in_channels (int|Sequence[int]): Input channels.
66
+ channels (int): Channels after modules, before conv_seg.
67
+ num_classes (int): Number of classes.
68
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
69
+ conv_cfg (dict|None): Config of conv layers. Default: None.
70
+ norm_cfg (dict|None): Config of norm layers. Default: None.
71
+ act_cfg (dict): Config of activation layers.
72
+ Default: dict(type='ReLU')
73
+ in_index (int|Sequence[int]): Input feature index. Default: -1
74
+ input_transform (str|None): Transformation type of input features.
75
+ Options: 'resize_concat', 'multiple_select', None.
76
+ 'resize_concat': Multiple feature maps will be resize to the
77
+ same size as first one and than concat together.
78
+ Usually used in FCN head of HRNet.
79
+ 'multiple_select': Multiple feature maps will be bundle into
80
+ a list and passed into decode head.
81
+ None: Only one select feature map is allowed.
82
+ Default: None.
83
+ loss_decode (dict): Config of decode loss.
84
+ Default: dict(type='CrossEntropyLoss').
85
+ ignore_index (int | None): The label index to be ignored. When using
86
+ masked BCE loss, ignore_index should be set to None. Default: 255
87
+ sampler (dict|None): The config of segmentation map sampler.
88
+ Default: None.
89
+ align_corners (bool): align_corners argument of F.interpolate.
90
+ Default: False.
91
+ """
92
+
93
+ def __init__(self,
94
+ in_channels,
95
+ channels,
96
+ *,
97
+ num_classes,
98
+ dropout_ratio=0.1,
99
+ conv_cfg=None,
100
+ norm_cfg=None,
101
+ act_cfg=dict(type='ReLU'),
102
+ in_index=-1,
103
+ input_transform=None,
104
+ loss_decode=dict(
105
+ type='CrossEntropyLoss',
106
+ use_sigmoid=False,
107
+ loss_weight=1.0),
108
+ ignore_index=255,
109
+ sampler=None,
110
+ align_corners=False):
111
+ super(BaseDecodeHead, self).__init__()
112
+ self._init_inputs(in_channels, in_index, input_transform)
113
+ self.channels = channels
114
+ self.num_classes = num_classes
115
+ self.dropout_ratio = dropout_ratio
116
+ self.conv_cfg = conv_cfg
117
+ self.norm_cfg = norm_cfg
118
+ self.act_cfg = act_cfg
119
+ self.in_index = in_index
120
+ # self.loss_decode = build_loss(loss_decode)
121
+ self.ignore_index = ignore_index
122
+ self.align_corners = align_corners
123
+ # if sampler is not None:
124
+ # self.sampler = build_pixel_sampler(sampler, context=self)
125
+ # else:
126
+ # self.sampler = None
127
+
128
+ # self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
129
+ # self.conv1 = nn.Conv2d(channels, num_classes, 3, padding=1)
130
+ if dropout_ratio > 0:
131
+ self.dropout = nn.Dropout2d(dropout_ratio)
132
+ else:
133
+ self.dropout = None
134
+ self.fp16_enabled = False
135
+
136
+ def extra_repr(self):
137
+ """Extra repr."""
138
+ s = f'input_transform={self.input_transform}, ' \
139
+ f'ignore_index={self.ignore_index}, ' \
140
+ f'align_corners={self.align_corners}'
141
+ return s
142
+
143
+ def _init_inputs(self, in_channels, in_index, input_transform):
144
+ """Check and initialize input transforms.
145
+
146
+ The in_channels, in_index and input_transform must match.
147
+ Specifically, when input_transform is None, only single feature map
148
+ will be selected. So in_channels and in_index must be of type int.
149
+ When input_transform
150
+
151
+ Args:
152
+ in_channels (int|Sequence[int]): Input channels.
153
+ in_index (int|Sequence[int]): Input feature index.
154
+ input_transform (str|None): Transformation type of input features.
155
+ Options: 'resize_concat', 'multiple_select', None.
156
+ 'resize_concat': Multiple feature maps will be resize to the
157
+ same size as first one and than concat together.
158
+ Usually used in FCN head of HRNet.
159
+ 'multiple_select': Multiple feature maps will be bundle into
160
+ a list and passed into decode head.
161
+ None: Only one select feature map is allowed.
162
+ """
163
+
164
+ if input_transform is not None:
165
+ assert input_transform in ['resize_concat', 'multiple_select']
166
+ self.input_transform = input_transform
167
+ self.in_index = in_index
168
+ if input_transform is not None:
169
+ assert isinstance(in_channels, (list, tuple))
170
+ assert isinstance(in_index, (list, tuple))
171
+ assert len(in_channels) == len(in_index)
172
+ if input_transform == 'resize_concat':
173
+ self.in_channels = sum(in_channels)
174
+ else:
175
+ self.in_channels = in_channels
176
+ else:
177
+ assert isinstance(in_channels, int)
178
+ assert isinstance(in_index, int)
179
+ self.in_channels = in_channels
180
+
181
+ def init_weights(self):
182
+ """Initialize weights of classification layer."""
183
+ # normal_init(self.conv_seg, mean=0, std=0.01)
184
+ # normal_init(self.conv1, mean=0, std=0.01)
185
+
186
+ def _transform_inputs(self, inputs):
187
+ """Transform inputs for decoder.
188
+
189
+ Args:
190
+ inputs (list[Tensor]): List of multi-level img features.
191
+
192
+ Returns:
193
+ Tensor: The transformed inputs
194
+ """
195
+
196
+ if self.input_transform == 'resize_concat':
197
+ inputs = [inputs[i] for i in self.in_index]
198
+ upsampled_inputs = [
199
+ resize(
200
+ input=x,
201
+ size=inputs[0].shape[2:],
202
+ mode='bilinear',
203
+ align_corners=self.align_corners) for x in inputs
204
+ ]
205
+ inputs = torch.cat(upsampled_inputs, dim=1)
206
+ elif self.input_transform == 'multiple_select':
207
+ inputs = [inputs[i] for i in self.in_index]
208
+ else:
209
+ inputs = inputs[self.in_index]
210
+
211
+ return inputs
212
+
213
+ def forward(self, inputs):
214
+ """Placeholder of forward function."""
215
+ pass
216
+
217
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
218
+ """Forward function for training.
219
+ Args:
220
+ inputs (list[Tensor]): List of multi-level img features.
221
+ img_metas (list[dict]): List of image info dict where each dict
222
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
223
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
224
+ For details on the values of these keys see
225
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
226
+ gt_semantic_seg (Tensor): Semantic segmentation masks
227
+ used if the architecture supports semantic segmentation task.
228
+ train_cfg (dict): The training config.
229
+
230
+ Returns:
231
+ dict[str, Tensor]: a dictionary of loss components
232
+ """
233
+ seg_logits = self.forward(inputs)
234
+ losses = self.losses(seg_logits, gt_semantic_seg)
235
+ return losses
236
+
237
+ def forward_test(self, inputs, img_metas, test_cfg):
238
+ """Forward function for testing.
239
+
240
+ Args:
241
+ inputs (list[Tensor]): List of multi-level img features.
242
+ img_metas (list[dict]): List of image info dict where each dict
243
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
244
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
245
+ For details on the values of these keys see
246
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
247
+ test_cfg (dict): The testing config.
248
+
249
+ Returns:
250
+ Tensor: Output segmentation map.
251
+ """
252
+ return self.forward(inputs)
253
+
254
+
255
+ class UPerHead(BaseDecodeHead):
256
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
257
+ super(UPerHead, self).__init__(
258
+ input_transform='multiple_select', **kwargs)
259
+ # FPN Module
260
+ self.lateral_convs = nn.ModuleList()
261
+ self.fpn_convs = nn.ModuleList()
262
+ for in_channels in self.in_channels: # skip the top layer
263
+ l_conv = ConvModule(
264
+ in_channels,
265
+ self.channels,
266
+ 1,
267
+ conv_cfg=self.conv_cfg,
268
+ norm_cfg=self.norm_cfg,
269
+ act_cfg=self.act_cfg,
270
+ inplace=True)
271
+ fpn_conv = ConvModule(
272
+ self.channels,
273
+ self.channels,
274
+ 3,
275
+ padding=1,
276
+ conv_cfg=self.conv_cfg,
277
+ norm_cfg=self.norm_cfg,
278
+ act_cfg=self.act_cfg,
279
+ inplace=True)
280
+ self.lateral_convs.append(l_conv)
281
+ self.fpn_convs.append(fpn_conv)
282
+
283
+ def forward(self, inputs):
284
+ """Forward function."""
285
+
286
+ inputs = self._transform_inputs(inputs)
287
+
288
+ # build laterals
289
+ laterals = [
290
+ lateral_conv(inputs[i])
291
+ for i, lateral_conv in enumerate(self.lateral_convs)
292
+ ]
293
+
294
+ # laterals.append(self.psp_forward(inputs))
295
+
296
+ # build top-down path
297
+ used_backbone_levels = len(laterals)
298
+ for i in range(used_backbone_levels - 1, 0, -1):
299
+ prev_shape = laterals[i - 1].shape[2:]
300
+ laterals[i - 1] += resize(
301
+ laterals[i],
302
+ size=prev_shape,
303
+ mode='bilinear',
304
+ align_corners=self.align_corners)
305
+
306
+ # build outputs
307
+ fpn_outs = [
308
+ self.fpn_convs[i](laterals[i])
309
+ for i in range(used_backbone_levels - 1)
310
+ ]
311
+ # append psp feature
312
+ fpn_outs.append(laterals[-1])
313
+
314
+ return fpn_outs[0]
315
+
316
+
317
+
318
+ class PSP(BaseDecodeHead):
319
+ """Unified Perceptual Parsing for Scene Understanding.
320
+
321
+ This head is the implementation of `UPerNet
322
+ <https://arxiv.org/abs/1807.10221>`_.
323
+
324
+ Args:
325
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
326
+ Module applied on the last feature. Default: (1, 2, 3, 6).
327
+ """
328
+
329
+ def __init__(self, pool_scales=(1, 2, 3, 6), **kwargs):
330
+ super(PSP, self).__init__(
331
+ input_transform='multiple_select', **kwargs)
332
+ # PSP Module
333
+ self.psp_modules = PPM(
334
+ pool_scales,
335
+ self.in_channels[-1],
336
+ self.channels,
337
+ conv_cfg=self.conv_cfg,
338
+ norm_cfg=self.norm_cfg,
339
+ act_cfg=self.act_cfg,
340
+ align_corners=self.align_corners)
341
+ self.bottleneck = ConvModule(
342
+ self.in_channels[-1] + len(pool_scales) * self.channels,
343
+ self.channels,
344
+ 3,
345
+ padding=1,
346
+ conv_cfg=self.conv_cfg,
347
+ norm_cfg=self.norm_cfg,
348
+ act_cfg=self.act_cfg)
349
+
350
+ def psp_forward(self, inputs):
351
+ """Forward function of PSP module."""
352
+ x = inputs[-1]
353
+ psp_outs = [x]
354
+ psp_outs.extend(self.psp_modules(x))
355
+ psp_outs = torch.cat(psp_outs, dim=1)
356
+ output = self.bottleneck(psp_outs)
357
+
358
+ return output
359
+
360
+ def forward(self, inputs):
361
+ """Forward function."""
362
+ inputs = self._transform_inputs(inputs)
363
+
364
+ return self.psp_forward(inputs)
iebins/sum_depth.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+
5
+ class Sum_depth(nn.Module):
6
+ def __init__(self):
7
+ super(Sum_depth, self).__init__()
8
+ self.sum_conv = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
9
+ sum_k = np.array([[1, 1, 1], [1, 1, 1], [1, 1, 1]])
10
+
11
+ sum_k = torch.from_numpy(sum_k).float().view(1, 1, 3, 3)
12
+ self.sum_conv.weight = nn.Parameter(sum_k)
13
+
14
+ for param in self.parameters():
15
+ param.requires_grad = False
16
+
17
+ def forward(self, x):
18
+ out = self.sum_conv(x)
19
+ out = out.contiguous().view(-1, 1, x.size(2), x.size(3))
20
+
21
+ return out
22
+
iebins/test.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.autograd import Variable
6
+
7
+ import os, sys, errno
8
+ import argparse
9
+ import time
10
+ import numpy as np
11
+ import cv2
12
+ import matplotlib.pyplot as plt
13
+ from tqdm import tqdm
14
+ import open3d as o3d
15
+
16
+ from utils import post_process_depth, D_to_cloud, flip_lr, inv_normalize
17
+
18
+ from networks.NewCRFDepth import NewCRFDepth
19
+
20
+
21
+ def convert_arg_line_to_args(arg_line):
22
+ for arg in arg_line.split():
23
+ if not arg.strip():
24
+ continue
25
+ yield arg
26
+
27
+
28
+ parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
29
+ parser.convert_arg_line_to_args = convert_arg_line_to_args
30
+
31
+ parser.add_argument('--model_name', type=str, help='model name', default='iebins')
32
+ parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
33
+ parser.add_argument('--data_path', type=str, help='path to the data', required=True)
34
+ parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
35
+ parser.add_argument('--input_height', type=int, help='input height', default=480)
36
+ parser.add_argument('--input_width', type=int, help='input width', default=640)
37
+ parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
38
+ parser.add_argument('--checkpoint_path', type=str, help='path to a specific checkpoint to load', default='')
39
+ parser.add_argument('--dataset', type=str, help='dataset to train on', default='nyu')
40
+ parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
41
+ parser.add_argument('--pred_clouds', help='if set, pred cloud points', action='store_true')
42
+ parser.add_argument('--save_viz', help='if set, save visulization of the outputs', action='store_true')
43
+
44
+ if sys.argv.__len__() == 2:
45
+ arg_filename_with_prefix = '@' + sys.argv[1]
46
+ args = parser.parse_args([arg_filename_with_prefix])
47
+ else:
48
+ args = parser.parse_args()
49
+
50
+ if args.dataset == 'kitti' or args.dataset == 'nyu':
51
+ from dataloaders.dataloader import NewDataLoader
52
+
53
+ model_dir = os.path.dirname(args.checkpoint_path)
54
+ sys.path.append(model_dir)
55
+
56
+
57
+ def get_num_lines(file_path):
58
+ f = open(file_path, 'r')
59
+ lines = f.readlines()
60
+ f.close()
61
+ return len(lines)
62
+
63
+
64
+ def test(params):
65
+ """Test function."""
66
+ args.mode = 'test'
67
+ dataloader = NewDataLoader(args, 'test')
68
+
69
+ model = NewCRFDepth(version='large07', inv_depth=False, max_depth=args.max_depth)
70
+ model = torch.nn.DataParallel(model)
71
+
72
+ checkpoint = torch.load(args.checkpoint_path)
73
+ model.load_state_dict(checkpoint['model'])
74
+ model.eval()
75
+ model.cuda()
76
+
77
+ num_params = sum([np.prod(p.size()) for p in model.parameters()])
78
+ print("Total number of parameters: {}".format(num_params))
79
+
80
+ num_test_samples = get_num_lines(args.filenames_file)
81
+
82
+ with open(args.filenames_file) as f:
83
+ lines = f.readlines()
84
+
85
+ print('now testing {} files with {}'.format(num_test_samples, args.checkpoint_path))
86
+
87
+ pred_depths = []
88
+ pred_clouds = []
89
+ start_time = time.time()
90
+ with torch.no_grad():
91
+ for _, sample in enumerate(tqdm(dataloader.data)):
92
+ image = Variable(sample['image'].cuda())
93
+ inv_K_p = Variable(sample['inv_K_p'].cuda())
94
+ b, _, h, w = image.shape
95
+ depth_to_cloud = D_to_cloud(b, h, w).cuda()
96
+
97
+ # Predict
98
+ pred_depths_r_list, _, _ = model(image)
99
+ post_process = True
100
+ if post_process:
101
+ image_flipped = flip_lr(image)
102
+ pred_depths_r_list_flipped, _, _ = model(image_flipped)
103
+ pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
104
+
105
+ if args.pred_clouds:
106
+ if args.dataset == 'nyu':
107
+ color = inv_normalize(image[0, :, :, :]).permute(1, 2, 0)[45:472, 43:608, :].reshape(-1, 3).cpu().numpy()
108
+ points = depth_to_cloud(pred_depth, inv_K_p).reshape(1, h, w, 3)[:, 45:472, 43:608, :].reshape(1, -1, 3)
109
+ points = points.cpu().numpy().squeeze()
110
+ else:
111
+ color = inv_normalize(image[0, :, :, :]).permute(1, 2, 0).reshape(-1, 3).cpu().numpy()
112
+ points = depth_to_cloud(pred_depth, inv_K_p)
113
+ points = points.cpu().numpy().squeeze()
114
+ pc = o3d.geometry.PointCloud()
115
+ pc.points = o3d.utility.Vector3dVector(points)
116
+ pc.colors = o3d.utility.Vector3dVector(color)
117
+
118
+ pred_clouds.append(pc)
119
+
120
+ pred_depth = pred_depth.cpu().numpy().squeeze()
121
+
122
+ if args.do_kb_crop:
123
+ height, width = 352, 1216
124
+ top_margin = int(height - 352)
125
+ left_margin = int((width - 1216) / 2)
126
+ pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
127
+ pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
128
+ pred_depth = pred_depth_uncropped
129
+
130
+ pred_depths.append(pred_depth)
131
+
132
+ elapsed_time = time.time() - start_time
133
+ print('Elapesed time: %s' % str(elapsed_time))
134
+ print('Done.')
135
+
136
+ save_name = 'models/result_' + args.model_name
137
+
138
+ print('Saving result pngs..')
139
+ if not os.path.exists(save_name):
140
+ try:
141
+ os.mkdir(save_name)
142
+ os.mkdir(save_name + '/raw')
143
+ os.mkdir(save_name + '/cmap')
144
+ os.mkdir(save_name + '/rgb')
145
+ os.mkdir(save_name + '/gt')
146
+ os.mkdir(save_name + '/cloud')
147
+ except OSError as e:
148
+ if e.errno != errno.EEXIST:
149
+ raise
150
+
151
+ for s in tqdm(range(num_test_samples)):
152
+ if args.dataset == 'kitti':
153
+ date_drive = lines[s].split('/')[1]
154
+ filename_pred_png = save_name + '/raw/' + date_drive + '_' + lines[s].split()[0].split('/')[-1].replace(
155
+ '.jpg', '.png')
156
+ filename_pred_ply = save_name + '/cloud/' + date_drive + '_' + lines[s].split()[0].split('/')[-1][:-4] + '_' + 'iebins' + '.ply'
157
+ filename_cmap_png = save_name + '/cmap/' + date_drive + '_' + lines[s].split()[0].split('/')[
158
+ -1].replace('.jpg', '.png')
159
+ filename_image_png = save_name + '/rgb/' + date_drive + '_' + lines[s].split()[0].split('/')[-1]
160
+ elif args.dataset == 'kittipred':
161
+ filename_pred_png = save_name + '/raw/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
162
+ filename_cmap_png = save_name + '/cmap/' + lines[s].split()[0].split('/')[-1].replace('.jpg', '.png')
163
+ filename_image_png = save_name + '/rgb/' + lines[s].split()[0].split('/')[-1]
164
+ else:
165
+ scene_name = lines[s].split()[0].split('/')[0]
166
+ filename_pred_png = save_name + '/raw/' + scene_name + '_' + lines[s].split()[0].split('/')[1].replace(
167
+ '.jpg', '.png')
168
+ filename_pred_ply = save_name + '/cloud/' + scene_name + '_' + lines[s].split()[0].split('/')[1][:-4] + '_' + 'iebins' + '.ply'
169
+ filename_cmap_png = save_name + '/cmap/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
170
+ '.jpg', '.png')
171
+ filename_gt_png = save_name + '/gt/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1].replace(
172
+ '.jpg', '_gt.png')
173
+ filename_image_png = save_name + '/rgb/' + scene_name + '_' + lines[s].split()[0].split('/rgb_')[1]
174
+
175
+ rgb_path = os.path.join(args.data_path, './' + lines[s].split()[0])
176
+ image = cv2.imread(rgb_path)
177
+ if args.dataset == 'nyu':
178
+ gt_path = os.path.join(args.data_path, './' + lines[s].split()[1])
179
+ gt = cv2.imread(gt_path, -1).astype(np.float32) / 1000.0 # Visualization purpose only
180
+ gt[gt == 0] = np.amax(gt)
181
+
182
+ pred_depth = pred_depths[s]
183
+
184
+ if args.dataset == 'kitti' or args.dataset == 'kittipred':
185
+ pred_depth_scaled = pred_depth * 256.0
186
+ else:
187
+ pred_depth_scaled = pred_depth * 1000.0
188
+
189
+ pred_depth_scaled = pred_depth_scaled.astype(np.uint16)
190
+ cv2.imwrite(filename_pred_png, pred_depth_scaled, [cv2.IMWRITE_PNG_COMPRESSION, 0])
191
+
192
+ if args.save_viz:
193
+ cv2.imwrite(filename_image_png, image[10:-1 - 9, 10:-1 - 9, :])
194
+ if args.dataset == 'nyu':
195
+ plt.imsave(filename_gt_png, (10 - gt) / 10, cmap='jet')
196
+ pred_depth_cropped = pred_depth[10:-1 - 9, 10:-1 - 9]
197
+ plt.imsave(filename_cmap_png, (10 - pred_depth) / 10, cmap='jet')
198
+ else:
199
+ plt.imsave(filename_cmap_png, np.log10(pred_depth), cmap='magma')
200
+
201
+ if args.pred_clouds:
202
+ pred_cloud = pred_clouds[s]
203
+ o3d.io.write_point_cloud(filename_pred_ply, pred_cloud)
204
+
205
+ return
206
+
207
+
208
+ if __name__ == '__main__':
209
+ test(args)
iebins/train.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.utils as utils
4
+ import torch.backends.cudnn as cudnn
5
+ import torch.distributed as dist
6
+ import torch.multiprocessing as mp
7
+
8
+ import os, sys, time
9
+ from telnetlib import IP
10
+ import argparse
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+
14
+ from tensorboardX import SummaryWriter
15
+
16
+ from utils import post_process_depth, flip_lr, silog_loss, compute_errors, eval_metrics, entropy_loss, colormap, \
17
+ block_print, enable_print, normalize_result, inv_normalize, convert_arg_line_to_args, colormap_magma
18
+ from networks.NewCRFDepth import NewCRFDepth
19
+ from networks.depth_update import *
20
+ from datetime import datetime
21
+ from sum_depth import Sum_depth
22
+
23
+
24
+ parser = argparse.ArgumentParser(description='IEBins PyTorch implementation.', fromfile_prefix_chars='@')
25
+ parser.convert_arg_line_to_args = convert_arg_line_to_args
26
+
27
+ parser.add_argument('--mode', type=str, help='train or test', default='train')
28
+ parser.add_argument('--model_name', type=str, help='model name', default='iebins')
29
+ parser.add_argument('--encoder', type=str, help='type of encoder, base07, large07, tiny07', default='large07')
30
+ parser.add_argument('--pretrain', type=str, help='path of pretrained encoder', default=None)
31
+
32
+ # Dataset
33
+ parser.add_argument('--dataset', type=str, help='dataset to train on, kitti or nyu', default='nyu')
34
+ parser.add_argument('--data_path', type=str, help='path to the data', required=True)
35
+ parser.add_argument('--gt_path', type=str, help='path to the groundtruth data', required=True)
36
+ parser.add_argument('--filenames_file', type=str, help='path to the filenames text file', required=True)
37
+ parser.add_argument('--input_height', type=int, help='input height', default=480)
38
+ parser.add_argument('--input_width', type=int, help='input width', default=640)
39
+ parser.add_argument('--max_depth', type=float, help='maximum depth in estimation', default=10)
40
+ parser.add_argument('--min_depth', type=float, help='minimum depth in estimation', default=0.1)
41
+
42
+ # Log and save
43
+ parser.add_argument('--log_directory', type=str, help='directory to save checkpoints and summaries', default='')
44
+ parser.add_argument('--checkpoint_path', type=str, help='path to a checkpoint to load', default='')
45
+ parser.add_argument('--log_freq', type=int, help='Logging frequency in global steps', default=100)
46
+ parser.add_argument('--save_freq', type=int, help='Checkpoint saving frequency in global steps', default=5000)
47
+
48
+ # Training
49
+ parser.add_argument('--weight_decay', type=float, help='weight decay factor for optimization', default=1e-2)
50
+ parser.add_argument('--retrain', help='if used with checkpoint_path, will restart training from step zero', action='store_true')
51
+ parser.add_argument('--adam_eps', type=float, help='epsilon in Adam optimizer', default=1e-6)
52
+ parser.add_argument('--batch_size', type=int, help='batch size', default=4)
53
+ parser.add_argument('--num_epochs', type=int, help='number of epochs', default=50)
54
+ parser.add_argument('--learning_rate', type=float, help='initial learning rate', default=1e-4)
55
+ parser.add_argument('--end_learning_rate', type=float, help='end learning rate', default=-1)
56
+ parser.add_argument('--variance_focus', type=float, help='lambda in paper: [0, 1], higher value more focus on minimizing variance of error', default=0.85)
57
+
58
+ # Preprocessing
59
+ parser.add_argument('--do_random_rotate', help='if set, will perform random rotation for augmentation', action='store_true')
60
+ parser.add_argument('--degree', type=float, help='random rotation maximum degree', default=2.5)
61
+ parser.add_argument('--do_kb_crop', help='if set, crop input images as kitti benchmark images', action='store_true')
62
+ parser.add_argument('--use_right', help='if set, will randomly use right images when train on KITTI', action='store_true')
63
+
64
+ # Multi-gpu training
65
+ parser.add_argument('--num_threads', type=int, help='number of threads to use for data loading', default=1)
66
+ parser.add_argument('--world_size', type=int, help='number of nodes for distributed training', default=1)
67
+ parser.add_argument('--rank', type=int, help='node rank for distributed training', default=0)
68
+ parser.add_argument('--dist_url', type=str, help='url used to set up distributed training', default='tcp://127.0.0.1:1234')
69
+ parser.add_argument('--dist_backend', type=str, help='distributed backend', default='nccl')
70
+ parser.add_argument('--gpu', type=int, help='GPU id to use.', default=None)
71
+ parser.add_argument('--multiprocessing_distributed', help='Use multi-processing distributed training to launch '
72
+ 'N processes per node, which has N GPUs. This is the '
73
+ 'fastest way to use PyTorch for either single node or '
74
+ 'multi node data parallel training', action='store_true',)
75
+ # Online eval
76
+ parser.add_argument('--do_online_eval', help='if set, perform online eval in every eval_freq steps', action='store_true')
77
+ parser.add_argument('--data_path_eval', type=str, help='path to the data for online evaluation', required=False)
78
+ parser.add_argument('--gt_path_eval', type=str, help='path to the groundtruth data for online evaluation', required=False)
79
+ parser.add_argument('--filenames_file_eval', type=str, help='path to the filenames text file for online evaluation', required=False)
80
+ parser.add_argument('--min_depth_eval', type=float, help='minimum depth for evaluation', default=1e-3)
81
+ parser.add_argument('--max_depth_eval', type=float, help='maximum depth for evaluation', default=80)
82
+ parser.add_argument('--eigen_crop', help='if set, crops according to Eigen NIPS14', action='store_true')
83
+ parser.add_argument('--garg_crop', help='if set, crops according to Garg ECCV16', action='store_true')
84
+ parser.add_argument('--eval_freq', type=int, help='Online evaluation frequency in global steps', default=500)
85
+ parser.add_argument('--eval_summary_directory', type=str, help='output directory for eval summary,'
86
+ 'if empty outputs to checkpoint folder', default='')
87
+
88
+ if sys.argv.__len__() == 2:
89
+ arg_filename_with_prefix = '@' + sys.argv[1]
90
+ args = parser.parse_args([arg_filename_with_prefix])
91
+ else:
92
+ args = parser.parse_args()
93
+
94
+ if args.dataset == 'kitti' or args.dataset == 'nyu':
95
+ from dataloaders.dataloader import NewDataLoader
96
+
97
+
98
+ def online_eval(model, dataloader_eval, gpu, epoch, ngpus, group, post_process=False):
99
+ eval_measures = torch.zeros(10).cuda(device=gpu)
100
+ for _, eval_sample_batched in enumerate(tqdm(dataloader_eval.data)):
101
+ with torch.no_grad():
102
+ image = torch.autograd.Variable(eval_sample_batched['image'].cuda(gpu, non_blocking=True))
103
+ gt_depth = eval_sample_batched['depth']
104
+ has_valid_depth = eval_sample_batched['has_valid_depth']
105
+ if not has_valid_depth:
106
+ # print('Invalid depth. continue.')
107
+ continue
108
+
109
+ pred_depths_r_list, _, _ = model(image)
110
+ if post_process:
111
+ image_flipped = flip_lr(image)
112
+ pred_depths_r_list_flipped, _, _ = model(image_flipped)
113
+ pred_depth = post_process_depth(pred_depths_r_list[-1], pred_depths_r_list_flipped[-1])
114
+
115
+ pred_depth = pred_depth.cpu().numpy().squeeze()
116
+ gt_depth = gt_depth.cpu().numpy().squeeze()
117
+
118
+ if args.do_kb_crop:
119
+ height, width = gt_depth.shape
120
+ top_margin = int(height - 352)
121
+ left_margin = int((width - 1216) / 2)
122
+ pred_depth_uncropped = np.zeros((height, width), dtype=np.float32)
123
+ pred_depth_uncropped[top_margin:top_margin + 352, left_margin:left_margin + 1216] = pred_depth
124
+ pred_depth = pred_depth_uncropped
125
+
126
+ pred_depth[pred_depth < args.min_depth_eval] = args.min_depth_eval
127
+ pred_depth[pred_depth > args.max_depth_eval] = args.max_depth_eval
128
+ pred_depth[np.isinf(pred_depth)] = args.max_depth_eval
129
+ pred_depth[np.isnan(pred_depth)] = args.min_depth_eval
130
+
131
+ valid_mask = np.logical_and(gt_depth > args.min_depth_eval, gt_depth < args.max_depth_eval)
132
+
133
+ if args.garg_crop or args.eigen_crop:
134
+ gt_height, gt_width = gt_depth.shape
135
+ eval_mask = np.zeros(valid_mask.shape)
136
+
137
+ if args.garg_crop:
138
+ eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height), int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
139
+
140
+ elif args.eigen_crop:
141
+ if args.dataset == 'kitti':
142
+ eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height), int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
143
+ elif args.dataset == 'nyu':
144
+ eval_mask[45:471, 41:601] = 1
145
+
146
+ valid_mask = np.logical_and(valid_mask, eval_mask)
147
+
148
+ measures = compute_errors(gt_depth[valid_mask], pred_depth[valid_mask])
149
+
150
+ eval_measures[:9] += torch.tensor(measures).cuda(device=gpu)
151
+ eval_measures[9] += 1
152
+
153
+ if args.multiprocessing_distributed:
154
+ # group = dist.new_group([i for i in range(ngpus)])
155
+ dist.all_reduce(tensor=eval_measures, op=dist.ReduceOp.SUM, group=group)
156
+
157
+ if not args.multiprocessing_distributed or gpu == 0:
158
+ eval_measures_cpu = eval_measures.cpu()
159
+ cnt = eval_measures_cpu[9].item()
160
+ eval_measures_cpu /= cnt
161
+ print('Computing errors for {} eval samples'.format(int(cnt)), ', post_process: ', post_process)
162
+ print("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}".format('silog', 'abs_rel', 'log10', 'rms',
163
+ 'sq_rel', 'log_rms', 'd1', 'd2',
164
+ 'd3'))
165
+ for i in range(8):
166
+ print('{:7.4f}, '.format(eval_measures_cpu[i]), end='')
167
+ print('{:7.4f}'.format(eval_measures_cpu[8]))
168
+ return eval_measures_cpu
169
+
170
+ return None
171
+
172
+
173
+ def main_worker(gpu, ngpus_per_node, args):
174
+ args.gpu = gpu
175
+
176
+ if args.gpu is not None:
177
+ print("== Use GPU: {} for training".format(args.gpu))
178
+
179
+ if args.distributed:
180
+ if args.dist_url == "env://" and args.rank == -1:
181
+ args.rank = int(os.environ["RANK"])
182
+ if args.multiprocessing_distributed:
183
+ args.rank = args.rank * ngpus_per_node + gpu
184
+ dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
185
+
186
+ # model
187
+ model = NewCRFDepth(version=args.encoder, inv_depth=False, max_depth=args.max_depth, pretrained=args.pretrain)
188
+ model.train()
189
+
190
+ num_params = sum([np.prod(p.size()) for p in model.parameters()])
191
+ print("== Total number of parameters: {}".format(num_params))
192
+
193
+ num_params_update = sum([np.prod(p.shape) for p in model.parameters() if p.requires_grad])
194
+ print("== Total number of learning parameters: {}".format(num_params_update))
195
+
196
+ if args.distributed:
197
+ if args.gpu is not None:
198
+ torch.cuda.set_device(args.gpu)
199
+ model.cuda(args.gpu)
200
+ args.batch_size = int(args.batch_size / ngpus_per_node)
201
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
202
+ else:
203
+ model.cuda()
204
+ model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
205
+ else:
206
+ model = torch.nn.DataParallel(model)
207
+ model.cuda()
208
+
209
+ if args.distributed:
210
+ print("== Model Initialized on GPU: {}".format(args.gpu))
211
+ else:
212
+ print("== Model Initialized")
213
+
214
+ global_step = 0
215
+ best_eval_measures_lower_better = torch.zeros(6).cpu() + 1e3
216
+ best_eval_measures_higher_better = torch.zeros(3).cpu()
217
+ best_eval_steps = np.zeros(9, dtype=np.int32)
218
+
219
+ # Training parameters
220
+ optimizer = torch.optim.Adam([{'params': model.module.parameters()}],
221
+ lr=args.learning_rate)
222
+
223
+ model_just_loaded = False
224
+ if args.checkpoint_path != '':
225
+ if os.path.isfile(args.checkpoint_path):
226
+ print("== Loading checkpoint '{}'".format(args.checkpoint_path))
227
+ if args.gpu is None:
228
+ checkpoint = torch.load(args.checkpoint_path)
229
+ else:
230
+ loc = 'cuda:{}'.format(args.gpu)
231
+ checkpoint = torch.load(args.checkpoint_path, map_location=loc)
232
+ model.load_state_dict(checkpoint['model'])
233
+ optimizer.load_state_dict(checkpoint['optimizer'])
234
+ if not args.retrain:
235
+ try:
236
+ global_step = checkpoint['global_step']
237
+ best_eval_measures_higher_better = checkpoint['best_eval_measures_higher_better'].cpu()
238
+ best_eval_measures_lower_better = checkpoint['best_eval_measures_lower_better'].cpu()
239
+ best_eval_steps = checkpoint['best_eval_steps']
240
+ except KeyError:
241
+ print("Could not load values for online evaluation")
242
+
243
+ print("== Loaded checkpoint '{}' (global_step {})".format(args.checkpoint_path, checkpoint['global_step']))
244
+ else:
245
+ print("== No checkpoint found at '{}'".format(args.checkpoint_path))
246
+ model_just_loaded = True
247
+ del checkpoint
248
+
249
+ cudnn.benchmark = True
250
+
251
+ dataloader = NewDataLoader(args, 'train')
252
+ dataloader_eval = NewDataLoader(args, 'online_eval')
253
+
254
+ # ===== Evaluation before training ======
255
+ # model.eval()
256
+ # with torch.no_grad():
257
+ # eval_measures = online_eval(model, dataloader_eval, gpu, ngpus_per_node, post_process=True)
258
+
259
+ # Logging
260
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
261
+ writer = SummaryWriter(args.log_directory + '/' + args.model_name + '/summaries', flush_secs=30)
262
+ if args.do_online_eval:
263
+ if args.eval_summary_directory != '':
264
+ eval_summary_path = os.path.join(args.eval_summary_directory, args.model_name)
265
+ else:
266
+ eval_summary_path = os.path.join(args.log_directory, args.model_name, 'eval')
267
+ eval_summary_writer = SummaryWriter(eval_summary_path, flush_secs=30)
268
+
269
+ silog_criterion = silog_loss(variance_focus=args.variance_focus)
270
+ sum_localdepth = Sum_depth().cuda(args.gpu)
271
+
272
+ start_time = time.time()
273
+ duration = 0
274
+
275
+ num_log_images = args.batch_size
276
+ end_learning_rate = args.end_learning_rate if args.end_learning_rate != -1 else 0.1 * args.learning_rate
277
+
278
+ var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
279
+ var_cnt = len(var_sum)
280
+ var_sum = np.sum(var_sum)
281
+
282
+ print("== Initial variables' sum: {:.3f}, avg: {:.3f}".format(var_sum, var_sum/var_cnt))
283
+
284
+ steps_per_epoch = len(dataloader.data)
285
+ num_total_steps = args.num_epochs * steps_per_epoch
286
+ epoch = global_step // steps_per_epoch
287
+
288
+ group = dist.new_group([i for i in range(ngpus_per_node)])
289
+ while epoch < args.num_epochs:
290
+ if args.distributed:
291
+ dataloader.train_sampler.set_epoch(epoch)
292
+
293
+ for step, sample_batched in enumerate(dataloader.data):
294
+ optimizer.zero_grad()
295
+ before_op_time = time.time()
296
+ si_loss = 0
297
+
298
+ image = torch.autograd.Variable(sample_batched['image'].cuda(args.gpu, non_blocking=True))
299
+ depth_gt = torch.autograd.Variable(sample_batched['depth'].cuda(args.gpu, non_blocking=True))
300
+
301
+ pred_depths_r_list, pred_depths_c_list, uncertainty_maps_list = model(image, epoch, step)
302
+
303
+ if args.dataset == 'nyu':
304
+ mask = depth_gt > 0.1
305
+ else:
306
+ mask = depth_gt > 1.0
307
+
308
+ max_tree_depth = len(pred_depths_r_list)
309
+ for curr_tree_depth in range(max_tree_depth):
310
+
311
+ si_loss += silog_criterion.forward(pred_depths_r_list[curr_tree_depth], depth_gt, mask.to(torch.bool))
312
+
313
+ loss = si_loss
314
+
315
+ loss.backward()
316
+ for param_group in optimizer.param_groups:
317
+ current_lr = (args.learning_rate - end_learning_rate) * (1 - global_step / num_total_steps) ** 0.9 + end_learning_rate
318
+ param_group['lr'] = current_lr
319
+
320
+ optimizer.step()
321
+
322
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
323
+ print('[epoch][s/s_per_e/gs]: [{}][{}/{}/{}], lr: {:.12f}, loss: {:.12f}'.format(epoch, step, steps_per_epoch, global_step, current_lr, loss))
324
+ # if np.isnan(loss.cpu().item()):
325
+ # print('NaN in loss occurred. Aborting training.')
326
+ # return -1
327
+
328
+ duration += time.time() - before_op_time
329
+ if global_step and global_step % args.log_freq == 0 and not model_just_loaded:
330
+ var_sum = [var.sum().item() for var in model.parameters() if var.requires_grad]
331
+ var_cnt = len(var_sum)
332
+ var_sum = np.sum(var_sum)
333
+ examples_per_sec = args.batch_size / duration * args.log_freq
334
+ duration = 0
335
+ time_sofar = (time.time() - start_time) / 3600
336
+ training_time_left = (num_total_steps / global_step - 1.0) * time_sofar
337
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
338
+ print("{}".format(args.model_name))
339
+ print_string = 'GPU: {} | examples/s: {:4.2f} | loss: {:.5f} | var sum: {:.3f} avg: {:.3f} | time elapsed: {:.2f}h | time left: {:.2f}h'
340
+ print(print_string.format(args.gpu, examples_per_sec, loss, var_sum.item(), var_sum.item()/var_cnt, time_sofar, training_time_left))
341
+
342
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed
343
+ and args.rank % ngpus_per_node == 0):
344
+ writer.add_scalar('silog_loss', si_loss, global_step)
345
+ # writer.add_scalar('var_loss', var_loss, global_step)
346
+ writer.add_scalar('learning_rate', current_lr, global_step)
347
+ writer.add_scalar('var average', var_sum.item()/var_cnt, global_step)
348
+ depth_gt = torch.where(depth_gt < 1e-3, depth_gt * 0 + 1e-3, depth_gt)
349
+ for i in range(num_log_images):
350
+ if args.dataset == 'nyu':
351
+ writer.add_image('depth_gt/image/{}'.format(i), colormap(depth_gt[i, :, :, :].data), global_step)
352
+ writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)
353
+ writer.add_image('depth_r_est0/image/{}'.format(i), colormap(pred_depths_r_list[0][i, :, :, :].data), global_step)
354
+ writer.add_image('depth_r_est1/image/{}'.format(i), colormap(pred_depths_r_list[1][i, :, :, :].data), global_step)
355
+ writer.add_image('depth_r_est2/image/{}'.format(i), colormap(pred_depths_r_list[2][i, :, :, :].data), global_step)
356
+ writer.add_image('depth_r_est3/image/{}'.format(i), colormap(pred_depths_r_list[3][i, :, :, :].data), global_step)
357
+ writer.add_image('depth_r_est4/image/{}'.format(i), colormap(pred_depths_r_list[4][i, :, :, :].data), global_step)
358
+ writer.add_image('depth_r_est5/image/{}'.format(i), colormap(pred_depths_r_list[5][i, :, :, :].data), global_step)
359
+ writer.add_image('depth_c_est0/image/{}'.format(i), colormap(pred_depths_c_list[0][i, :, :, :].data), global_step)
360
+ writer.add_image('depth_c_est1/image/{}'.format(i), colormap(pred_depths_c_list[1][i, :, :, :].data), global_step)
361
+ writer.add_image('depth_c_est2/image/{}'.format(i), colormap(pred_depths_c_list[2][i, :, :, :].data), global_step)
362
+ writer.add_image('depth_c_est3/image/{}'.format(i), colormap(pred_depths_c_list[3][i, :, :, :].data), global_step)
363
+ writer.add_image('depth_c_est4/image/{}'.format(i), colormap(pred_depths_c_list[4][i, :, :, :].data), global_step)
364
+ writer.add_image('depth_c_est5/image/{}'.format(i), colormap(pred_depths_c_list[5][i, :, :, :].data), global_step)
365
+ else:
366
+ writer.add_image('depth_gt/image/{}'.format(i), colormap_magma(torch.log10(depth_gt[i, :, :, :].data)), global_step)
367
+ writer.add_image('image/image/{}'.format(i), inv_normalize(image[i, :, :, :]).data, global_step)
368
+ writer.add_image('depth_r_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[0][i, :, :, :].data)), global_step)
369
+ writer.add_image('depth_r_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[1][i, :, :, :].data)), global_step)
370
+ writer.add_image('depth_r_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[2][i, :, :, :].data)), global_step)
371
+ writer.add_image('depth_r_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[3][i, :, :, :].data)), global_step)
372
+ writer.add_image('depth_r_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[4][i, :, :, :].data)), global_step)
373
+ writer.add_image('depth_r_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_r_list[5][i, :, :, :].data)), global_step)
374
+ writer.add_image('depth_c_est0/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[0][i, :, :, :].data)), global_step)
375
+ writer.add_image('depth_c_est1/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[1][i, :, :, :].data)), global_step)
376
+ writer.add_image('depth_c_est2/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[2][i, :, :, :].data)), global_step)
377
+ writer.add_image('depth_c_est3/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[3][i, :, :, :].data)), global_step)
378
+ writer.add_image('depth_c_est4/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[4][i, :, :, :].data)), global_step)
379
+ writer.add_image('depth_c_est5/image/{}'.format(i), colormap_magma(torch.log10(pred_depths_c_list[5][i, :, :, :].data)), global_step)
380
+
381
+ writer.add_image('uncer_est0/image/{}'.format(i), colormap(uncertainty_maps_list[0][i, :, :, :].data), global_step)
382
+ writer.add_image('uncer_est1/image/{}'.format(i), colormap(uncertainty_maps_list[1][i, :, :, :].data), global_step)
383
+ writer.add_image('uncer_est2/image/{}'.format(i), colormap(uncertainty_maps_list[2][i, :, :, :].data), global_step)
384
+ writer.add_image('uncer_est3/image/{}'.format(i), colormap(uncertainty_maps_list[3][i, :, :, :].data), global_step)
385
+ writer.add_image('uncer_est4/image/{}'.format(i), colormap(uncertainty_maps_list[4][i, :, :, :].data), global_step)
386
+ writer.add_image('uncer_est5/image/{}'.format(i), colormap(uncertainty_maps_list[5][i, :, :, :].data), global_step)
387
+
388
+ if args.do_online_eval and global_step and global_step % args.eval_freq == 0 and not model_just_loaded:
389
+ time.sleep(0.1)
390
+ model.eval()
391
+ with torch.no_grad():
392
+ eval_measures = online_eval(model, dataloader_eval, gpu, epoch, ngpus_per_node, group, post_process=True)
393
+ if eval_measures is not None:
394
+ exp_name = '%s'%(datetime.now().strftime('%m%d'))
395
+ log_txt = os.path.join(args.log_directory + '/' + args.model_name, exp_name+'_logs.txt')
396
+ with open(log_txt, 'a') as txtfile:
397
+ txtfile.write(">>>>>>>>>>>>>>>>>>>>>>>>>Step:%d>>>>>>>>>>>>>>>>>>>>>>>>>\n"%(int(global_step)))
398
+ txtfile.write("{:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}, {:>7}\n".format('silog',
399
+ 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2','d3'))
400
+ txtfile.write("depth estimation\n")
401
+ line = ''
402
+ for i in range(9):
403
+ line +='{:7.4f}, '.format(eval_measures[i])
404
+ txtfile.write(line+'\n')
405
+
406
+ for i in range(9):
407
+ eval_summary_writer.add_scalar(eval_metrics[i], eval_measures[i].cpu(), int(global_step))
408
+ measure = eval_measures[i]
409
+ is_best = False
410
+ if i < 6 and measure < best_eval_measures_lower_better[i]:
411
+ old_best = best_eval_measures_lower_better[i].item()
412
+ best_eval_measures_lower_better[i] = measure.item()
413
+ is_best = True
414
+ elif i >= 6 and measure > best_eval_measures_higher_better[i-6]:
415
+ old_best = best_eval_measures_higher_better[i-6].item()
416
+ best_eval_measures_higher_better[i-6] = measure.item()
417
+ is_best = True
418
+ if is_best:
419
+ old_best_step = best_eval_steps[i]
420
+ old_best_name = '/model-{}-best_{}_{:.5f}'.format(old_best_step, eval_metrics[i], old_best)
421
+ model_path = args.log_directory + '/' + args.model_name + old_best_name
422
+ if os.path.exists(model_path):
423
+ command = 'rm {}'.format(model_path)
424
+ os.system(command)
425
+ best_eval_steps[i] = global_step
426
+ model_save_name = '/model-{}-best_{}_{:.5f}'.format(global_step, eval_metrics[i], measure)
427
+ print('New best for {}. Saving model: {}'.format(eval_metrics[i], model_save_name))
428
+ checkpoint = {'global_step': global_step,
429
+ 'model': model.state_dict(),
430
+ 'optimizer': optimizer.state_dict(),
431
+ 'best_eval_measures_higher_better': best_eval_measures_higher_better,
432
+ 'best_eval_measures_lower_better': best_eval_measures_lower_better,
433
+ 'best_eval_steps': best_eval_steps
434
+ }
435
+ torch.save(checkpoint, args.log_directory + '/' + args.model_name + model_save_name)
436
+ eval_summary_writer.flush()
437
+ model.train()
438
+ block_print()
439
+ enable_print()
440
+
441
+ model_just_loaded = False
442
+ global_step += 1
443
+
444
+ epoch += 1
445
+
446
+ if not args.multiprocessing_distributed or (args.multiprocessing_distributed and args.rank % ngpus_per_node == 0):
447
+ writer.close()
448
+ if args.do_online_eval:
449
+ eval_summary_writer.close()
450
+
451
+
452
+ def main():
453
+ if args.mode != 'train':
454
+ print('train.py is only for training.')
455
+ return -1
456
+
457
+ exp_name = '%s'%(datetime.now().strftime('%m%d'))
458
+ args.log_directory = os.path.join(args.log_directory,exp_name)
459
+ command = 'mkdir ' + os.path.join(args.log_directory, args.model_name)
460
+ os.system(command)
461
+
462
+ args_out_path = os.path.join(args.log_directory, args.model_name)
463
+ command = 'cp ' + sys.argv[1] + ' ' + args_out_path
464
+ os.system(command)
465
+
466
+ save_files = True
467
+ if save_files:
468
+ aux_out_path = os.path.join(args.log_directory, args.model_name)
469
+ networks_savepath = os.path.join(aux_out_path, 'networks')
470
+ dataloaders_savepath = os.path.join(aux_out_path, 'dataloaders')
471
+ command = 'cp iebins/train.py ' + aux_out_path
472
+ os.system(command)
473
+ command = 'mkdir -p ' + networks_savepath + ' && cp iebins/networks/*.py ' + networks_savepath
474
+ os.system(command)
475
+ command = 'mkdir -p ' + dataloaders_savepath + ' && cp iebins/dataloaders/*.py ' + dataloaders_savepath
476
+ os.system(command)
477
+
478
+ torch.cuda.empty_cache()
479
+ args.distributed = args.world_size > 1 or args.multiprocessing_distributed
480
+
481
+ ngpus_per_node = torch.cuda.device_count()
482
+ if ngpus_per_node > 1 and not args.multiprocessing_distributed:
483
+ print("This machine has more than 1 gpu. Please specify --multiprocessing_distributed, or set \'CUDA_VISIBLE_DEVICES=0\'")
484
+ return -1
485
+
486
+ if args.do_online_eval:
487
+ print("You have specified --do_online_eval.")
488
+ print("This will evaluate the model every eval_freq {} steps and save best models for individual eval metrics."
489
+ .format(args.eval_freq))
490
+
491
+ if args.multiprocessing_distributed:
492
+ args.world_size = ngpus_per_node * args.world_size
493
+ mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
494
+ else:
495
+ main_worker(args.gpu, ngpus_per_node, args)
496
+
497
+
498
+ if __name__ == '__main__':
499
+ main()
iebins/utils.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.distributed as dist
5
+ from torch.utils.data import Sampler
6
+ from torchvision import transforms
7
+ import matplotlib.pyplot as plt
8
+ import os, sys
9
+ import numpy as np
10
+ import math
11
+ import torch
12
+
13
+
14
+ def convert_arg_line_to_args(arg_line):
15
+ for arg in arg_line.split():
16
+ if not arg.strip():
17
+ continue
18
+ yield arg
19
+
20
+
21
+ def block_print():
22
+ sys.stdout = open(os.devnull, 'w')
23
+
24
+
25
+ def enable_print():
26
+ sys.stdout = sys.__stdout__
27
+
28
+
29
+ def get_num_lines(file_path):
30
+ f = open(file_path, 'r')
31
+ lines = f.readlines()
32
+ f.close()
33
+ return len(lines)
34
+
35
+
36
+ def colorize(value, vmin=None, vmax=None, cmap='Greys'):
37
+ value = value.cpu().numpy()[:, :, :]
38
+ value = np.log10(value)
39
+
40
+ vmin = value.min() if vmin is None else vmin
41
+ vmax = value.max() if vmax is None else vmax
42
+
43
+ if vmin != vmax:
44
+ value = (value - vmin) / (vmax - vmin)
45
+ else:
46
+ value = value*0.
47
+
48
+ cmapper = matplotlib.cm.get_cmap(cmap)
49
+ value = cmapper(value, bytes=True)
50
+
51
+ img = value[:, :, :3]
52
+
53
+ return img.transpose((2, 0, 1))
54
+
55
+
56
+ def normalize_result(value, vmin=None, vmax=None):
57
+ value = value.cpu().numpy()[0, :, :]
58
+
59
+ vmin = value.min() if vmin is None else vmin
60
+ vmax = value.max() if vmax is None else vmax
61
+
62
+ if vmin != vmax:
63
+ value = (value - vmin) / (vmax - vmin)
64
+ else:
65
+ value = value * 0.
66
+
67
+ return np.expand_dims(value, 0)
68
+
69
+
70
+ inv_normalize = transforms.Normalize(
71
+ mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
72
+ std=[1/0.229, 1/0.224, 1/0.225]
73
+ )
74
+
75
+
76
+ eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3']
77
+
78
+
79
+ def compute_errors(gt, pred):
80
+ thresh = np.maximum((gt / pred), (pred / gt))
81
+ d1 = (thresh < 1.25).mean()
82
+ d2 = (thresh < 1.25 ** 2).mean()
83
+ d3 = (thresh < 1.25 ** 3).mean()
84
+
85
+ rms = (gt - pred) ** 2
86
+ rms = np.sqrt(rms.mean())
87
+
88
+ log_rms = (np.log(gt) - np.log(pred)) ** 2
89
+ log_rms = np.sqrt(log_rms.mean())
90
+
91
+ abs_rel = np.mean(np.abs(gt - pred) / gt)
92
+ sq_rel = np.mean(((gt - pred) ** 2) / gt)
93
+
94
+ err = np.log(pred) - np.log(gt)
95
+ silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
96
+
97
+ err = np.abs(np.log10(pred) - np.log10(gt))
98
+ log10 = np.mean(err)
99
+
100
+ return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3]
101
+
102
+
103
+ class silog_loss(nn.Module):
104
+ def __init__(self, variance_focus):
105
+ super(silog_loss, self).__init__()
106
+ self.variance_focus = variance_focus
107
+
108
+ def forward(self, depth_est, depth_gt, mask):
109
+ d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask])
110
+ return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0
111
+
112
+
113
+ def entropy_loss(preds, gt_label, mask):
114
+ # preds: B, C, H, W
115
+ # gt_label: B, H, W
116
+ # mask: B, H, W
117
+ mask = mask > 0.0 # B, H, W
118
+ preds = preds.permute(0, 2, 3, 1) # B, H, W, C
119
+ preds_mask = preds[mask] # N, C
120
+ gt_label_mask = gt_label[mask] # N
121
+ loss = F.cross_entropy(preds_mask, gt_label_mask, reduction='mean')
122
+ return loss
123
+
124
+
125
+ def colormap(inputs, normalize=True, torch_transpose=True):
126
+ if isinstance(inputs, torch.Tensor):
127
+ inputs = inputs.detach().cpu().numpy()
128
+ _DEPTH_COLORMAP = plt.get_cmap('jet', 256) # for plotting
129
+ vis = inputs
130
+ if normalize:
131
+ ma = float(vis.max())
132
+ mi = float(vis.min())
133
+ d = ma - mi if ma != mi else 1e5
134
+ vis = (vis - mi) / d
135
+
136
+ if vis.ndim == 4:
137
+ vis = vis.transpose([0, 2, 3, 1])
138
+ vis = _DEPTH_COLORMAP(vis)
139
+ vis = vis[:, :, :, 0, :3]
140
+ if torch_transpose:
141
+ vis = vis.transpose(0, 3, 1, 2)
142
+ elif vis.ndim == 3:
143
+ vis = _DEPTH_COLORMAP(vis)
144
+ vis = vis[:, :, :, :3]
145
+ if torch_transpose:
146
+ vis = vis.transpose(0, 3, 1, 2)
147
+ elif vis.ndim == 2:
148
+ vis = _DEPTH_COLORMAP(vis)
149
+ vis = vis[..., :3]
150
+ if torch_transpose:
151
+ vis = vis.transpose(2, 0, 1)
152
+
153
+ return vis[0,:,:,:]
154
+
155
+
156
+ def colormap_magma(inputs, normalize=True, torch_transpose=True):
157
+ if isinstance(inputs, torch.Tensor):
158
+ inputs = inputs.detach().cpu().numpy()
159
+ _DEPTH_COLORMAP = plt.get_cmap('magma', 256) # for plotting
160
+ vis = inputs
161
+ if normalize:
162
+ ma = float(vis.max())
163
+ mi = float(vis.min())
164
+ d = ma - mi if ma != mi else 1e5
165
+ vis = (vis - mi) / d
166
+
167
+ if vis.ndim == 4:
168
+ vis = vis.transpose([0, 2, 3, 1])
169
+ vis = _DEPTH_COLORMAP(vis)
170
+ vis = vis[:, :, :, 0, :3]
171
+ if torch_transpose:
172
+ vis = vis.transpose(0, 3, 1, 2)
173
+ elif vis.ndim == 3:
174
+ vis = _DEPTH_COLORMAP(vis)
175
+ vis = vis[:, :, :, :3]
176
+ if torch_transpose:
177
+ vis = vis.transpose(0, 3, 1, 2)
178
+ elif vis.ndim == 2:
179
+ vis = _DEPTH_COLORMAP(vis)
180
+ vis = vis[..., :3]
181
+ if torch_transpose:
182
+ vis = vis.transpose(2, 0, 1)
183
+
184
+ return vis[0,:,:,:]
185
+
186
+
187
+ def flip_lr(image):
188
+ """
189
+ Flip image horizontally
190
+
191
+ Parameters
192
+ ----------
193
+ image : torch.Tensor [B,3,H,W]
194
+ Image to be flipped
195
+
196
+ Returns
197
+ -------
198
+ image_flipped : torch.Tensor [B,3,H,W]
199
+ Flipped image
200
+ """
201
+ assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip'
202
+ return torch.flip(image, [3])
203
+
204
+
205
+ def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'):
206
+ """
207
+ Fuse inverse depth and flipped inverse depth maps
208
+
209
+ Parameters
210
+ ----------
211
+ inv_depth : torch.Tensor [B,1,H,W]
212
+ Inverse depth map
213
+ inv_depth_hat : torch.Tensor [B,1,H,W]
214
+ Flipped inverse depth map produced from a flipped image
215
+ method : str
216
+ Method that will be used to fuse the inverse depth maps
217
+
218
+ Returns
219
+ -------
220
+ fused_inv_depth : torch.Tensor [B,1,H,W]
221
+ Fused inverse depth map
222
+ """
223
+ if method == 'mean':
224
+ return 0.5 * (inv_depth + inv_depth_hat)
225
+ elif method == 'max':
226
+ return torch.max(inv_depth, inv_depth_hat)
227
+ elif method == 'min':
228
+ return torch.min(inv_depth, inv_depth_hat)
229
+ else:
230
+ raise ValueError('Unknown post-process method {}'.format(method))
231
+
232
+
233
+ def post_process_depth(depth, depth_flipped, method='mean'):
234
+ """
235
+ Post-process an inverse and flipped inverse depth map
236
+
237
+ Parameters
238
+ ----------
239
+ inv_depth : torch.Tensor [B,1,H,W]
240
+ Inverse depth map
241
+ inv_depth_flipped : torch.Tensor [B,1,H,W]
242
+ Inverse depth map produced from a flipped image
243
+ method : str
244
+ Method that will be used to fuse the inverse depth maps
245
+
246
+ Returns
247
+ -------
248
+ inv_depth_pp : torch.Tensor [B,1,H,W]
249
+ Post-processed inverse depth map
250
+ """
251
+ B, C, H, W = depth.shape
252
+ inv_depth_hat = flip_lr(depth_flipped)
253
+ inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method)
254
+ xs = torch.linspace(0., 1., W, device=depth.device,
255
+ dtype=depth.dtype).repeat(B, C, H, 1)
256
+ mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.)
257
+ mask_hat = flip_lr(mask)
258
+ return mask_hat * depth + mask * inv_depth_hat + \
259
+ (1.0 - mask - mask_hat) * inv_depth_fused
260
+
261
+
262
+ class DistributedSamplerNoEvenlyDivisible(Sampler):
263
+ """Sampler that restricts data loading to a subset of the dataset.
264
+
265
+ It is especially useful in conjunction with
266
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
267
+ process can pass a DistributedSampler instance as a DataLoader sampler,
268
+ and load a subset of the original dataset that is exclusive to it.
269
+
270
+ .. note::
271
+ Dataset is assumed to be of constant size.
272
+
273
+ Arguments:
274
+ dataset: Dataset used for sampling.
275
+ num_replicas (optional): Number of processes participating in
276
+ distributed training.
277
+ rank (optional): Rank of the current process within num_replicas.
278
+ shuffle (optional): If true (default), sampler will shuffle the indices
279
+ """
280
+
281
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
282
+ if num_replicas is None:
283
+ if not dist.is_available():
284
+ raise RuntimeError("Requires distributed package to be available")
285
+ num_replicas = dist.get_world_size()
286
+ if rank is None:
287
+ if not dist.is_available():
288
+ raise RuntimeError("Requires distributed package to be available")
289
+ rank = dist.get_rank()
290
+ self.dataset = dataset
291
+ self.num_replicas = num_replicas
292
+ self.rank = rank
293
+ self.epoch = 0
294
+ num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
295
+ rest = len(self.dataset) - num_samples * self.num_replicas
296
+ if self.rank < rest:
297
+ num_samples += 1
298
+ self.num_samples = num_samples
299
+ self.total_size = len(dataset)
300
+ # self.total_size = self.num_samples * self.num_replicas
301
+ self.shuffle = shuffle
302
+
303
+ def __iter__(self):
304
+ # deterministically shuffle based on epoch
305
+ g = torch.Generator()
306
+ g.manual_seed(self.epoch)
307
+ if self.shuffle:
308
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
309
+ else:
310
+ indices = list(range(len(self.dataset)))
311
+
312
+ # add extra samples to make it evenly divisible
313
+ # indices += indices[:(self.total_size - len(indices))]
314
+ # assert len(indices) == self.total_size
315
+
316
+ # subsample
317
+ indices = indices[self.rank:self.total_size:self.num_replicas]
318
+ self.num_samples = len(indices)
319
+ # assert len(indices) == self.num_samples
320
+
321
+ return iter(indices)
322
+
323
+ def __len__(self):
324
+ return self.num_samples
325
+
326
+ def set_epoch(self, epoch):
327
+ self.epoch = epoch
328
+
329
+
330
+ class D_to_cloud(nn.Module):
331
+ """Layer to transform depth into point cloud
332
+ """
333
+ def __init__(self, batch_size, height, width):
334
+ super(D_to_cloud, self).__init__()
335
+
336
+ self.batch_size = batch_size
337
+ self.height = height
338
+ self.width = width
339
+
340
+ meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
341
+ self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) # 2, H, W
342
+ self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), requires_grad=False) # 2, H, W
343
+
344
+ self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
345
+ requires_grad=False) # B, 1, H, W
346
+
347
+ self.pix_coords = torch.unsqueeze(torch.stack(
348
+ [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) # 1, 2, L
349
+ self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) # B, 2, L
350
+ self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), requires_grad=False) # B, 3, L
351
+
352
+ def forward(self, depth, inv_K):
353
+ cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
354
+ cam_points = depth.view(self.batch_size, 1, -1) * cam_points
355
+
356
+ return cam_points.permute(0, 2, 1)
iebins/utils/transfrom.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from PIL import Image, ImageOps, ImageFilter
3
+ import torch
4
+ from torchvision import transforms
5
+ import torch.nn.functional as F
6
+
7
+ import numpy as np
8
+ import cv2
9
+ import math
10
+
11
+
12
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
13
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
14
+ Args:
15
+ sample (dict): sample
16
+ size (tuple): image size
17
+ Returns:
18
+ tuple: new size
19
+ """
20
+ shape = list(sample["disparity"].shape)
21
+
22
+ if shape[0] >= size[0] and shape[1] >= size[1]:
23
+ return sample
24
+
25
+ scale = [0, 0]
26
+ scale[0] = size[0] / shape[0]
27
+ scale[1] = size[1] / shape[1]
28
+
29
+ scale = max(scale)
30
+
31
+ shape[0] = math.ceil(scale * shape[0])
32
+ shape[1] = math.ceil(scale * shape[1])
33
+
34
+ # resize
35
+ sample["image"] = cv2.resize(
36
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
37
+ )
38
+
39
+ sample["disparity"] = cv2.resize(
40
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
41
+ )
42
+ sample["mask"] = cv2.resize(
43
+ sample["mask"].astype(np.float32),
44
+ tuple(shape[::-1]),
45
+ interpolation=cv2.INTER_NEAREST,
46
+ )
47
+ sample["mask"] = sample["mask"].astype(bool)
48
+
49
+ return tuple(shape)
50
+
51
+
52
+ class Resize(object):
53
+ """Resize sample to given size (width, height).
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ width,
59
+ height,
60
+ resize_target=True,
61
+ keep_aspect_ratio=False,
62
+ ensure_multiple_of=1,
63
+ resize_method="lower_bound",
64
+ image_interpolation_method=cv2.INTER_AREA,
65
+ ):
66
+ """Init.
67
+ Args:
68
+ width (int): desired output width
69
+ height (int): desired output height
70
+ resize_target (bool, optional):
71
+ True: Resize the full sample (image, mask, target).
72
+ False: Resize image only.
73
+ Defaults to True.
74
+ keep_aspect_ratio (bool, optional):
75
+ True: Keep the aspect ratio of the input sample.
76
+ Output sample might not have the given width and height, and
77
+ resize behaviour depends on the parameter 'resize_method'.
78
+ Defaults to False.
79
+ ensure_multiple_of (int, optional):
80
+ Output width and height is constrained to be multiple of this parameter.
81
+ Defaults to 1.
82
+ resize_method (str, optional):
83
+ "lower_bound": Output will be at least as large as the given size.
84
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
85
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
86
+ Defaults to "lower_bound".
87
+ """
88
+ self.__width = width
89
+ self.__height = height
90
+
91
+ self.__resize_target = resize_target
92
+ self.__keep_aspect_ratio = keep_aspect_ratio
93
+ self.__multiple_of = ensure_multiple_of
94
+ self.__resize_method = resize_method
95
+ self.__image_interpolation_method = image_interpolation_method
96
+
97
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
98
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if max_val is not None and y > max_val:
101
+ y = (np.floor(x / self.__multiple_of)
102
+ * self.__multiple_of).astype(int)
103
+
104
+ if y < min_val:
105
+ y = (np.ceil(x / self.__multiple_of)
106
+ * self.__multiple_of).astype(int)
107
+
108
+ return y
109
+
110
+ def get_size(self, width, height):
111
+ # determine new height and width
112
+ scale_height = self.__height / height
113
+ scale_width = self.__width / width
114
+
115
+ if self.__keep_aspect_ratio:
116
+ if self.__resize_method == "lower_bound":
117
+ # scale such that output size is lower bound
118
+ if scale_width > scale_height:
119
+ # fit width
120
+ scale_height = scale_width
121
+ else:
122
+ # fit height
123
+ scale_width = scale_height
124
+ elif self.__resize_method == "upper_bound":
125
+ # scale such that output size is upper bound
126
+ if scale_width < scale_height:
127
+ # fit width
128
+ scale_height = scale_width
129
+ else:
130
+ # fit height
131
+ scale_width = scale_height
132
+ elif self.__resize_method == "minimal":
133
+ # scale as least as possbile
134
+ if abs(1 - scale_width) < abs(1 - scale_height):
135
+ # fit width
136
+ scale_height = scale_width
137
+ else:
138
+ # fit height
139
+ scale_width = scale_height
140
+ else:
141
+ raise ValueError(
142
+ f"resize_method {self.__resize_method} not implemented"
143
+ )
144
+
145
+ if self.__resize_method == "lower_bound":
146
+ new_height = self.constrain_to_multiple_of(
147
+ scale_height * height, min_val=self.__height
148
+ )
149
+ new_width = self.constrain_to_multiple_of(
150
+ scale_width * width, min_val=self.__width
151
+ )
152
+ elif self.__resize_method == "upper_bound":
153
+ new_height = self.constrain_to_multiple_of(
154
+ scale_height * height, max_val=self.__height
155
+ )
156
+ new_width = self.constrain_to_multiple_of(
157
+ scale_width * width, max_val=self.__width
158
+ )
159
+ elif self.__resize_method == "minimal":
160
+ new_height = self.constrain_to_multiple_of(scale_height * height)
161
+ new_width = self.constrain_to_multiple_of(scale_width * width)
162
+ else:
163
+ raise ValueError(f"resize_method {
164
+ self.__resize_method} not implemented")
165
+
166
+ return (new_width, new_height)
167
+
168
+ def __call__(self, sample):
169
+ width, height = self.get_size(
170
+ sample["image"].shape[1], sample["image"].shape[0]
171
+ )
172
+
173
+ # resize sample
174
+ sample["image"] = cv2.resize(
175
+ sample["image"],
176
+ (width, height),
177
+ interpolation=self.__image_interpolation_method,
178
+ )
179
+
180
+ if self.__resize_target:
181
+ if "disparity" in sample:
182
+ sample["disparity"] = cv2.resize(
183
+ sample["disparity"],
184
+ (width, height),
185
+ interpolation=cv2.INTER_NEAREST,
186
+ )
187
+
188
+ if "depth" in sample:
189
+ sample["depth"] = cv2.resize(
190
+ sample["depth"], (width,
191
+ height), interpolation=cv2.INTER_NEAREST
192
+ )
193
+
194
+ if "semseg_mask" in sample:
195
+ # sample["semseg_mask"] = cv2.resize(
196
+ # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
197
+ # )
198
+ sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[
199
+ None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
200
+
201
+ if "mask" in sample:
202
+ sample["mask"] = cv2.resize(
203
+ sample["mask"].astype(np.float32),
204
+ (width, height),
205
+ interpolation=cv2.INTER_NEAREST,
206
+ )
207
+ # sample["mask"] = sample["mask"].astype(bool)
208
+
209
+ # print(sample['image'].shape, sample['depth'].shape)
210
+ return sample
211
+
212
+
213
+ class NormalizeImage(object):
214
+ """Normlize image by given mean and std.
215
+ """
216
+
217
+ def __init__(self, mean, std):
218
+ self.__mean = mean
219
+ self.__std = std
220
+
221
+ def __call__(self, sample):
222
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
223
+
224
+ return sample
225
+
226
+
227
+ class PrepareForNet(object):
228
+ """Prepare sample for usage as network input.
229
+ """
230
+
231
+ def __init__(self):
232
+ pass
233
+
234
+ def __call__(self, sample):
235
+ image = np.transpose(sample["image"], (2, 0, 1))
236
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
237
+
238
+ if "mask" in sample:
239
+ sample["mask"] = sample["mask"].astype(np.float32)
240
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
241
+
242
+ if "depth" in sample:
243
+ depth = sample["depth"].astype(np.float32)
244
+ sample["depth"] = np.ascontiguousarray(depth)
245
+
246
+ if "semseg_mask" in sample:
247
+ sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
248
+ sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
249
+
250
+ return sample
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pytorch=1.10.0
2
+ torchvision
3
+ cudatoolkit=11.1
4
+ matplotlib
5
+ tqdm
6
+ tensorboardX
7
+ timm
8
+ mmcv
9
+ open3d
10
+ gradio_imageslider
11
+ torch
12
+ opencv-python