Spaces:
Running
on
Zero
Running
on
Zero
Upload 8 files
Browse files- .gitattributes +1 -0
- Marigold/README.md +1 -0
- Marigold/marigold/__init__.py +21 -0
- Marigold/marigold/marigold_pipeline.py +534 -0
- Marigold/marigold/util/__init__.py +0 -0
- Marigold/marigold/util/batchsize.py +81 -0
- Marigold/marigold/util/ensemble.py +132 -0
- Marigold/marigold/util/image_util.py +121 -0
- assets/examples/mushrooms.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/examples/mushrooms.png filter=lfs diff=lfs merge=lfs -text
|
Marigold/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Code is copied from https://github.com/prs-eth/Marigold. Modifications are indicated within the code.
|
Marigold/marigold/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
from .marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput # noqa: F401
|
Marigold/marigold/marigold_pipeline.py
ADDED
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
|
20 |
+
# @GonzaloMartinGarcia
|
21 |
+
# This file is a modified version of the original Marigold pipeline file.
|
22 |
+
# Based on GeoWizard, we added the option to sample surface normals, marked with # add.
|
23 |
+
|
24 |
+
from typing import Dict, Union
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
from diffusers import (
|
29 |
+
AutoencoderKL,
|
30 |
+
DDIMScheduler,
|
31 |
+
DiffusionPipeline,
|
32 |
+
LCMScheduler,
|
33 |
+
UNet2DConditionModel,
|
34 |
+
DDPMScheduler,
|
35 |
+
)
|
36 |
+
from diffusers.utils import BaseOutput
|
37 |
+
from PIL import Image
|
38 |
+
from torchvision.transforms.functional import resize, pil_to_tensor
|
39 |
+
from torchvision.transforms import InterpolationMode
|
40 |
+
from torch.utils.data import DataLoader, TensorDataset
|
41 |
+
from tqdm.auto import tqdm
|
42 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
43 |
+
|
44 |
+
from .util.batchsize import find_batch_size
|
45 |
+
from .util.ensemble import ensemble_depths
|
46 |
+
from .util.image_util import (
|
47 |
+
chw2hwc,
|
48 |
+
colorize_depth_maps,
|
49 |
+
get_tv_resample_method,
|
50 |
+
resize_max_res,
|
51 |
+
)
|
52 |
+
|
53 |
+
# add
|
54 |
+
import random
|
55 |
+
|
56 |
+
|
57 |
+
# add
|
58 |
+
# Surface Normals Ensamble from the GeoWizard github repository (https://github.com/fuxiao0719/GeoWizard)
|
59 |
+
def ensemble_normals(input_images:torch.Tensor):
|
60 |
+
normal_preds = input_images
|
61 |
+
bsz, d, h, w = normal_preds.shape
|
62 |
+
normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
|
63 |
+
phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
|
64 |
+
theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
|
65 |
+
normal_pred = torch.zeros((d,h,w)).to(normal_preds)
|
66 |
+
normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
|
67 |
+
normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
|
68 |
+
normal_pred[2,:,:] = torch.cos(theta)
|
69 |
+
angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999))
|
70 |
+
normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
|
71 |
+
return normal_preds[normal_idx], None
|
72 |
+
|
73 |
+
# add
|
74 |
+
# Pyramid nosie from
|
75 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31
|
76 |
+
def pyramid_noise_like(x, discount=0.9):
|
77 |
+
b, c, w, h = x.shape
|
78 |
+
u = torch.nn.Upsample(size=(w, h), mode='bilinear')
|
79 |
+
noise = torch.randn_like(x)
|
80 |
+
for i in range(10):
|
81 |
+
r = random.random()*2+2
|
82 |
+
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
|
83 |
+
noise += u(torch.randn(b, c, w, h).to(x)) * discount**i
|
84 |
+
if w==1 or h==1:
|
85 |
+
break
|
86 |
+
return noise / noise.std()
|
87 |
+
|
88 |
+
class MarigoldDepthOutput(BaseOutput):
|
89 |
+
"""
|
90 |
+
Output class for Marigold monocular depth prediction pipeline.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
depth_np (`np.ndarray`):
|
94 |
+
Predicted depth map, with depth values in the range of [0, 1].
|
95 |
+
depth_colored (`PIL.Image.Image`):
|
96 |
+
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
97 |
+
uncertainty (`None` or `np.ndarray`):
|
98 |
+
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
99 |
+
normal_np (`np.ndarray`):
|
100 |
+
Predicted normal map, with normal vectors in the range of [-1, 1].
|
101 |
+
normal_colored (`PIL.Image.Image`):
|
102 |
+
Colorized normal map
|
103 |
+
"""
|
104 |
+
|
105 |
+
depth_np: np.ndarray
|
106 |
+
depth_colored: Union[None, Image.Image]
|
107 |
+
uncertainty: Union[None, np.ndarray]
|
108 |
+
# add
|
109 |
+
normal_np: np.ndarray
|
110 |
+
normal_colored: Union[None, Image.Image]
|
111 |
+
|
112 |
+
|
113 |
+
class MarigoldPipeline(DiffusionPipeline):
|
114 |
+
"""
|
115 |
+
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
116 |
+
|
117 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
118 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
119 |
+
|
120 |
+
Args:
|
121 |
+
unet (`UNet2DConditionModel`):
|
122 |
+
Conditional U-Net to denoise the depth latent, conditioned on image latent.
|
123 |
+
vae (`AutoencoderKL`):
|
124 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
|
125 |
+
to and from latent representations.
|
126 |
+
scheduler (`DDIMScheduler`):
|
127 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
128 |
+
text_encoder (`CLIPTextModel`):
|
129 |
+
Text-encoder, for empty text embedding.
|
130 |
+
tokenizer (`CLIPTokenizer`):
|
131 |
+
CLIP tokenizer.
|
132 |
+
"""
|
133 |
+
|
134 |
+
rgb_latent_scale_factor = 0.18215
|
135 |
+
depth_latent_scale_factor = 0.18215
|
136 |
+
|
137 |
+
def __init__(
|
138 |
+
self,
|
139 |
+
unet: UNet2DConditionModel,
|
140 |
+
vae: AutoencoderKL,
|
141 |
+
scheduler: Union[DDIMScheduler,DDPMScheduler,LCMScheduler],
|
142 |
+
text_encoder: CLIPTextModel,
|
143 |
+
tokenizer: CLIPTokenizer,
|
144 |
+
):
|
145 |
+
super().__init__()
|
146 |
+
|
147 |
+
self.register_modules(
|
148 |
+
unet=unet,
|
149 |
+
vae=vae,
|
150 |
+
scheduler=scheduler,
|
151 |
+
text_encoder=text_encoder,
|
152 |
+
tokenizer=tokenizer,
|
153 |
+
)
|
154 |
+
|
155 |
+
self.empty_text_embed = None
|
156 |
+
|
157 |
+
@torch.no_grad()
|
158 |
+
def __call__(
|
159 |
+
self,
|
160 |
+
input_image: Union[Image.Image, torch.Tensor],
|
161 |
+
denoising_steps: int = 10,
|
162 |
+
ensemble_size: int = 10,
|
163 |
+
processing_res: int = 768,
|
164 |
+
match_input_res: bool = True,
|
165 |
+
resample_method: str = "bilinear",
|
166 |
+
batch_size: int = 0,
|
167 |
+
color_map: str = "Spectral",
|
168 |
+
show_progress_bar: bool = True,
|
169 |
+
ensemble_kwargs: Dict = None,
|
170 |
+
# add
|
171 |
+
noise="gaussian",
|
172 |
+
normals=False,
|
173 |
+
) -> MarigoldDepthOutput:
|
174 |
+
"""
|
175 |
+
Function invoked when calling the pipeline.
|
176 |
+
|
177 |
+
Args:
|
178 |
+
input_image (`Image`):
|
179 |
+
Input RGB (or gray-scale) image.
|
180 |
+
processing_res (`int`, *optional*, defaults to `768`):
|
181 |
+
Maximum resolution of processing.
|
182 |
+
If set to 0: will not resize at all.
|
183 |
+
match_input_res (`bool`, *optional*, defaults to `True`):
|
184 |
+
Resize depth prediction to match input resolution.
|
185 |
+
Only valid if `processing_res` > 0.
|
186 |
+
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
187 |
+
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
188 |
+
denoising_steps (`int`, *optional*, defaults to `10`):
|
189 |
+
Number of diffusion denoising steps (DDIM) during inference.
|
190 |
+
ensemble_size (`int`, *optional*, defaults to `10`):
|
191 |
+
Number of predictions to be ensembled.
|
192 |
+
batch_size (`int`, *optional*, defaults to `0`):
|
193 |
+
Inference batch size, no bigger than `num_ensemble`.
|
194 |
+
If set to 0, the script will automatically decide the proper batch size.
|
195 |
+
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
196 |
+
Display a progress bar of diffusion denoising.
|
197 |
+
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
198 |
+
Colormap used to colorize the depth map.
|
199 |
+
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
200 |
+
Arguments for detailed ensembling settings.
|
201 |
+
noise (`str`, *optional*, defaults to `gaussian`):
|
202 |
+
Type of noise to be used for the initial depth map.
|
203 |
+
Can be one of `gaussian`, `pyramid`, `zeros`.
|
204 |
+
normals (`bool`, *optional*, defaults to `False`):
|
205 |
+
If `True`, the pipeline will predict surface normals instead of depth maps.
|
206 |
+
Returns:
|
207 |
+
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
208 |
+
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
209 |
+
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
|
210 |
+
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
211 |
+
coming from ensembling. None if `ensemble_size = 1`
|
212 |
+
- **normal_np** (`np.ndarray`) Predicted normal map, with normal vectors in the range of [-1, 1]
|
213 |
+
- **normal_colored** (`PIL.Image.Image`) Colorized normal map
|
214 |
+
"""
|
215 |
+
|
216 |
+
assert processing_res >= 0
|
217 |
+
assert ensemble_size >= 1
|
218 |
+
|
219 |
+
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
|
220 |
+
|
221 |
+
# ----------------- Image Preprocess -----------------
|
222 |
+
|
223 |
+
# Convert to torch tensor
|
224 |
+
if isinstance(input_image, Image.Image):
|
225 |
+
input_image = input_image.convert("RGB")
|
226 |
+
rgb = pil_to_tensor(input_image) # [H, W, rgb] -> [rgb, H, W]
|
227 |
+
elif isinstance(input_image, torch.Tensor):
|
228 |
+
rgb = input_image.squeeze()
|
229 |
+
else:
|
230 |
+
raise TypeError(f"Unknown input type: {type(input_image) = }")
|
231 |
+
input_size = rgb.shape
|
232 |
+
assert (
|
233 |
+
3 == rgb.dim() and 3 == input_size[0]
|
234 |
+
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
|
235 |
+
|
236 |
+
# Resize image
|
237 |
+
if processing_res > 0:
|
238 |
+
rgb = resize_max_res(
|
239 |
+
rgb,
|
240 |
+
max_edge_resolution=processing_res,
|
241 |
+
resample_method=resample_method,
|
242 |
+
)
|
243 |
+
|
244 |
+
# Normalize rgb values
|
245 |
+
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
246 |
+
rgb_norm = rgb_norm.to(self.dtype)
|
247 |
+
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
248 |
+
|
249 |
+
# ----------------- Predicting depth/normal --------------
|
250 |
+
|
251 |
+
# Batch repeated input image
|
252 |
+
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
253 |
+
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
254 |
+
if batch_size > 0:
|
255 |
+
_bs = batch_size
|
256 |
+
else:
|
257 |
+
_bs = find_batch_size(
|
258 |
+
ensemble_size=ensemble_size,
|
259 |
+
input_res=max(rgb_norm.shape[1:]),
|
260 |
+
dtype=self.dtype,
|
261 |
+
)
|
262 |
+
|
263 |
+
single_rgb_loader = DataLoader(
|
264 |
+
single_rgb_dataset, batch_size=_bs, shuffle=False
|
265 |
+
)
|
266 |
+
|
267 |
+
# load iterator
|
268 |
+
pred_ls = []
|
269 |
+
if show_progress_bar:
|
270 |
+
iterable = tqdm(
|
271 |
+
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
272 |
+
)
|
273 |
+
else:
|
274 |
+
iterable = single_rgb_loader
|
275 |
+
|
276 |
+
# inference (batched)
|
277 |
+
for batch in iterable:
|
278 |
+
(batched_img,) = batch
|
279 |
+
pred_raw = self.single_infer(
|
280 |
+
rgb_in=batched_img,
|
281 |
+
num_inference_steps=denoising_steps,
|
282 |
+
show_pbar=show_progress_bar,
|
283 |
+
# add
|
284 |
+
noise=noise,
|
285 |
+
normals=normals,
|
286 |
+
)
|
287 |
+
pred_ls.append(pred_raw.detach())
|
288 |
+
preds = torch.concat(pred_ls, dim=0).squeeze()
|
289 |
+
torch.cuda.empty_cache() # clear vram cache for ensembling
|
290 |
+
|
291 |
+
# ----------------- Test-time ensembling -----------------
|
292 |
+
|
293 |
+
if ensemble_size > 1: # add
|
294 |
+
pred, pred_uncert = ensemble_normals(preds) if normals else ensemble_depths(preds, **(ensemble_kwargs or {}))
|
295 |
+
else:
|
296 |
+
pred = preds
|
297 |
+
pred_uncert = None
|
298 |
+
|
299 |
+
# ----------------- Post processing -----------------
|
300 |
+
|
301 |
+
if normals:
|
302 |
+
# add
|
303 |
+
# Normalizae normal vectors to unit length
|
304 |
+
pred /= (torch.norm(pred, p=2, dim=0, keepdim=True)+1e-5)
|
305 |
+
else:
|
306 |
+
# Scale relative prediction to [0, 1]
|
307 |
+
min_d = torch.min(pred)
|
308 |
+
max_d = torch.max(pred)
|
309 |
+
if max_d == min_d:
|
310 |
+
pred = torch.zeros_like(pred)
|
311 |
+
else:
|
312 |
+
pred = (pred - min_d) / (max_d - min_d)
|
313 |
+
|
314 |
+
# Resize back to original resolution
|
315 |
+
if match_input_res:
|
316 |
+
pred = resize(
|
317 |
+
pred if normals else pred.unsqueeze(0),
|
318 |
+
(input_size[-2],input_size[-1]),
|
319 |
+
interpolation=resample_method,
|
320 |
+
antialias=True,
|
321 |
+
).squeeze()
|
322 |
+
|
323 |
+
# Convert to numpy
|
324 |
+
pred = pred.cpu().numpy()
|
325 |
+
|
326 |
+
# Process prediction for visualization
|
327 |
+
if not normals:
|
328 |
+
# add
|
329 |
+
pred = pred.clip(0, 1)
|
330 |
+
if color_map is not None:
|
331 |
+
colored = colorize_depth_maps(
|
332 |
+
pred, 0, 1, cmap=color_map
|
333 |
+
).squeeze() # [3, H, W], value in (0, 1)
|
334 |
+
colored = (colored * 255).astype(np.uint8)
|
335 |
+
colored_hwc = chw2hwc(colored)
|
336 |
+
colored_img = Image.fromarray(colored_hwc)
|
337 |
+
else:
|
338 |
+
colored_img = None
|
339 |
+
else:
|
340 |
+
pred = pred.clip(-1.0, 1.0)
|
341 |
+
colored = (((pred+1)/2) * 255).astype(np.uint8)
|
342 |
+
colored_hwc = chw2hwc(colored)
|
343 |
+
colored_img = Image.fromarray(colored_hwc)
|
344 |
+
|
345 |
+
|
346 |
+
return MarigoldDepthOutput(
|
347 |
+
depth_np = pred if not normals else None,
|
348 |
+
depth_colored = colored_img if not normals else None,
|
349 |
+
uncertainty = pred_uncert,
|
350 |
+
# add
|
351 |
+
normal_np = pred if normals else None,
|
352 |
+
normal_colored = colored_img if normals else None,
|
353 |
+
)
|
354 |
+
|
355 |
+
|
356 |
+
def encode_empty_text(self):
|
357 |
+
"""
|
358 |
+
Encode text embedding for empty prompt
|
359 |
+
"""
|
360 |
+
prompt = ""
|
361 |
+
text_inputs = self.tokenizer(
|
362 |
+
prompt,
|
363 |
+
padding="do_not_pad",
|
364 |
+
max_length=self.tokenizer.model_max_length,
|
365 |
+
truncation=True,
|
366 |
+
return_tensors="pt",
|
367 |
+
)
|
368 |
+
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
369 |
+
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
370 |
+
|
371 |
+
@torch.no_grad()
|
372 |
+
def single_infer(
|
373 |
+
self,
|
374 |
+
rgb_in: torch.Tensor,
|
375 |
+
num_inference_steps: int,
|
376 |
+
show_pbar: bool,
|
377 |
+
# add
|
378 |
+
noise="gaussian",
|
379 |
+
normals=False,
|
380 |
+
) -> torch.Tensor:
|
381 |
+
"""
|
382 |
+
Perform an individual depth prediction without ensembling.
|
383 |
+
|
384 |
+
Args:
|
385 |
+
rgb_in (`torch.Tensor`):
|
386 |
+
Input RGB image.
|
387 |
+
num_inference_steps (`int`):
|
388 |
+
Number of diffusion denoisign steps (DDIM) during inference.
|
389 |
+
show_pbar (`bool`):
|
390 |
+
Display a progress bar of diffusion denoising.
|
391 |
+
noise (`str`, *optional*, defaults to `gaussian`):
|
392 |
+
Type of noise to be used for the initial depth map.
|
393 |
+
Can be one of `gaussian`, `pyramid`, `zeros`.
|
394 |
+
Returns:
|
395 |
+
`torch.Tensor`: Predicted depth map.
|
396 |
+
"""
|
397 |
+
device = self.device
|
398 |
+
rgb_in = rgb_in.to(device)
|
399 |
+
|
400 |
+
# Set timesteps
|
401 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
402 |
+
timesteps = self.scheduler.timesteps # [T]
|
403 |
+
|
404 |
+
# Encode image
|
405 |
+
rgb_latent = self.encode_rgb(rgb_in)
|
406 |
+
|
407 |
+
# add
|
408 |
+
# Initial prediction
|
409 |
+
latent_shape = rgb_latent.shape
|
410 |
+
if noise == "gaussian":
|
411 |
+
latent = torch.randn(
|
412 |
+
latent_shape,
|
413 |
+
device=device,
|
414 |
+
dtype=self.dtype,
|
415 |
+
)
|
416 |
+
elif noise == "pyramid":
|
417 |
+
latent = pyramid_noise_like(rgb_latent).to(device) # [B, 4, h, w]
|
418 |
+
elif noise == "zeros":
|
419 |
+
latent = torch.zeros(
|
420 |
+
latent_shape,
|
421 |
+
device=device,
|
422 |
+
dtype=self.dtype,
|
423 |
+
)
|
424 |
+
else:
|
425 |
+
raise ValueError(f"Unknown noise type: {noise}")
|
426 |
+
|
427 |
+
# Batched empty text embedding
|
428 |
+
if self.empty_text_embed is None:
|
429 |
+
self.encode_empty_text()
|
430 |
+
batch_empty_text_embed = self.empty_text_embed.repeat(
|
431 |
+
(rgb_latent.shape[0], 1, 1)
|
432 |
+
) # [B, 2, 1024]
|
433 |
+
|
434 |
+
# Denoising loop
|
435 |
+
if show_pbar:
|
436 |
+
iterable = tqdm(
|
437 |
+
enumerate(timesteps),
|
438 |
+
total=len(timesteps),
|
439 |
+
leave=False,
|
440 |
+
desc=" " * 4 + "Diffusion denoising",
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
iterable = enumerate(timesteps)
|
444 |
+
|
445 |
+
for i, t in iterable:
|
446 |
+
|
447 |
+
unet_input = torch.cat(
|
448 |
+
[rgb_latent, latent], dim=1
|
449 |
+
) # this order is important
|
450 |
+
|
451 |
+
# predict the noise residual
|
452 |
+
noise_pred = self.unet(
|
453 |
+
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
454 |
+
).sample # [B, 4, h, w]
|
455 |
+
|
456 |
+
# compute the previous noisy sample x_t -> x_t-1
|
457 |
+
scheduler_step = self.scheduler.step(
|
458 |
+
noise_pred, t, latent
|
459 |
+
)
|
460 |
+
|
461 |
+
latent = scheduler_step.prev_sample
|
462 |
+
|
463 |
+
if normals:
|
464 |
+
# add
|
465 |
+
# decode and normalize normal vectors
|
466 |
+
normal = self.decode_normal(latent)
|
467 |
+
normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
|
468 |
+
return normal
|
469 |
+
else:
|
470 |
+
# decode and normalize depth map
|
471 |
+
depth = self.decode_depth(latent)
|
472 |
+
depth = torch.clip(depth, -1.0, 1.0)
|
473 |
+
depth = (depth + 1.0) / 2.0
|
474 |
+
return depth
|
475 |
+
|
476 |
+
|
477 |
+
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
478 |
+
"""
|
479 |
+
Encode RGB image into latent.
|
480 |
+
|
481 |
+
Args:
|
482 |
+
rgb_in (`torch.Tensor`):
|
483 |
+
Input RGB image to be encoded.
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
`torch.Tensor`: Image latent.
|
487 |
+
"""
|
488 |
+
# encode
|
489 |
+
h = self.vae.encoder(rgb_in)
|
490 |
+
moments = self.vae.quant_conv(h)
|
491 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
492 |
+
# scale latent
|
493 |
+
rgb_latent = mean * self.rgb_latent_scale_factor
|
494 |
+
return rgb_latent
|
495 |
+
|
496 |
+
|
497 |
+
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
498 |
+
"""
|
499 |
+
Decode depth latent into depth map.
|
500 |
+
|
501 |
+
Args:
|
502 |
+
depth_latent (`torch.Tensor`):
|
503 |
+
Depth latent to be decoded.
|
504 |
+
|
505 |
+
Returns:
|
506 |
+
`torch.Tensor`: Decoded depth map.
|
507 |
+
"""
|
508 |
+
# scale latent
|
509 |
+
depth_latent = depth_latent / self.depth_latent_scale_factor
|
510 |
+
# decode
|
511 |
+
z = self.vae.post_quant_conv(depth_latent)
|
512 |
+
stacked = self.vae.decoder(z)
|
513 |
+
# mean of output channels
|
514 |
+
depth_mean = stacked.mean(dim=1, keepdim=True)
|
515 |
+
return depth_mean
|
516 |
+
|
517 |
+
# add
|
518 |
+
def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
|
519 |
+
"""
|
520 |
+
Decode normal latent into normal map.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
normal_latent (`torch.Tensor`):
|
524 |
+
normal latent to be decoded.
|
525 |
+
|
526 |
+
Returns:
|
527 |
+
`torch.Tensor`: Decoded depth map.
|
528 |
+
"""
|
529 |
+
# scale latent
|
530 |
+
normal_latent = normal_latent / self.depth_latent_scale_factor
|
531 |
+
# decode
|
532 |
+
z = self.vae.post_quant_conv(normal_latent)
|
533 |
+
normal = self.vae.decoder(z)
|
534 |
+
return normal
|
Marigold/marigold/util/__init__.py
ADDED
File without changes
|
Marigold/marigold/util/batchsize.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import math
|
23 |
+
|
24 |
+
|
25 |
+
# Search table for suggested max. inference batch size
|
26 |
+
bs_search_table = [
|
27 |
+
# tested on A100-PCIE-80GB
|
28 |
+
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
29 |
+
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
30 |
+
# tested on A100-PCIE-40GB
|
31 |
+
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
32 |
+
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
33 |
+
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
34 |
+
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
35 |
+
# tested on RTX3090, RTX4090
|
36 |
+
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
37 |
+
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
38 |
+
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
39 |
+
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
40 |
+
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
41 |
+
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
42 |
+
# tested on GTX1080Ti
|
43 |
+
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
44 |
+
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
45 |
+
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
46 |
+
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
47 |
+
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
48 |
+
]
|
49 |
+
|
50 |
+
|
51 |
+
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
52 |
+
"""
|
53 |
+
Automatically search for suitable operating batch size.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
ensemble_size (`int`):
|
57 |
+
Number of predictions to be ensembled.
|
58 |
+
input_res (`int`):
|
59 |
+
Operating resolution of the input image.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
`int`: Operating batch size.
|
63 |
+
"""
|
64 |
+
if not torch.cuda.is_available():
|
65 |
+
return 1
|
66 |
+
|
67 |
+
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
68 |
+
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
69 |
+
for settings in sorted(
|
70 |
+
filtered_bs_search_table,
|
71 |
+
key=lambda k: (k["res"], -k["total_vram"]),
|
72 |
+
):
|
73 |
+
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
74 |
+
bs = settings["bs"]
|
75 |
+
if bs > ensemble_size:
|
76 |
+
bs = ensemble_size
|
77 |
+
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
78 |
+
bs = math.ceil(ensemble_size / 2)
|
79 |
+
return bs
|
80 |
+
|
81 |
+
return 1
|
Marigold/marigold/util/ensemble.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# --------------------------------------------------------------------------
|
15 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
+
# --------------------------------------------------------------------------
|
19 |
+
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
|
24 |
+
from scipy.optimize import minimize
|
25 |
+
|
26 |
+
|
27 |
+
def inter_distances(tensors: torch.Tensor):
|
28 |
+
"""
|
29 |
+
To calculate the distance between each two depth maps.
|
30 |
+
"""
|
31 |
+
distances = []
|
32 |
+
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
33 |
+
arr1 = tensors[i : i + 1]
|
34 |
+
arr2 = tensors[j : j + 1]
|
35 |
+
distances.append(arr1 - arr2)
|
36 |
+
dist = torch.concatenate(distances, dim=0)
|
37 |
+
return dist
|
38 |
+
|
39 |
+
|
40 |
+
def ensemble_depths(
|
41 |
+
input_images: torch.Tensor,
|
42 |
+
regularizer_strength: float = 0.02,
|
43 |
+
max_iter: int = 2,
|
44 |
+
tol: float = 1e-3,
|
45 |
+
reduction: str = "median",
|
46 |
+
max_res: int = None,
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
50 |
+
by aligning estimating the scale and shift
|
51 |
+
"""
|
52 |
+
device = input_images.device
|
53 |
+
dtype = input_images.dtype
|
54 |
+
np_dtype = np.float32
|
55 |
+
|
56 |
+
original_input = input_images.clone()
|
57 |
+
n_img = input_images.shape[0]
|
58 |
+
ori_shape = input_images.shape
|
59 |
+
|
60 |
+
if max_res is not None:
|
61 |
+
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
62 |
+
if scale_factor < 1:
|
63 |
+
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
64 |
+
input_images = downscaler(input_images)
|
65 |
+
|
66 |
+
# init guess
|
67 |
+
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
68 |
+
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
69 |
+
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
70 |
+
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
71 |
+
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
|
72 |
+
|
73 |
+
input_images = input_images.to(device)
|
74 |
+
|
75 |
+
# objective function
|
76 |
+
def closure(x):
|
77 |
+
len_x = len(x)
|
78 |
+
s = x[: int(len_x / 2)]
|
79 |
+
t = x[int(len_x / 2) :]
|
80 |
+
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
81 |
+
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
82 |
+
|
83 |
+
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
84 |
+
dists = inter_distances(transformed_arrays)
|
85 |
+
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
86 |
+
|
87 |
+
if "mean" == reduction:
|
88 |
+
pred = torch.mean(transformed_arrays, dim=0)
|
89 |
+
elif "median" == reduction:
|
90 |
+
pred = torch.median(transformed_arrays, dim=0).values
|
91 |
+
else:
|
92 |
+
raise ValueError
|
93 |
+
|
94 |
+
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
95 |
+
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
96 |
+
|
97 |
+
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
98 |
+
err = err.detach().cpu().numpy().astype(np_dtype)
|
99 |
+
return err
|
100 |
+
|
101 |
+
res = minimize(
|
102 |
+
closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
|
103 |
+
)
|
104 |
+
x = res.x
|
105 |
+
len_x = len(x)
|
106 |
+
s = x[: int(len_x / 2)]
|
107 |
+
t = x[int(len_x / 2) :]
|
108 |
+
|
109 |
+
# Prediction
|
110 |
+
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
111 |
+
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
112 |
+
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
113 |
+
if "mean" == reduction:
|
114 |
+
aligned_images = torch.mean(transformed_arrays, dim=0)
|
115 |
+
std = torch.std(transformed_arrays, dim=0)
|
116 |
+
uncertainty = std
|
117 |
+
elif "median" == reduction:
|
118 |
+
aligned_images = torch.median(transformed_arrays, dim=0).values
|
119 |
+
# MAD (median absolute deviation) as uncertainty indicator
|
120 |
+
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
121 |
+
mad = torch.median(abs_dev, dim=0).values
|
122 |
+
uncertainty = mad
|
123 |
+
else:
|
124 |
+
raise ValueError(f"Unknown reduction method: {reduction}")
|
125 |
+
|
126 |
+
# Scale and shift to [0, 1]
|
127 |
+
_min = torch.min(aligned_images)
|
128 |
+
_max = torch.max(aligned_images)
|
129 |
+
aligned_images = (aligned_images - _min) / (_max - _min)
|
130 |
+
uncertainty /= _max - _min
|
131 |
+
|
132 |
+
return aligned_images, uncertainty
|
Marigold/marigold/util/image_util.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
+
# Last modified: 2024-04-16
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# --------------------------------------------------------------------------
|
16 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
17 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
18 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
19 |
+
# --------------------------------------------------------------------------
|
20 |
+
|
21 |
+
|
22 |
+
import matplotlib
|
23 |
+
import numpy as np
|
24 |
+
import torch
|
25 |
+
from torchvision.transforms import InterpolationMode
|
26 |
+
from torchvision.transforms.functional import resize
|
27 |
+
|
28 |
+
|
29 |
+
def colorize_depth_maps(
|
30 |
+
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Colorize depth maps.
|
34 |
+
"""
|
35 |
+
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
36 |
+
|
37 |
+
if isinstance(depth_map, torch.Tensor):
|
38 |
+
depth = depth_map.detach().squeeze().numpy()
|
39 |
+
elif isinstance(depth_map, np.ndarray):
|
40 |
+
depth = depth_map.copy().squeeze()
|
41 |
+
# reshape to [ (B,) H, W ]
|
42 |
+
if depth.ndim < 3:
|
43 |
+
depth = depth[np.newaxis, :, :]
|
44 |
+
|
45 |
+
# colorize
|
46 |
+
cm = matplotlib.colormaps[cmap]
|
47 |
+
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
48 |
+
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
49 |
+
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
50 |
+
|
51 |
+
if valid_mask is not None:
|
52 |
+
if isinstance(depth_map, torch.Tensor):
|
53 |
+
valid_mask = valid_mask.detach().numpy()
|
54 |
+
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
55 |
+
if valid_mask.ndim < 3:
|
56 |
+
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
57 |
+
else:
|
58 |
+
valid_mask = valid_mask[:, np.newaxis, :, :]
|
59 |
+
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
60 |
+
img_colored_np[~valid_mask] = 0
|
61 |
+
|
62 |
+
if isinstance(depth_map, torch.Tensor):
|
63 |
+
img_colored = torch.from_numpy(img_colored_np).float()
|
64 |
+
elif isinstance(depth_map, np.ndarray):
|
65 |
+
img_colored = img_colored_np
|
66 |
+
|
67 |
+
return img_colored
|
68 |
+
|
69 |
+
|
70 |
+
def chw2hwc(chw):
|
71 |
+
assert 3 == len(chw.shape)
|
72 |
+
if isinstance(chw, torch.Tensor):
|
73 |
+
hwc = torch.permute(chw, (1, 2, 0))
|
74 |
+
elif isinstance(chw, np.ndarray):
|
75 |
+
hwc = np.moveaxis(chw, 0, -1)
|
76 |
+
return hwc
|
77 |
+
|
78 |
+
|
79 |
+
def resize_max_res(
|
80 |
+
img: torch.Tensor,
|
81 |
+
max_edge_resolution: int,
|
82 |
+
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
|
83 |
+
) -> torch.Tensor:
|
84 |
+
"""
|
85 |
+
Resize image to limit maximum edge length while keeping aspect ratio.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
img (`torch.Tensor`):
|
89 |
+
Image tensor to be resized.
|
90 |
+
max_edge_resolution (`int`):
|
91 |
+
Maximum edge length (pixel).
|
92 |
+
resample_method (`PIL.Image.Resampling`):
|
93 |
+
Resampling method used to resize images.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
`torch.Tensor`: Resized image.
|
97 |
+
"""
|
98 |
+
assert 3 == img.dim()
|
99 |
+
_, original_height, original_width = img.shape
|
100 |
+
downscale_factor = min(
|
101 |
+
max_edge_resolution / original_width, max_edge_resolution / original_height
|
102 |
+
)
|
103 |
+
|
104 |
+
new_width = int(original_width * downscale_factor)
|
105 |
+
new_height = int(original_height * downscale_factor)
|
106 |
+
|
107 |
+
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
|
108 |
+
return resized_img
|
109 |
+
|
110 |
+
|
111 |
+
def get_tv_resample_method(method_str: str) -> InterpolationMode:
|
112 |
+
resample_method_dict = {
|
113 |
+
"bilinear": InterpolationMode.BILINEAR,
|
114 |
+
"bicubic": InterpolationMode.BICUBIC,
|
115 |
+
"nearest": InterpolationMode.NEAREST_EXACT,
|
116 |
+
}
|
117 |
+
resample_method = resample_method_dict.get(method_str, None)
|
118 |
+
if resample_method is None:
|
119 |
+
raise ValueError(f"Unknown resampling method: {resample_method}")
|
120 |
+
else:
|
121 |
+
return resample_method
|
assets/examples/mushrooms.png
ADDED
Git LFS Details
|