GonzaloMG commited on
Commit
5c08b12
1 Parent(s): 0db2418

Delete Marigold

Browse files
Marigold/README.md DELETED
@@ -1 +0,0 @@
1
- Code is copied from https://github.com/prs-eth/Marigold. Modifications are indicated within the code.
 
 
Marigold/marigold/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
-
20
-
21
- from .marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput # noqa: F401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Marigold/marigold/marigold_pipeline.py DELETED
@@ -1,534 +0,0 @@
1
- # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
-
20
- # @GonzaloMartinGarcia
21
- # This file is a modified version of the original Marigold pipeline file.
22
- # Based on GeoWizard, we added the option to sample surface normals, marked with # add.
23
-
24
- from typing import Dict, Union
25
-
26
- import numpy as np
27
- import torch
28
- from diffusers import (
29
- AutoencoderKL,
30
- DDIMScheduler,
31
- DiffusionPipeline,
32
- LCMScheduler,
33
- UNet2DConditionModel,
34
- DDPMScheduler,
35
- )
36
- from diffusers.utils import BaseOutput
37
- from PIL import Image
38
- from torchvision.transforms.functional import resize, pil_to_tensor
39
- from torchvision.transforms import InterpolationMode
40
- from torch.utils.data import DataLoader, TensorDataset
41
- from tqdm.auto import tqdm
42
- from transformers import CLIPTextModel, CLIPTokenizer
43
-
44
- from .util.batchsize import find_batch_size
45
- from .util.ensemble import ensemble_depths
46
- from .util.image_util import (
47
- chw2hwc,
48
- colorize_depth_maps,
49
- get_tv_resample_method,
50
- resize_max_res,
51
- )
52
-
53
- # add
54
- import random
55
-
56
-
57
- # add
58
- # Surface Normals Ensamble from the GeoWizard github repository (https://github.com/fuxiao0719/GeoWizard)
59
- def ensemble_normals(input_images:torch.Tensor):
60
- normal_preds = input_images
61
- bsz, d, h, w = normal_preds.shape
62
- normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
63
- phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
64
- theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
65
- normal_pred = torch.zeros((d,h,w)).to(normal_preds)
66
- normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
67
- normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
68
- normal_pred[2,:,:] = torch.cos(theta)
69
- angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999))
70
- normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
71
- return normal_preds[normal_idx], None
72
-
73
- # add
74
- # Pyramid nosie from
75
- # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31
76
- def pyramid_noise_like(x, discount=0.9):
77
- b, c, w, h = x.shape
78
- u = torch.nn.Upsample(size=(w, h), mode='bilinear')
79
- noise = torch.randn_like(x)
80
- for i in range(10):
81
- r = random.random()*2+2
82
- w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
83
- noise += u(torch.randn(b, c, w, h).to(x)) * discount**i
84
- if w==1 or h==1:
85
- break
86
- return noise / noise.std()
87
-
88
- class MarigoldDepthOutput(BaseOutput):
89
- """
90
- Output class for Marigold monocular depth prediction pipeline.
91
-
92
- Args:
93
- depth_np (`np.ndarray`):
94
- Predicted depth map, with depth values in the range of [0, 1].
95
- depth_colored (`PIL.Image.Image`):
96
- Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
97
- uncertainty (`None` or `np.ndarray`):
98
- Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
99
- normal_np (`np.ndarray`):
100
- Predicted normal map, with normal vectors in the range of [-1, 1].
101
- normal_colored (`PIL.Image.Image`):
102
- Colorized normal map
103
- """
104
-
105
- depth_np: np.ndarray
106
- depth_colored: Union[None, Image.Image]
107
- uncertainty: Union[None, np.ndarray]
108
- # add
109
- normal_np: np.ndarray
110
- normal_colored: Union[None, Image.Image]
111
-
112
-
113
- class MarigoldPipeline(DiffusionPipeline):
114
- """
115
- Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
116
-
117
- This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
118
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
119
-
120
- Args:
121
- unet (`UNet2DConditionModel`):
122
- Conditional U-Net to denoise the depth latent, conditioned on image latent.
123
- vae (`AutoencoderKL`):
124
- Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
125
- to and from latent representations.
126
- scheduler (`DDIMScheduler`):
127
- A scheduler to be used in combination with `unet` to denoise the encoded image latents.
128
- text_encoder (`CLIPTextModel`):
129
- Text-encoder, for empty text embedding.
130
- tokenizer (`CLIPTokenizer`):
131
- CLIP tokenizer.
132
- """
133
-
134
- rgb_latent_scale_factor = 0.18215
135
- depth_latent_scale_factor = 0.18215
136
-
137
- def __init__(
138
- self,
139
- unet: UNet2DConditionModel,
140
- vae: AutoencoderKL,
141
- scheduler: Union[DDIMScheduler,DDPMScheduler,LCMScheduler],
142
- text_encoder: CLIPTextModel,
143
- tokenizer: CLIPTokenizer,
144
- ):
145
- super().__init__()
146
-
147
- self.register_modules(
148
- unet=unet,
149
- vae=vae,
150
- scheduler=scheduler,
151
- text_encoder=text_encoder,
152
- tokenizer=tokenizer,
153
- )
154
-
155
- self.empty_text_embed = None
156
-
157
- @torch.no_grad()
158
- def __call__(
159
- self,
160
- input_image: Union[Image.Image, torch.Tensor],
161
- denoising_steps: int = 10,
162
- ensemble_size: int = 10,
163
- processing_res: int = 768,
164
- match_input_res: bool = True,
165
- resample_method: str = "bilinear",
166
- batch_size: int = 0,
167
- color_map: str = "Spectral",
168
- show_progress_bar: bool = True,
169
- ensemble_kwargs: Dict = None,
170
- # add
171
- noise="gaussian",
172
- normals=False,
173
- ) -> MarigoldDepthOutput:
174
- """
175
- Function invoked when calling the pipeline.
176
-
177
- Args:
178
- input_image (`Image`):
179
- Input RGB (or gray-scale) image.
180
- processing_res (`int`, *optional*, defaults to `768`):
181
- Maximum resolution of processing.
182
- If set to 0: will not resize at all.
183
- match_input_res (`bool`, *optional*, defaults to `True`):
184
- Resize depth prediction to match input resolution.
185
- Only valid if `processing_res` > 0.
186
- resample_method: (`str`, *optional*, defaults to `bilinear`):
187
- Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
188
- denoising_steps (`int`, *optional*, defaults to `10`):
189
- Number of diffusion denoising steps (DDIM) during inference.
190
- ensemble_size (`int`, *optional*, defaults to `10`):
191
- Number of predictions to be ensembled.
192
- batch_size (`int`, *optional*, defaults to `0`):
193
- Inference batch size, no bigger than `num_ensemble`.
194
- If set to 0, the script will automatically decide the proper batch size.
195
- show_progress_bar (`bool`, *optional*, defaults to `True`):
196
- Display a progress bar of diffusion denoising.
197
- color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
198
- Colormap used to colorize the depth map.
199
- ensemble_kwargs (`dict`, *optional*, defaults to `None`):
200
- Arguments for detailed ensembling settings.
201
- noise (`str`, *optional*, defaults to `gaussian`):
202
- Type of noise to be used for the initial depth map.
203
- Can be one of `gaussian`, `pyramid`, `zeros`.
204
- normals (`bool`, *optional*, defaults to `False`):
205
- If `True`, the pipeline will predict surface normals instead of depth maps.
206
- Returns:
207
- `MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
208
- - **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
209
- - **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
210
- - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
211
- coming from ensembling. None if `ensemble_size = 1`
212
- - **normal_np** (`np.ndarray`) Predicted normal map, with normal vectors in the range of [-1, 1]
213
- - **normal_colored** (`PIL.Image.Image`) Colorized normal map
214
- """
215
-
216
- assert processing_res >= 0
217
- assert ensemble_size >= 1
218
-
219
- resample_method: InterpolationMode = get_tv_resample_method(resample_method)
220
-
221
- # ----------------- Image Preprocess -----------------
222
-
223
- # Convert to torch tensor
224
- if isinstance(input_image, Image.Image):
225
- input_image = input_image.convert("RGB")
226
- rgb = pil_to_tensor(input_image) # [H, W, rgb] -> [rgb, H, W]
227
- elif isinstance(input_image, torch.Tensor):
228
- rgb = input_image.squeeze()
229
- else:
230
- raise TypeError(f"Unknown input type: {type(input_image) = }")
231
- input_size = rgb.shape
232
- assert (
233
- 3 == rgb.dim() and 3 == input_size[0]
234
- ), f"Wrong input shape {input_size}, expected [rgb, H, W]"
235
-
236
- # Resize image
237
- if processing_res > 0:
238
- rgb = resize_max_res(
239
- rgb,
240
- max_edge_resolution=processing_res,
241
- resample_method=resample_method,
242
- )
243
-
244
- # Normalize rgb values
245
- rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
246
- rgb_norm = rgb_norm.to(self.dtype)
247
- assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
248
-
249
- # ----------------- Predicting depth/normal --------------
250
-
251
- # Batch repeated input image
252
- duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
253
- single_rgb_dataset = TensorDataset(duplicated_rgb)
254
- if batch_size > 0:
255
- _bs = batch_size
256
- else:
257
- _bs = find_batch_size(
258
- ensemble_size=ensemble_size,
259
- input_res=max(rgb_norm.shape[1:]),
260
- dtype=self.dtype,
261
- )
262
-
263
- single_rgb_loader = DataLoader(
264
- single_rgb_dataset, batch_size=_bs, shuffle=False
265
- )
266
-
267
- # load iterator
268
- pred_ls = []
269
- if show_progress_bar:
270
- iterable = tqdm(
271
- single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
272
- )
273
- else:
274
- iterable = single_rgb_loader
275
-
276
- # inference (batched)
277
- for batch in iterable:
278
- (batched_img,) = batch
279
- pred_raw = self.single_infer(
280
- rgb_in=batched_img,
281
- num_inference_steps=denoising_steps,
282
- show_pbar=show_progress_bar,
283
- # add
284
- noise=noise,
285
- normals=normals,
286
- )
287
- pred_ls.append(pred_raw.detach())
288
- preds = torch.concat(pred_ls, dim=0).squeeze()
289
- torch.cuda.empty_cache() # clear vram cache for ensembling
290
-
291
- # ----------------- Test-time ensembling -----------------
292
-
293
- if ensemble_size > 1: # add
294
- pred, pred_uncert = ensemble_normals(preds) if normals else ensemble_depths(preds, **(ensemble_kwargs or {}))
295
- else:
296
- pred = preds
297
- pred_uncert = None
298
-
299
- # ----------------- Post processing -----------------
300
-
301
- if normals:
302
- # add
303
- # Normalizae normal vectors to unit length
304
- pred /= (torch.norm(pred, p=2, dim=0, keepdim=True)+1e-5)
305
- else:
306
- # Scale relative prediction to [0, 1]
307
- min_d = torch.min(pred)
308
- max_d = torch.max(pred)
309
- if max_d == min_d:
310
- pred = torch.zeros_like(pred)
311
- else:
312
- pred = (pred - min_d) / (max_d - min_d)
313
-
314
- # Resize back to original resolution
315
- if match_input_res:
316
- pred = resize(
317
- pred if normals else pred.unsqueeze(0),
318
- (input_size[-2],input_size[-1]),
319
- interpolation=resample_method,
320
- antialias=True,
321
- ).squeeze()
322
-
323
- # Convert to numpy
324
- pred = pred.cpu().numpy()
325
-
326
- # Process prediction for visualization
327
- if not normals:
328
- # add
329
- pred = pred.clip(0, 1)
330
- if color_map is not None:
331
- colored = colorize_depth_maps(
332
- pred, 0, 1, cmap=color_map
333
- ).squeeze() # [3, H, W], value in (0, 1)
334
- colored = (colored * 255).astype(np.uint8)
335
- colored_hwc = chw2hwc(colored)
336
- colored_img = Image.fromarray(colored_hwc)
337
- else:
338
- colored_img = None
339
- else:
340
- pred = pred.clip(-1.0, 1.0)
341
- colored = (((pred+1)/2) * 255).astype(np.uint8)
342
- colored_hwc = chw2hwc(colored)
343
- colored_img = Image.fromarray(colored_hwc)
344
-
345
-
346
- return MarigoldDepthOutput(
347
- depth_np = pred if not normals else None,
348
- depth_colored = colored_img if not normals else None,
349
- uncertainty = pred_uncert,
350
- # add
351
- normal_np = pred if normals else None,
352
- normal_colored = colored_img if normals else None,
353
- )
354
-
355
-
356
- def encode_empty_text(self):
357
- """
358
- Encode text embedding for empty prompt
359
- """
360
- prompt = ""
361
- text_inputs = self.tokenizer(
362
- prompt,
363
- padding="do_not_pad",
364
- max_length=self.tokenizer.model_max_length,
365
- truncation=True,
366
- return_tensors="pt",
367
- )
368
- text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
369
- self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
370
-
371
- @torch.no_grad()
372
- def single_infer(
373
- self,
374
- rgb_in: torch.Tensor,
375
- num_inference_steps: int,
376
- show_pbar: bool,
377
- # add
378
- noise="gaussian",
379
- normals=False,
380
- ) -> torch.Tensor:
381
- """
382
- Perform an individual depth prediction without ensembling.
383
-
384
- Args:
385
- rgb_in (`torch.Tensor`):
386
- Input RGB image.
387
- num_inference_steps (`int`):
388
- Number of diffusion denoisign steps (DDIM) during inference.
389
- show_pbar (`bool`):
390
- Display a progress bar of diffusion denoising.
391
- noise (`str`, *optional*, defaults to `gaussian`):
392
- Type of noise to be used for the initial depth map.
393
- Can be one of `gaussian`, `pyramid`, `zeros`.
394
- Returns:
395
- `torch.Tensor`: Predicted depth map.
396
- """
397
- device = self.device
398
- rgb_in = rgb_in.to(device)
399
-
400
- # Set timesteps
401
- self.scheduler.set_timesteps(num_inference_steps, device=device)
402
- timesteps = self.scheduler.timesteps # [T]
403
-
404
- # Encode image
405
- rgb_latent = self.encode_rgb(rgb_in)
406
-
407
- # add
408
- # Initial prediction
409
- latent_shape = rgb_latent.shape
410
- if noise == "gaussian":
411
- latent = torch.randn(
412
- latent_shape,
413
- device=device,
414
- dtype=self.dtype,
415
- )
416
- elif noise == "pyramid":
417
- latent = pyramid_noise_like(rgb_latent).to(device) # [B, 4, h, w]
418
- elif noise == "zeros":
419
- latent = torch.zeros(
420
- latent_shape,
421
- device=device,
422
- dtype=self.dtype,
423
- )
424
- else:
425
- raise ValueError(f"Unknown noise type: {noise}")
426
-
427
- # Batched empty text embedding
428
- if self.empty_text_embed is None:
429
- self.encode_empty_text()
430
- batch_empty_text_embed = self.empty_text_embed.repeat(
431
- (rgb_latent.shape[0], 1, 1)
432
- ) # [B, 2, 1024]
433
-
434
- # Denoising loop
435
- if show_pbar:
436
- iterable = tqdm(
437
- enumerate(timesteps),
438
- total=len(timesteps),
439
- leave=False,
440
- desc=" " * 4 + "Diffusion denoising",
441
- )
442
- else:
443
- iterable = enumerate(timesteps)
444
-
445
- for i, t in iterable:
446
-
447
- unet_input = torch.cat(
448
- [rgb_latent, latent], dim=1
449
- ) # this order is important
450
-
451
- # predict the noise residual
452
- noise_pred = self.unet(
453
- unet_input, t, encoder_hidden_states=batch_empty_text_embed
454
- ).sample # [B, 4, h, w]
455
-
456
- # compute the previous noisy sample x_t -> x_t-1
457
- scheduler_step = self.scheduler.step(
458
- noise_pred, t, latent
459
- )
460
-
461
- latent = scheduler_step.prev_sample
462
-
463
- if normals:
464
- # add
465
- # decode and normalize normal vectors
466
- normal = self.decode_normal(latent)
467
- normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
468
- return normal
469
- else:
470
- # decode and normalize depth map
471
- depth = self.decode_depth(latent)
472
- depth = torch.clip(depth, -1.0, 1.0)
473
- depth = (depth + 1.0) / 2.0
474
- return depth
475
-
476
-
477
- def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
478
- """
479
- Encode RGB image into latent.
480
-
481
- Args:
482
- rgb_in (`torch.Tensor`):
483
- Input RGB image to be encoded.
484
-
485
- Returns:
486
- `torch.Tensor`: Image latent.
487
- """
488
- # encode
489
- h = self.vae.encoder(rgb_in)
490
- moments = self.vae.quant_conv(h)
491
- mean, logvar = torch.chunk(moments, 2, dim=1)
492
- # scale latent
493
- rgb_latent = mean * self.rgb_latent_scale_factor
494
- return rgb_latent
495
-
496
-
497
- def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
498
- """
499
- Decode depth latent into depth map.
500
-
501
- Args:
502
- depth_latent (`torch.Tensor`):
503
- Depth latent to be decoded.
504
-
505
- Returns:
506
- `torch.Tensor`: Decoded depth map.
507
- """
508
- # scale latent
509
- depth_latent = depth_latent / self.depth_latent_scale_factor
510
- # decode
511
- z = self.vae.post_quant_conv(depth_latent)
512
- stacked = self.vae.decoder(z)
513
- # mean of output channels
514
- depth_mean = stacked.mean(dim=1, keepdim=True)
515
- return depth_mean
516
-
517
- # add
518
- def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
519
- """
520
- Decode normal latent into normal map.
521
-
522
- Args:
523
- normal_latent (`torch.Tensor`):
524
- normal latent to be decoded.
525
-
526
- Returns:
527
- `torch.Tensor`: Decoded depth map.
528
- """
529
- # scale latent
530
- normal_latent = normal_latent / self.depth_latent_scale_factor
531
- # decode
532
- z = self.vae.post_quant_conv(normal_latent)
533
- normal = self.vae.decoder(z)
534
- return normal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Marigold/marigold/util/__init__.py DELETED
File without changes
Marigold/marigold/util/batchsize.py DELETED
@@ -1,81 +0,0 @@
1
- # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
-
20
-
21
- import torch
22
- import math
23
-
24
-
25
- # Search table for suggested max. inference batch size
26
- bs_search_table = [
27
- # tested on A100-PCIE-80GB
28
- {"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
29
- {"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
30
- # tested on A100-PCIE-40GB
31
- {"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
32
- {"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
33
- {"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
34
- {"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
35
- # tested on RTX3090, RTX4090
36
- {"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
37
- {"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
38
- {"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
39
- {"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
40
- {"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
41
- {"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
42
- # tested on GTX1080Ti
43
- {"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
44
- {"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
45
- {"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
46
- {"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
47
- {"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
48
- ]
49
-
50
-
51
- def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
52
- """
53
- Automatically search for suitable operating batch size.
54
-
55
- Args:
56
- ensemble_size (`int`):
57
- Number of predictions to be ensembled.
58
- input_res (`int`):
59
- Operating resolution of the input image.
60
-
61
- Returns:
62
- `int`: Operating batch size.
63
- """
64
- if not torch.cuda.is_available():
65
- return 1
66
-
67
- total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
68
- filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
69
- for settings in sorted(
70
- filtered_bs_search_table,
71
- key=lambda k: (k["res"], -k["total_vram"]),
72
- ):
73
- if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
74
- bs = settings["bs"]
75
- if bs > ensemble_size:
76
- bs = ensemble_size
77
- elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
78
- bs = math.ceil(ensemble_size / 2)
79
- return bs
80
-
81
- return 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Marigold/marigold/util/ensemble.py DELETED
@@ -1,132 +0,0 @@
1
- # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # --------------------------------------------------------------------------
15
- # If you find this code useful, we kindly ask you to cite our paper in your work.
16
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
17
- # More information about the method can be found at https://marigoldmonodepth.github.io
18
- # --------------------------------------------------------------------------
19
-
20
-
21
- import numpy as np
22
- import torch
23
-
24
- from scipy.optimize import minimize
25
-
26
-
27
- def inter_distances(tensors: torch.Tensor):
28
- """
29
- To calculate the distance between each two depth maps.
30
- """
31
- distances = []
32
- for i, j in torch.combinations(torch.arange(tensors.shape[0])):
33
- arr1 = tensors[i : i + 1]
34
- arr2 = tensors[j : j + 1]
35
- distances.append(arr1 - arr2)
36
- dist = torch.concatenate(distances, dim=0)
37
- return dist
38
-
39
-
40
- def ensemble_depths(
41
- input_images: torch.Tensor,
42
- regularizer_strength: float = 0.02,
43
- max_iter: int = 2,
44
- tol: float = 1e-3,
45
- reduction: str = "median",
46
- max_res: int = None,
47
- ):
48
- """
49
- To ensemble multiple affine-invariant depth images (up to scale and shift),
50
- by aligning estimating the scale and shift
51
- """
52
- device = input_images.device
53
- dtype = input_images.dtype
54
- np_dtype = np.float32
55
-
56
- original_input = input_images.clone()
57
- n_img = input_images.shape[0]
58
- ori_shape = input_images.shape
59
-
60
- if max_res is not None:
61
- scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
62
- if scale_factor < 1:
63
- downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
64
- input_images = downscaler(input_images)
65
-
66
- # init guess
67
- _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
68
- _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
69
- s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
70
- t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
71
- x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
72
-
73
- input_images = input_images.to(device)
74
-
75
- # objective function
76
- def closure(x):
77
- len_x = len(x)
78
- s = x[: int(len_x / 2)]
79
- t = x[int(len_x / 2) :]
80
- s = torch.from_numpy(s).to(dtype=dtype).to(device)
81
- t = torch.from_numpy(t).to(dtype=dtype).to(device)
82
-
83
- transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
84
- dists = inter_distances(transformed_arrays)
85
- sqrt_dist = torch.sqrt(torch.mean(dists**2))
86
-
87
- if "mean" == reduction:
88
- pred = torch.mean(transformed_arrays, dim=0)
89
- elif "median" == reduction:
90
- pred = torch.median(transformed_arrays, dim=0).values
91
- else:
92
- raise ValueError
93
-
94
- near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
95
- far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
96
-
97
- err = sqrt_dist + (near_err + far_err) * regularizer_strength
98
- err = err.detach().cpu().numpy().astype(np_dtype)
99
- return err
100
-
101
- res = minimize(
102
- closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
103
- )
104
- x = res.x
105
- len_x = len(x)
106
- s = x[: int(len_x / 2)]
107
- t = x[int(len_x / 2) :]
108
-
109
- # Prediction
110
- s = torch.from_numpy(s).to(dtype=dtype).to(device)
111
- t = torch.from_numpy(t).to(dtype=dtype).to(device)
112
- transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
113
- if "mean" == reduction:
114
- aligned_images = torch.mean(transformed_arrays, dim=0)
115
- std = torch.std(transformed_arrays, dim=0)
116
- uncertainty = std
117
- elif "median" == reduction:
118
- aligned_images = torch.median(transformed_arrays, dim=0).values
119
- # MAD (median absolute deviation) as uncertainty indicator
120
- abs_dev = torch.abs(transformed_arrays - aligned_images)
121
- mad = torch.median(abs_dev, dim=0).values
122
- uncertainty = mad
123
- else:
124
- raise ValueError(f"Unknown reduction method: {reduction}")
125
-
126
- # Scale and shift to [0, 1]
127
- _min = torch.min(aligned_images)
128
- _max = torch.max(aligned_images)
129
- aligned_images = (aligned_images - _min) / (_max - _min)
130
- uncertainty /= _max - _min
131
-
132
- return aligned_images, uncertainty
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Marigold/marigold/util/image_util.py DELETED
@@ -1,121 +0,0 @@
1
- # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
2
- # Last modified: 2024-04-16
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # --------------------------------------------------------------------------
16
- # If you find this code useful, we kindly ask you to cite our paper in your work.
17
- # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
18
- # More information about the method can be found at https://marigoldmonodepth.github.io
19
- # --------------------------------------------------------------------------
20
-
21
-
22
- import matplotlib
23
- import numpy as np
24
- import torch
25
- from torchvision.transforms import InterpolationMode
26
- from torchvision.transforms.functional import resize
27
-
28
-
29
- def colorize_depth_maps(
30
- depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
31
- ):
32
- """
33
- Colorize depth maps.
34
- """
35
- assert len(depth_map.shape) >= 2, "Invalid dimension"
36
-
37
- if isinstance(depth_map, torch.Tensor):
38
- depth = depth_map.detach().squeeze().numpy()
39
- elif isinstance(depth_map, np.ndarray):
40
- depth = depth_map.copy().squeeze()
41
- # reshape to [ (B,) H, W ]
42
- if depth.ndim < 3:
43
- depth = depth[np.newaxis, :, :]
44
-
45
- # colorize
46
- cm = matplotlib.colormaps[cmap]
47
- depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
48
- img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
49
- img_colored_np = np.rollaxis(img_colored_np, 3, 1)
50
-
51
- if valid_mask is not None:
52
- if isinstance(depth_map, torch.Tensor):
53
- valid_mask = valid_mask.detach().numpy()
54
- valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
55
- if valid_mask.ndim < 3:
56
- valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
57
- else:
58
- valid_mask = valid_mask[:, np.newaxis, :, :]
59
- valid_mask = np.repeat(valid_mask, 3, axis=1)
60
- img_colored_np[~valid_mask] = 0
61
-
62
- if isinstance(depth_map, torch.Tensor):
63
- img_colored = torch.from_numpy(img_colored_np).float()
64
- elif isinstance(depth_map, np.ndarray):
65
- img_colored = img_colored_np
66
-
67
- return img_colored
68
-
69
-
70
- def chw2hwc(chw):
71
- assert 3 == len(chw.shape)
72
- if isinstance(chw, torch.Tensor):
73
- hwc = torch.permute(chw, (1, 2, 0))
74
- elif isinstance(chw, np.ndarray):
75
- hwc = np.moveaxis(chw, 0, -1)
76
- return hwc
77
-
78
-
79
- def resize_max_res(
80
- img: torch.Tensor,
81
- max_edge_resolution: int,
82
- resample_method: InterpolationMode = InterpolationMode.BILINEAR,
83
- ) -> torch.Tensor:
84
- """
85
- Resize image to limit maximum edge length while keeping aspect ratio.
86
-
87
- Args:
88
- img (`torch.Tensor`):
89
- Image tensor to be resized.
90
- max_edge_resolution (`int`):
91
- Maximum edge length (pixel).
92
- resample_method (`PIL.Image.Resampling`):
93
- Resampling method used to resize images.
94
-
95
- Returns:
96
- `torch.Tensor`: Resized image.
97
- """
98
- assert 3 == img.dim()
99
- _, original_height, original_width = img.shape
100
- downscale_factor = min(
101
- max_edge_resolution / original_width, max_edge_resolution / original_height
102
- )
103
-
104
- new_width = int(original_width * downscale_factor)
105
- new_height = int(original_height * downscale_factor)
106
-
107
- resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
108
- return resized_img
109
-
110
-
111
- def get_tv_resample_method(method_str: str) -> InterpolationMode:
112
- resample_method_dict = {
113
- "bilinear": InterpolationMode.BILINEAR,
114
- "bicubic": InterpolationMode.BICUBIC,
115
- "nearest": InterpolationMode.NEAREST_EXACT,
116
- }
117
- resample_method = resample_method_dict.get(method_str, None)
118
- if resample_method is None:
119
- raise ValueError(f"Unknown resampling method: {resample_method}")
120
- else:
121
- return resample_method