Spaces:
Running
on
Zero
Running
on
Zero
Delete Marigold
Browse files- Marigold/README.md +0 -1
- Marigold/marigold/__init__.py +0 -21
- Marigold/marigold/marigold_pipeline.py +0 -534
- Marigold/marigold/util/__init__.py +0 -0
- Marigold/marigold/util/batchsize.py +0 -81
- Marigold/marigold/util/ensemble.py +0 -132
- Marigold/marigold/util/image_util.py +0 -121
Marigold/README.md
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
Code is copied from https://github.com/prs-eth/Marigold. Modifications are indicated within the code.
|
|
|
|
Marigold/marigold/__init__.py
DELETED
@@ -1,21 +0,0 @@
|
|
1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
# --------------------------------------------------------------------------
|
15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
-
# --------------------------------------------------------------------------
|
19 |
-
|
20 |
-
|
21 |
-
from .marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput # noqa: F401
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/marigold_pipeline.py
DELETED
@@ -1,534 +0,0 @@
|
|
1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
# --------------------------------------------------------------------------
|
15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
-
# --------------------------------------------------------------------------
|
19 |
-
|
20 |
-
# @GonzaloMartinGarcia
|
21 |
-
# This file is a modified version of the original Marigold pipeline file.
|
22 |
-
# Based on GeoWizard, we added the option to sample surface normals, marked with # add.
|
23 |
-
|
24 |
-
from typing import Dict, Union
|
25 |
-
|
26 |
-
import numpy as np
|
27 |
-
import torch
|
28 |
-
from diffusers import (
|
29 |
-
AutoencoderKL,
|
30 |
-
DDIMScheduler,
|
31 |
-
DiffusionPipeline,
|
32 |
-
LCMScheduler,
|
33 |
-
UNet2DConditionModel,
|
34 |
-
DDPMScheduler,
|
35 |
-
)
|
36 |
-
from diffusers.utils import BaseOutput
|
37 |
-
from PIL import Image
|
38 |
-
from torchvision.transforms.functional import resize, pil_to_tensor
|
39 |
-
from torchvision.transforms import InterpolationMode
|
40 |
-
from torch.utils.data import DataLoader, TensorDataset
|
41 |
-
from tqdm.auto import tqdm
|
42 |
-
from transformers import CLIPTextModel, CLIPTokenizer
|
43 |
-
|
44 |
-
from .util.batchsize import find_batch_size
|
45 |
-
from .util.ensemble import ensemble_depths
|
46 |
-
from .util.image_util import (
|
47 |
-
chw2hwc,
|
48 |
-
colorize_depth_maps,
|
49 |
-
get_tv_resample_method,
|
50 |
-
resize_max_res,
|
51 |
-
)
|
52 |
-
|
53 |
-
# add
|
54 |
-
import random
|
55 |
-
|
56 |
-
|
57 |
-
# add
|
58 |
-
# Surface Normals Ensamble from the GeoWizard github repository (https://github.com/fuxiao0719/GeoWizard)
|
59 |
-
def ensemble_normals(input_images:torch.Tensor):
|
60 |
-
normal_preds = input_images
|
61 |
-
bsz, d, h, w = normal_preds.shape
|
62 |
-
normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5)
|
63 |
-
phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0)
|
64 |
-
theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0)
|
65 |
-
normal_pred = torch.zeros((d,h,w)).to(normal_preds)
|
66 |
-
normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi)
|
67 |
-
normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi)
|
68 |
-
normal_pred[2,:,:] = torch.cos(theta)
|
69 |
-
angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999))
|
70 |
-
normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1))
|
71 |
-
return normal_preds[normal_idx], None
|
72 |
-
|
73 |
-
# add
|
74 |
-
# Pyramid nosie from
|
75 |
-
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31
|
76 |
-
def pyramid_noise_like(x, discount=0.9):
|
77 |
-
b, c, w, h = x.shape
|
78 |
-
u = torch.nn.Upsample(size=(w, h), mode='bilinear')
|
79 |
-
noise = torch.randn_like(x)
|
80 |
-
for i in range(10):
|
81 |
-
r = random.random()*2+2
|
82 |
-
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i)))
|
83 |
-
noise += u(torch.randn(b, c, w, h).to(x)) * discount**i
|
84 |
-
if w==1 or h==1:
|
85 |
-
break
|
86 |
-
return noise / noise.std()
|
87 |
-
|
88 |
-
class MarigoldDepthOutput(BaseOutput):
|
89 |
-
"""
|
90 |
-
Output class for Marigold monocular depth prediction pipeline.
|
91 |
-
|
92 |
-
Args:
|
93 |
-
depth_np (`np.ndarray`):
|
94 |
-
Predicted depth map, with depth values in the range of [0, 1].
|
95 |
-
depth_colored (`PIL.Image.Image`):
|
96 |
-
Colorized depth map, with the shape of [3, H, W] and values in [0, 1].
|
97 |
-
uncertainty (`None` or `np.ndarray`):
|
98 |
-
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling.
|
99 |
-
normal_np (`np.ndarray`):
|
100 |
-
Predicted normal map, with normal vectors in the range of [-1, 1].
|
101 |
-
normal_colored (`PIL.Image.Image`):
|
102 |
-
Colorized normal map
|
103 |
-
"""
|
104 |
-
|
105 |
-
depth_np: np.ndarray
|
106 |
-
depth_colored: Union[None, Image.Image]
|
107 |
-
uncertainty: Union[None, np.ndarray]
|
108 |
-
# add
|
109 |
-
normal_np: np.ndarray
|
110 |
-
normal_colored: Union[None, Image.Image]
|
111 |
-
|
112 |
-
|
113 |
-
class MarigoldPipeline(DiffusionPipeline):
|
114 |
-
"""
|
115 |
-
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
|
116 |
-
|
117 |
-
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
118 |
-
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
119 |
-
|
120 |
-
Args:
|
121 |
-
unet (`UNet2DConditionModel`):
|
122 |
-
Conditional U-Net to denoise the depth latent, conditioned on image latent.
|
123 |
-
vae (`AutoencoderKL`):
|
124 |
-
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps
|
125 |
-
to and from latent representations.
|
126 |
-
scheduler (`DDIMScheduler`):
|
127 |
-
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
128 |
-
text_encoder (`CLIPTextModel`):
|
129 |
-
Text-encoder, for empty text embedding.
|
130 |
-
tokenizer (`CLIPTokenizer`):
|
131 |
-
CLIP tokenizer.
|
132 |
-
"""
|
133 |
-
|
134 |
-
rgb_latent_scale_factor = 0.18215
|
135 |
-
depth_latent_scale_factor = 0.18215
|
136 |
-
|
137 |
-
def __init__(
|
138 |
-
self,
|
139 |
-
unet: UNet2DConditionModel,
|
140 |
-
vae: AutoencoderKL,
|
141 |
-
scheduler: Union[DDIMScheduler,DDPMScheduler,LCMScheduler],
|
142 |
-
text_encoder: CLIPTextModel,
|
143 |
-
tokenizer: CLIPTokenizer,
|
144 |
-
):
|
145 |
-
super().__init__()
|
146 |
-
|
147 |
-
self.register_modules(
|
148 |
-
unet=unet,
|
149 |
-
vae=vae,
|
150 |
-
scheduler=scheduler,
|
151 |
-
text_encoder=text_encoder,
|
152 |
-
tokenizer=tokenizer,
|
153 |
-
)
|
154 |
-
|
155 |
-
self.empty_text_embed = None
|
156 |
-
|
157 |
-
@torch.no_grad()
|
158 |
-
def __call__(
|
159 |
-
self,
|
160 |
-
input_image: Union[Image.Image, torch.Tensor],
|
161 |
-
denoising_steps: int = 10,
|
162 |
-
ensemble_size: int = 10,
|
163 |
-
processing_res: int = 768,
|
164 |
-
match_input_res: bool = True,
|
165 |
-
resample_method: str = "bilinear",
|
166 |
-
batch_size: int = 0,
|
167 |
-
color_map: str = "Spectral",
|
168 |
-
show_progress_bar: bool = True,
|
169 |
-
ensemble_kwargs: Dict = None,
|
170 |
-
# add
|
171 |
-
noise="gaussian",
|
172 |
-
normals=False,
|
173 |
-
) -> MarigoldDepthOutput:
|
174 |
-
"""
|
175 |
-
Function invoked when calling the pipeline.
|
176 |
-
|
177 |
-
Args:
|
178 |
-
input_image (`Image`):
|
179 |
-
Input RGB (or gray-scale) image.
|
180 |
-
processing_res (`int`, *optional*, defaults to `768`):
|
181 |
-
Maximum resolution of processing.
|
182 |
-
If set to 0: will not resize at all.
|
183 |
-
match_input_res (`bool`, *optional*, defaults to `True`):
|
184 |
-
Resize depth prediction to match input resolution.
|
185 |
-
Only valid if `processing_res` > 0.
|
186 |
-
resample_method: (`str`, *optional*, defaults to `bilinear`):
|
187 |
-
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
|
188 |
-
denoising_steps (`int`, *optional*, defaults to `10`):
|
189 |
-
Number of diffusion denoising steps (DDIM) during inference.
|
190 |
-
ensemble_size (`int`, *optional*, defaults to `10`):
|
191 |
-
Number of predictions to be ensembled.
|
192 |
-
batch_size (`int`, *optional*, defaults to `0`):
|
193 |
-
Inference batch size, no bigger than `num_ensemble`.
|
194 |
-
If set to 0, the script will automatically decide the proper batch size.
|
195 |
-
show_progress_bar (`bool`, *optional*, defaults to `True`):
|
196 |
-
Display a progress bar of diffusion denoising.
|
197 |
-
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
|
198 |
-
Colormap used to colorize the depth map.
|
199 |
-
ensemble_kwargs (`dict`, *optional*, defaults to `None`):
|
200 |
-
Arguments for detailed ensembling settings.
|
201 |
-
noise (`str`, *optional*, defaults to `gaussian`):
|
202 |
-
Type of noise to be used for the initial depth map.
|
203 |
-
Can be one of `gaussian`, `pyramid`, `zeros`.
|
204 |
-
normals (`bool`, *optional*, defaults to `False`):
|
205 |
-
If `True`, the pipeline will predict surface normals instead of depth maps.
|
206 |
-
Returns:
|
207 |
-
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
|
208 |
-
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
|
209 |
-
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
|
210 |
-
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
|
211 |
-
coming from ensembling. None if `ensemble_size = 1`
|
212 |
-
- **normal_np** (`np.ndarray`) Predicted normal map, with normal vectors in the range of [-1, 1]
|
213 |
-
- **normal_colored** (`PIL.Image.Image`) Colorized normal map
|
214 |
-
"""
|
215 |
-
|
216 |
-
assert processing_res >= 0
|
217 |
-
assert ensemble_size >= 1
|
218 |
-
|
219 |
-
resample_method: InterpolationMode = get_tv_resample_method(resample_method)
|
220 |
-
|
221 |
-
# ----------------- Image Preprocess -----------------
|
222 |
-
|
223 |
-
# Convert to torch tensor
|
224 |
-
if isinstance(input_image, Image.Image):
|
225 |
-
input_image = input_image.convert("RGB")
|
226 |
-
rgb = pil_to_tensor(input_image) # [H, W, rgb] -> [rgb, H, W]
|
227 |
-
elif isinstance(input_image, torch.Tensor):
|
228 |
-
rgb = input_image.squeeze()
|
229 |
-
else:
|
230 |
-
raise TypeError(f"Unknown input type: {type(input_image) = }")
|
231 |
-
input_size = rgb.shape
|
232 |
-
assert (
|
233 |
-
3 == rgb.dim() and 3 == input_size[0]
|
234 |
-
), f"Wrong input shape {input_size}, expected [rgb, H, W]"
|
235 |
-
|
236 |
-
# Resize image
|
237 |
-
if processing_res > 0:
|
238 |
-
rgb = resize_max_res(
|
239 |
-
rgb,
|
240 |
-
max_edge_resolution=processing_res,
|
241 |
-
resample_method=resample_method,
|
242 |
-
)
|
243 |
-
|
244 |
-
# Normalize rgb values
|
245 |
-
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1]
|
246 |
-
rgb_norm = rgb_norm.to(self.dtype)
|
247 |
-
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
|
248 |
-
|
249 |
-
# ----------------- Predicting depth/normal --------------
|
250 |
-
|
251 |
-
# Batch repeated input image
|
252 |
-
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size)
|
253 |
-
single_rgb_dataset = TensorDataset(duplicated_rgb)
|
254 |
-
if batch_size > 0:
|
255 |
-
_bs = batch_size
|
256 |
-
else:
|
257 |
-
_bs = find_batch_size(
|
258 |
-
ensemble_size=ensemble_size,
|
259 |
-
input_res=max(rgb_norm.shape[1:]),
|
260 |
-
dtype=self.dtype,
|
261 |
-
)
|
262 |
-
|
263 |
-
single_rgb_loader = DataLoader(
|
264 |
-
single_rgb_dataset, batch_size=_bs, shuffle=False
|
265 |
-
)
|
266 |
-
|
267 |
-
# load iterator
|
268 |
-
pred_ls = []
|
269 |
-
if show_progress_bar:
|
270 |
-
iterable = tqdm(
|
271 |
-
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False
|
272 |
-
)
|
273 |
-
else:
|
274 |
-
iterable = single_rgb_loader
|
275 |
-
|
276 |
-
# inference (batched)
|
277 |
-
for batch in iterable:
|
278 |
-
(batched_img,) = batch
|
279 |
-
pred_raw = self.single_infer(
|
280 |
-
rgb_in=batched_img,
|
281 |
-
num_inference_steps=denoising_steps,
|
282 |
-
show_pbar=show_progress_bar,
|
283 |
-
# add
|
284 |
-
noise=noise,
|
285 |
-
normals=normals,
|
286 |
-
)
|
287 |
-
pred_ls.append(pred_raw.detach())
|
288 |
-
preds = torch.concat(pred_ls, dim=0).squeeze()
|
289 |
-
torch.cuda.empty_cache() # clear vram cache for ensembling
|
290 |
-
|
291 |
-
# ----------------- Test-time ensembling -----------------
|
292 |
-
|
293 |
-
if ensemble_size > 1: # add
|
294 |
-
pred, pred_uncert = ensemble_normals(preds) if normals else ensemble_depths(preds, **(ensemble_kwargs or {}))
|
295 |
-
else:
|
296 |
-
pred = preds
|
297 |
-
pred_uncert = None
|
298 |
-
|
299 |
-
# ----------------- Post processing -----------------
|
300 |
-
|
301 |
-
if normals:
|
302 |
-
# add
|
303 |
-
# Normalizae normal vectors to unit length
|
304 |
-
pred /= (torch.norm(pred, p=2, dim=0, keepdim=True)+1e-5)
|
305 |
-
else:
|
306 |
-
# Scale relative prediction to [0, 1]
|
307 |
-
min_d = torch.min(pred)
|
308 |
-
max_d = torch.max(pred)
|
309 |
-
if max_d == min_d:
|
310 |
-
pred = torch.zeros_like(pred)
|
311 |
-
else:
|
312 |
-
pred = (pred - min_d) / (max_d - min_d)
|
313 |
-
|
314 |
-
# Resize back to original resolution
|
315 |
-
if match_input_res:
|
316 |
-
pred = resize(
|
317 |
-
pred if normals else pred.unsqueeze(0),
|
318 |
-
(input_size[-2],input_size[-1]),
|
319 |
-
interpolation=resample_method,
|
320 |
-
antialias=True,
|
321 |
-
).squeeze()
|
322 |
-
|
323 |
-
# Convert to numpy
|
324 |
-
pred = pred.cpu().numpy()
|
325 |
-
|
326 |
-
# Process prediction for visualization
|
327 |
-
if not normals:
|
328 |
-
# add
|
329 |
-
pred = pred.clip(0, 1)
|
330 |
-
if color_map is not None:
|
331 |
-
colored = colorize_depth_maps(
|
332 |
-
pred, 0, 1, cmap=color_map
|
333 |
-
).squeeze() # [3, H, W], value in (0, 1)
|
334 |
-
colored = (colored * 255).astype(np.uint8)
|
335 |
-
colored_hwc = chw2hwc(colored)
|
336 |
-
colored_img = Image.fromarray(colored_hwc)
|
337 |
-
else:
|
338 |
-
colored_img = None
|
339 |
-
else:
|
340 |
-
pred = pred.clip(-1.0, 1.0)
|
341 |
-
colored = (((pred+1)/2) * 255).astype(np.uint8)
|
342 |
-
colored_hwc = chw2hwc(colored)
|
343 |
-
colored_img = Image.fromarray(colored_hwc)
|
344 |
-
|
345 |
-
|
346 |
-
return MarigoldDepthOutput(
|
347 |
-
depth_np = pred if not normals else None,
|
348 |
-
depth_colored = colored_img if not normals else None,
|
349 |
-
uncertainty = pred_uncert,
|
350 |
-
# add
|
351 |
-
normal_np = pred if normals else None,
|
352 |
-
normal_colored = colored_img if normals else None,
|
353 |
-
)
|
354 |
-
|
355 |
-
|
356 |
-
def encode_empty_text(self):
|
357 |
-
"""
|
358 |
-
Encode text embedding for empty prompt
|
359 |
-
"""
|
360 |
-
prompt = ""
|
361 |
-
text_inputs = self.tokenizer(
|
362 |
-
prompt,
|
363 |
-
padding="do_not_pad",
|
364 |
-
max_length=self.tokenizer.model_max_length,
|
365 |
-
truncation=True,
|
366 |
-
return_tensors="pt",
|
367 |
-
)
|
368 |
-
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device)
|
369 |
-
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)
|
370 |
-
|
371 |
-
@torch.no_grad()
|
372 |
-
def single_infer(
|
373 |
-
self,
|
374 |
-
rgb_in: torch.Tensor,
|
375 |
-
num_inference_steps: int,
|
376 |
-
show_pbar: bool,
|
377 |
-
# add
|
378 |
-
noise="gaussian",
|
379 |
-
normals=False,
|
380 |
-
) -> torch.Tensor:
|
381 |
-
"""
|
382 |
-
Perform an individual depth prediction without ensembling.
|
383 |
-
|
384 |
-
Args:
|
385 |
-
rgb_in (`torch.Tensor`):
|
386 |
-
Input RGB image.
|
387 |
-
num_inference_steps (`int`):
|
388 |
-
Number of diffusion denoisign steps (DDIM) during inference.
|
389 |
-
show_pbar (`bool`):
|
390 |
-
Display a progress bar of diffusion denoising.
|
391 |
-
noise (`str`, *optional*, defaults to `gaussian`):
|
392 |
-
Type of noise to be used for the initial depth map.
|
393 |
-
Can be one of `gaussian`, `pyramid`, `zeros`.
|
394 |
-
Returns:
|
395 |
-
`torch.Tensor`: Predicted depth map.
|
396 |
-
"""
|
397 |
-
device = self.device
|
398 |
-
rgb_in = rgb_in.to(device)
|
399 |
-
|
400 |
-
# Set timesteps
|
401 |
-
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
402 |
-
timesteps = self.scheduler.timesteps # [T]
|
403 |
-
|
404 |
-
# Encode image
|
405 |
-
rgb_latent = self.encode_rgb(rgb_in)
|
406 |
-
|
407 |
-
# add
|
408 |
-
# Initial prediction
|
409 |
-
latent_shape = rgb_latent.shape
|
410 |
-
if noise == "gaussian":
|
411 |
-
latent = torch.randn(
|
412 |
-
latent_shape,
|
413 |
-
device=device,
|
414 |
-
dtype=self.dtype,
|
415 |
-
)
|
416 |
-
elif noise == "pyramid":
|
417 |
-
latent = pyramid_noise_like(rgb_latent).to(device) # [B, 4, h, w]
|
418 |
-
elif noise == "zeros":
|
419 |
-
latent = torch.zeros(
|
420 |
-
latent_shape,
|
421 |
-
device=device,
|
422 |
-
dtype=self.dtype,
|
423 |
-
)
|
424 |
-
else:
|
425 |
-
raise ValueError(f"Unknown noise type: {noise}")
|
426 |
-
|
427 |
-
# Batched empty text embedding
|
428 |
-
if self.empty_text_embed is None:
|
429 |
-
self.encode_empty_text()
|
430 |
-
batch_empty_text_embed = self.empty_text_embed.repeat(
|
431 |
-
(rgb_latent.shape[0], 1, 1)
|
432 |
-
) # [B, 2, 1024]
|
433 |
-
|
434 |
-
# Denoising loop
|
435 |
-
if show_pbar:
|
436 |
-
iterable = tqdm(
|
437 |
-
enumerate(timesteps),
|
438 |
-
total=len(timesteps),
|
439 |
-
leave=False,
|
440 |
-
desc=" " * 4 + "Diffusion denoising",
|
441 |
-
)
|
442 |
-
else:
|
443 |
-
iterable = enumerate(timesteps)
|
444 |
-
|
445 |
-
for i, t in iterable:
|
446 |
-
|
447 |
-
unet_input = torch.cat(
|
448 |
-
[rgb_latent, latent], dim=1
|
449 |
-
) # this order is important
|
450 |
-
|
451 |
-
# predict the noise residual
|
452 |
-
noise_pred = self.unet(
|
453 |
-
unet_input, t, encoder_hidden_states=batch_empty_text_embed
|
454 |
-
).sample # [B, 4, h, w]
|
455 |
-
|
456 |
-
# compute the previous noisy sample x_t -> x_t-1
|
457 |
-
scheduler_step = self.scheduler.step(
|
458 |
-
noise_pred, t, latent
|
459 |
-
)
|
460 |
-
|
461 |
-
latent = scheduler_step.prev_sample
|
462 |
-
|
463 |
-
if normals:
|
464 |
-
# add
|
465 |
-
# decode and normalize normal vectors
|
466 |
-
normal = self.decode_normal(latent)
|
467 |
-
normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5)
|
468 |
-
return normal
|
469 |
-
else:
|
470 |
-
# decode and normalize depth map
|
471 |
-
depth = self.decode_depth(latent)
|
472 |
-
depth = torch.clip(depth, -1.0, 1.0)
|
473 |
-
depth = (depth + 1.0) / 2.0
|
474 |
-
return depth
|
475 |
-
|
476 |
-
|
477 |
-
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
|
478 |
-
"""
|
479 |
-
Encode RGB image into latent.
|
480 |
-
|
481 |
-
Args:
|
482 |
-
rgb_in (`torch.Tensor`):
|
483 |
-
Input RGB image to be encoded.
|
484 |
-
|
485 |
-
Returns:
|
486 |
-
`torch.Tensor`: Image latent.
|
487 |
-
"""
|
488 |
-
# encode
|
489 |
-
h = self.vae.encoder(rgb_in)
|
490 |
-
moments = self.vae.quant_conv(h)
|
491 |
-
mean, logvar = torch.chunk(moments, 2, dim=1)
|
492 |
-
# scale latent
|
493 |
-
rgb_latent = mean * self.rgb_latent_scale_factor
|
494 |
-
return rgb_latent
|
495 |
-
|
496 |
-
|
497 |
-
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
|
498 |
-
"""
|
499 |
-
Decode depth latent into depth map.
|
500 |
-
|
501 |
-
Args:
|
502 |
-
depth_latent (`torch.Tensor`):
|
503 |
-
Depth latent to be decoded.
|
504 |
-
|
505 |
-
Returns:
|
506 |
-
`torch.Tensor`: Decoded depth map.
|
507 |
-
"""
|
508 |
-
# scale latent
|
509 |
-
depth_latent = depth_latent / self.depth_latent_scale_factor
|
510 |
-
# decode
|
511 |
-
z = self.vae.post_quant_conv(depth_latent)
|
512 |
-
stacked = self.vae.decoder(z)
|
513 |
-
# mean of output channels
|
514 |
-
depth_mean = stacked.mean(dim=1, keepdim=True)
|
515 |
-
return depth_mean
|
516 |
-
|
517 |
-
# add
|
518 |
-
def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor:
|
519 |
-
"""
|
520 |
-
Decode normal latent into normal map.
|
521 |
-
|
522 |
-
Args:
|
523 |
-
normal_latent (`torch.Tensor`):
|
524 |
-
normal latent to be decoded.
|
525 |
-
|
526 |
-
Returns:
|
527 |
-
`torch.Tensor`: Decoded depth map.
|
528 |
-
"""
|
529 |
-
# scale latent
|
530 |
-
normal_latent = normal_latent / self.depth_latent_scale_factor
|
531 |
-
# decode
|
532 |
-
z = self.vae.post_quant_conv(normal_latent)
|
533 |
-
normal = self.vae.decoder(z)
|
534 |
-
return normal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/util/__init__.py
DELETED
File without changes
|
Marigold/marigold/util/batchsize.py
DELETED
@@ -1,81 +0,0 @@
|
|
1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
# --------------------------------------------------------------------------
|
15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
-
# --------------------------------------------------------------------------
|
19 |
-
|
20 |
-
|
21 |
-
import torch
|
22 |
-
import math
|
23 |
-
|
24 |
-
|
25 |
-
# Search table for suggested max. inference batch size
|
26 |
-
bs_search_table = [
|
27 |
-
# tested on A100-PCIE-80GB
|
28 |
-
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32},
|
29 |
-
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32},
|
30 |
-
# tested on A100-PCIE-40GB
|
31 |
-
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32},
|
32 |
-
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32},
|
33 |
-
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16},
|
34 |
-
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16},
|
35 |
-
# tested on RTX3090, RTX4090
|
36 |
-
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32},
|
37 |
-
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32},
|
38 |
-
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32},
|
39 |
-
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16},
|
40 |
-
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16},
|
41 |
-
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16},
|
42 |
-
# tested on GTX1080Ti
|
43 |
-
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32},
|
44 |
-
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32},
|
45 |
-
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16},
|
46 |
-
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16},
|
47 |
-
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16},
|
48 |
-
]
|
49 |
-
|
50 |
-
|
51 |
-
def find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int:
|
52 |
-
"""
|
53 |
-
Automatically search for suitable operating batch size.
|
54 |
-
|
55 |
-
Args:
|
56 |
-
ensemble_size (`int`):
|
57 |
-
Number of predictions to be ensembled.
|
58 |
-
input_res (`int`):
|
59 |
-
Operating resolution of the input image.
|
60 |
-
|
61 |
-
Returns:
|
62 |
-
`int`: Operating batch size.
|
63 |
-
"""
|
64 |
-
if not torch.cuda.is_available():
|
65 |
-
return 1
|
66 |
-
|
67 |
-
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3
|
68 |
-
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype]
|
69 |
-
for settings in sorted(
|
70 |
-
filtered_bs_search_table,
|
71 |
-
key=lambda k: (k["res"], -k["total_vram"]),
|
72 |
-
):
|
73 |
-
if input_res <= settings["res"] and total_vram >= settings["total_vram"]:
|
74 |
-
bs = settings["bs"]
|
75 |
-
if bs > ensemble_size:
|
76 |
-
bs = ensemble_size
|
77 |
-
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size:
|
78 |
-
bs = math.ceil(ensemble_size / 2)
|
79 |
-
return bs
|
80 |
-
|
81 |
-
return 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/util/ensemble.py
DELETED
@@ -1,132 +0,0 @@
|
|
1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
-
#
|
3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
-
# you may not use this file except in compliance with the License.
|
5 |
-
# You may obtain a copy of the License at
|
6 |
-
#
|
7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
-
#
|
9 |
-
# Unless required by applicable law or agreed to in writing, software
|
10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
-
# See the License for the specific language governing permissions and
|
13 |
-
# limitations under the License.
|
14 |
-
# --------------------------------------------------------------------------
|
15 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
16 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
17 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
18 |
-
# --------------------------------------------------------------------------
|
19 |
-
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
import torch
|
23 |
-
|
24 |
-
from scipy.optimize import minimize
|
25 |
-
|
26 |
-
|
27 |
-
def inter_distances(tensors: torch.Tensor):
|
28 |
-
"""
|
29 |
-
To calculate the distance between each two depth maps.
|
30 |
-
"""
|
31 |
-
distances = []
|
32 |
-
for i, j in torch.combinations(torch.arange(tensors.shape[0])):
|
33 |
-
arr1 = tensors[i : i + 1]
|
34 |
-
arr2 = tensors[j : j + 1]
|
35 |
-
distances.append(arr1 - arr2)
|
36 |
-
dist = torch.concatenate(distances, dim=0)
|
37 |
-
return dist
|
38 |
-
|
39 |
-
|
40 |
-
def ensemble_depths(
|
41 |
-
input_images: torch.Tensor,
|
42 |
-
regularizer_strength: float = 0.02,
|
43 |
-
max_iter: int = 2,
|
44 |
-
tol: float = 1e-3,
|
45 |
-
reduction: str = "median",
|
46 |
-
max_res: int = None,
|
47 |
-
):
|
48 |
-
"""
|
49 |
-
To ensemble multiple affine-invariant depth images (up to scale and shift),
|
50 |
-
by aligning estimating the scale and shift
|
51 |
-
"""
|
52 |
-
device = input_images.device
|
53 |
-
dtype = input_images.dtype
|
54 |
-
np_dtype = np.float32
|
55 |
-
|
56 |
-
original_input = input_images.clone()
|
57 |
-
n_img = input_images.shape[0]
|
58 |
-
ori_shape = input_images.shape
|
59 |
-
|
60 |
-
if max_res is not None:
|
61 |
-
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:]))
|
62 |
-
if scale_factor < 1:
|
63 |
-
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest")
|
64 |
-
input_images = downscaler(input_images)
|
65 |
-
|
66 |
-
# init guess
|
67 |
-
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
68 |
-
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1)
|
69 |
-
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1))
|
70 |
-
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1))
|
71 |
-
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype)
|
72 |
-
|
73 |
-
input_images = input_images.to(device)
|
74 |
-
|
75 |
-
# objective function
|
76 |
-
def closure(x):
|
77 |
-
len_x = len(x)
|
78 |
-
s = x[: int(len_x / 2)]
|
79 |
-
t = x[int(len_x / 2) :]
|
80 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
81 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
82 |
-
|
83 |
-
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1))
|
84 |
-
dists = inter_distances(transformed_arrays)
|
85 |
-
sqrt_dist = torch.sqrt(torch.mean(dists**2))
|
86 |
-
|
87 |
-
if "mean" == reduction:
|
88 |
-
pred = torch.mean(transformed_arrays, dim=0)
|
89 |
-
elif "median" == reduction:
|
90 |
-
pred = torch.median(transformed_arrays, dim=0).values
|
91 |
-
else:
|
92 |
-
raise ValueError
|
93 |
-
|
94 |
-
near_err = torch.sqrt((0 - torch.min(pred)) ** 2)
|
95 |
-
far_err = torch.sqrt((1 - torch.max(pred)) ** 2)
|
96 |
-
|
97 |
-
err = sqrt_dist + (near_err + far_err) * regularizer_strength
|
98 |
-
err = err.detach().cpu().numpy().astype(np_dtype)
|
99 |
-
return err
|
100 |
-
|
101 |
-
res = minimize(
|
102 |
-
closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False}
|
103 |
-
)
|
104 |
-
x = res.x
|
105 |
-
len_x = len(x)
|
106 |
-
s = x[: int(len_x / 2)]
|
107 |
-
t = x[int(len_x / 2) :]
|
108 |
-
|
109 |
-
# Prediction
|
110 |
-
s = torch.from_numpy(s).to(dtype=dtype).to(device)
|
111 |
-
t = torch.from_numpy(t).to(dtype=dtype).to(device)
|
112 |
-
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1)
|
113 |
-
if "mean" == reduction:
|
114 |
-
aligned_images = torch.mean(transformed_arrays, dim=0)
|
115 |
-
std = torch.std(transformed_arrays, dim=0)
|
116 |
-
uncertainty = std
|
117 |
-
elif "median" == reduction:
|
118 |
-
aligned_images = torch.median(transformed_arrays, dim=0).values
|
119 |
-
# MAD (median absolute deviation) as uncertainty indicator
|
120 |
-
abs_dev = torch.abs(transformed_arrays - aligned_images)
|
121 |
-
mad = torch.median(abs_dev, dim=0).values
|
122 |
-
uncertainty = mad
|
123 |
-
else:
|
124 |
-
raise ValueError(f"Unknown reduction method: {reduction}")
|
125 |
-
|
126 |
-
# Scale and shift to [0, 1]
|
127 |
-
_min = torch.min(aligned_images)
|
128 |
-
_max = torch.max(aligned_images)
|
129 |
-
aligned_images = (aligned_images - _min) / (_max - _min)
|
130 |
-
uncertainty /= _max - _min
|
131 |
-
|
132 |
-
return aligned_images, uncertainty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Marigold/marigold/util/image_util.py
DELETED
@@ -1,121 +0,0 @@
|
|
1 |
-
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
2 |
-
# Last modified: 2024-04-16
|
3 |
-
#
|
4 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
-
# you may not use this file except in compliance with the License.
|
6 |
-
# You may obtain a copy of the License at
|
7 |
-
#
|
8 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
-
#
|
10 |
-
# Unless required by applicable law or agreed to in writing, software
|
11 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
-
# See the License for the specific language governing permissions and
|
14 |
-
# limitations under the License.
|
15 |
-
# --------------------------------------------------------------------------
|
16 |
-
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
17 |
-
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
18 |
-
# More information about the method can be found at https://marigoldmonodepth.github.io
|
19 |
-
# --------------------------------------------------------------------------
|
20 |
-
|
21 |
-
|
22 |
-
import matplotlib
|
23 |
-
import numpy as np
|
24 |
-
import torch
|
25 |
-
from torchvision.transforms import InterpolationMode
|
26 |
-
from torchvision.transforms.functional import resize
|
27 |
-
|
28 |
-
|
29 |
-
def colorize_depth_maps(
|
30 |
-
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
31 |
-
):
|
32 |
-
"""
|
33 |
-
Colorize depth maps.
|
34 |
-
"""
|
35 |
-
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
36 |
-
|
37 |
-
if isinstance(depth_map, torch.Tensor):
|
38 |
-
depth = depth_map.detach().squeeze().numpy()
|
39 |
-
elif isinstance(depth_map, np.ndarray):
|
40 |
-
depth = depth_map.copy().squeeze()
|
41 |
-
# reshape to [ (B,) H, W ]
|
42 |
-
if depth.ndim < 3:
|
43 |
-
depth = depth[np.newaxis, :, :]
|
44 |
-
|
45 |
-
# colorize
|
46 |
-
cm = matplotlib.colormaps[cmap]
|
47 |
-
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
48 |
-
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # value from 0 to 1
|
49 |
-
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
50 |
-
|
51 |
-
if valid_mask is not None:
|
52 |
-
if isinstance(depth_map, torch.Tensor):
|
53 |
-
valid_mask = valid_mask.detach().numpy()
|
54 |
-
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
55 |
-
if valid_mask.ndim < 3:
|
56 |
-
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
57 |
-
else:
|
58 |
-
valid_mask = valid_mask[:, np.newaxis, :, :]
|
59 |
-
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
60 |
-
img_colored_np[~valid_mask] = 0
|
61 |
-
|
62 |
-
if isinstance(depth_map, torch.Tensor):
|
63 |
-
img_colored = torch.from_numpy(img_colored_np).float()
|
64 |
-
elif isinstance(depth_map, np.ndarray):
|
65 |
-
img_colored = img_colored_np
|
66 |
-
|
67 |
-
return img_colored
|
68 |
-
|
69 |
-
|
70 |
-
def chw2hwc(chw):
|
71 |
-
assert 3 == len(chw.shape)
|
72 |
-
if isinstance(chw, torch.Tensor):
|
73 |
-
hwc = torch.permute(chw, (1, 2, 0))
|
74 |
-
elif isinstance(chw, np.ndarray):
|
75 |
-
hwc = np.moveaxis(chw, 0, -1)
|
76 |
-
return hwc
|
77 |
-
|
78 |
-
|
79 |
-
def resize_max_res(
|
80 |
-
img: torch.Tensor,
|
81 |
-
max_edge_resolution: int,
|
82 |
-
resample_method: InterpolationMode = InterpolationMode.BILINEAR,
|
83 |
-
) -> torch.Tensor:
|
84 |
-
"""
|
85 |
-
Resize image to limit maximum edge length while keeping aspect ratio.
|
86 |
-
|
87 |
-
Args:
|
88 |
-
img (`torch.Tensor`):
|
89 |
-
Image tensor to be resized.
|
90 |
-
max_edge_resolution (`int`):
|
91 |
-
Maximum edge length (pixel).
|
92 |
-
resample_method (`PIL.Image.Resampling`):
|
93 |
-
Resampling method used to resize images.
|
94 |
-
|
95 |
-
Returns:
|
96 |
-
`torch.Tensor`: Resized image.
|
97 |
-
"""
|
98 |
-
assert 3 == img.dim()
|
99 |
-
_, original_height, original_width = img.shape
|
100 |
-
downscale_factor = min(
|
101 |
-
max_edge_resolution / original_width, max_edge_resolution / original_height
|
102 |
-
)
|
103 |
-
|
104 |
-
new_width = int(original_width * downscale_factor)
|
105 |
-
new_height = int(original_height * downscale_factor)
|
106 |
-
|
107 |
-
resized_img = resize(img, (new_height, new_width), resample_method, antialias=True)
|
108 |
-
return resized_img
|
109 |
-
|
110 |
-
|
111 |
-
def get_tv_resample_method(method_str: str) -> InterpolationMode:
|
112 |
-
resample_method_dict = {
|
113 |
-
"bilinear": InterpolationMode.BILINEAR,
|
114 |
-
"bicubic": InterpolationMode.BICUBIC,
|
115 |
-
"nearest": InterpolationMode.NEAREST_EXACT,
|
116 |
-
}
|
117 |
-
resample_method = resample_method_dict.get(method_str, None)
|
118 |
-
if resample_method is None:
|
119 |
-
raise ValueError(f"Unknown resampling method: {resample_method}")
|
120 |
-
else:
|
121 |
-
return resample_method
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|