Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------------------------- | |
# If you find this code useful, we kindly ask you to cite our paper in your work. | |
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation | |
# More information about the method can be found at https://marigoldmonodepth.github.io | |
# -------------------------------------------------------------------------- | |
# @GonzaloMartinGarcia | |
# This file is a modified version of the original Marigold pipeline file. | |
# Based on GeoWizard, we added the option to sample surface normals, marked with # add. | |
from typing import Dict, Union | |
import numpy as np | |
import torch | |
from diffusers import ( | |
AutoencoderKL, | |
DDIMScheduler, | |
DiffusionPipeline, | |
LCMScheduler, | |
UNet2DConditionModel, | |
DDPMScheduler, | |
) | |
from diffusers.utils import BaseOutput | |
from PIL import Image | |
from torchvision.transforms.functional import resize, pil_to_tensor | |
from torchvision.transforms import InterpolationMode | |
from torch.utils.data import DataLoader, TensorDataset | |
from tqdm.auto import tqdm | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from .util.batchsize import find_batch_size | |
from .util.ensemble import ensemble_depths | |
from .util.image_util import ( | |
chw2hwc, | |
colorize_depth_maps, | |
get_tv_resample_method, | |
resize_max_res, | |
) | |
# add | |
import random | |
# add | |
# Surface Normals Ensamble from the GeoWizard github repository (https://github.com/fuxiao0719/GeoWizard) | |
def ensemble_normals(input_images:torch.Tensor): | |
normal_preds = input_images | |
bsz, d, h, w = normal_preds.shape | |
normal_preds = normal_preds / (torch.norm(normal_preds, p=2, dim=1).unsqueeze(1)+1e-5) | |
phi = torch.atan2(normal_preds[:,1,:,:], normal_preds[:,0,:,:]).mean(dim=0) | |
theta = torch.atan2(torch.norm(normal_preds[:,:2,:,:], p=2, dim=1), normal_preds[:,2,:,:]).mean(dim=0) | |
normal_pred = torch.zeros((d,h,w)).to(normal_preds) | |
normal_pred[0,:,:] = torch.sin(theta) * torch.cos(phi) | |
normal_pred[1,:,:] = torch.sin(theta) * torch.sin(phi) | |
normal_pred[2,:,:] = torch.cos(theta) | |
angle_error = torch.acos(torch.clip(torch.cosine_similarity(normal_pred[None], normal_preds, dim=1),-0.999, 0.999)) | |
normal_idx = torch.argmin(angle_error.reshape(bsz,-1).sum(-1)) | |
return normal_preds[normal_idx], None | |
# add | |
# Pyramid nosie from | |
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2?s=31 | |
def pyramid_noise_like(x, discount=0.9): | |
b, c, w, h = x.shape | |
u = torch.nn.Upsample(size=(w, h), mode='bilinear') | |
noise = torch.randn_like(x) | |
for i in range(10): | |
r = random.random()*2+2 | |
w, h = max(1, int(w/(r**i))), max(1, int(h/(r**i))) | |
noise += u(torch.randn(b, c, w, h).to(x)) * discount**i | |
if w==1 or h==1: | |
break | |
return noise / noise.std() | |
class MarigoldDepthOutput(BaseOutput): | |
""" | |
Output class for Marigold monocular depth prediction pipeline. | |
Args: | |
depth_np (`np.ndarray`): | |
Predicted depth map, with depth values in the range of [0, 1]. | |
depth_colored (`PIL.Image.Image`): | |
Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. | |
uncertainty (`None` or `np.ndarray`): | |
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. | |
normal_np (`np.ndarray`): | |
Predicted normal map, with normal vectors in the range of [-1, 1]. | |
normal_colored (`PIL.Image.Image`): | |
Colorized normal map | |
""" | |
depth_np: np.ndarray | |
depth_colored: Union[None, Image.Image] | |
uncertainty: Union[None, np.ndarray] | |
# add | |
normal_np: np.ndarray | |
normal_colored: Union[None, Image.Image] | |
class MarigoldPipeline(DiffusionPipeline): | |
""" | |
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. | |
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the | |
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) | |
Args: | |
unet (`UNet2DConditionModel`): | |
Conditional U-Net to denoise the depth latent, conditioned on image latent. | |
vae (`AutoencoderKL`): | |
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps | |
to and from latent representations. | |
scheduler (`DDIMScheduler`): | |
A scheduler to be used in combination with `unet` to denoise the encoded image latents. | |
text_encoder (`CLIPTextModel`): | |
Text-encoder, for empty text embedding. | |
tokenizer (`CLIPTokenizer`): | |
CLIP tokenizer. | |
""" | |
rgb_latent_scale_factor = 0.18215 | |
depth_latent_scale_factor = 0.18215 | |
def __init__( | |
self, | |
unet: UNet2DConditionModel, | |
vae: AutoencoderKL, | |
scheduler: Union[DDIMScheduler,DDPMScheduler,LCMScheduler], | |
text_encoder: CLIPTextModel, | |
tokenizer: CLIPTokenizer, | |
): | |
super().__init__() | |
self.register_modules( | |
unet=unet, | |
vae=vae, | |
scheduler=scheduler, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
) | |
self.empty_text_embed = None | |
def __call__( | |
self, | |
input_image: Union[Image.Image, torch.Tensor], | |
denoising_steps: int = 10, | |
ensemble_size: int = 10, | |
processing_res: int = 768, | |
match_input_res: bool = True, | |
resample_method: str = "bilinear", | |
batch_size: int = 0, | |
color_map: str = "Spectral", | |
show_progress_bar: bool = True, | |
ensemble_kwargs: Dict = None, | |
# add | |
noise="gaussian", | |
normals=False, | |
) -> MarigoldDepthOutput: | |
""" | |
Function invoked when calling the pipeline. | |
Args: | |
input_image (`Image`): | |
Input RGB (or gray-scale) image. | |
processing_res (`int`, *optional*, defaults to `768`): | |
Maximum resolution of processing. | |
If set to 0: will not resize at all. | |
match_input_res (`bool`, *optional*, defaults to `True`): | |
Resize depth prediction to match input resolution. | |
Only valid if `processing_res` > 0. | |
resample_method: (`str`, *optional*, defaults to `bilinear`): | |
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. | |
denoising_steps (`int`, *optional*, defaults to `10`): | |
Number of diffusion denoising steps (DDIM) during inference. | |
ensemble_size (`int`, *optional*, defaults to `10`): | |
Number of predictions to be ensembled. | |
batch_size (`int`, *optional*, defaults to `0`): | |
Inference batch size, no bigger than `num_ensemble`. | |
If set to 0, the script will automatically decide the proper batch size. | |
show_progress_bar (`bool`, *optional*, defaults to `True`): | |
Display a progress bar of diffusion denoising. | |
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): | |
Colormap used to colorize the depth map. | |
ensemble_kwargs (`dict`, *optional*, defaults to `None`): | |
Arguments for detailed ensembling settings. | |
noise (`str`, *optional*, defaults to `gaussian`): | |
Type of noise to be used for the initial depth map. | |
Can be one of `gaussian`, `pyramid`, `zeros`. | |
normals (`bool`, *optional*, defaults to `False`): | |
If `True`, the pipeline will predict surface normals instead of depth maps. | |
Returns: | |
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: | |
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] | |
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None` | |
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) | |
coming from ensembling. None if `ensemble_size = 1` | |
- **normal_np** (`np.ndarray`) Predicted normal map, with normal vectors in the range of [-1, 1] | |
- **normal_colored** (`PIL.Image.Image`) Colorized normal map | |
""" | |
assert processing_res >= 0 | |
assert ensemble_size >= 1 | |
resample_method: InterpolationMode = get_tv_resample_method(resample_method) | |
# ----------------- Image Preprocess ----------------- | |
# Convert to torch tensor | |
if isinstance(input_image, Image.Image): | |
input_image = input_image.convert("RGB") | |
rgb = pil_to_tensor(input_image) # [H, W, rgb] -> [rgb, H, W] | |
elif isinstance(input_image, torch.Tensor): | |
rgb = input_image.squeeze() | |
else: | |
raise TypeError(f"Unknown input type: {type(input_image) = }") | |
input_size = rgb.shape | |
assert ( | |
3 == rgb.dim() and 3 == input_size[0] | |
), f"Wrong input shape {input_size}, expected [rgb, H, W]" | |
# Resize image | |
if processing_res > 0: | |
rgb = resize_max_res( | |
rgb, | |
max_edge_resolution=processing_res, | |
resample_method=resample_method, | |
) | |
# Normalize rgb values | |
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 # [0, 255] -> [-1, 1] | |
rgb_norm = rgb_norm.to(self.dtype) | |
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 | |
# ----------------- Predicting depth/normal -------------- | |
# Batch repeated input image | |
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) | |
single_rgb_dataset = TensorDataset(duplicated_rgb) | |
if batch_size > 0: | |
_bs = batch_size | |
else: | |
_bs = find_batch_size( | |
ensemble_size=ensemble_size, | |
input_res=max(rgb_norm.shape[1:]), | |
dtype=self.dtype, | |
) | |
single_rgb_loader = DataLoader( | |
single_rgb_dataset, batch_size=_bs, shuffle=False | |
) | |
# load iterator | |
pred_ls = [] | |
if show_progress_bar: | |
iterable = tqdm( | |
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False | |
) | |
else: | |
iterable = single_rgb_loader | |
# inference (batched) | |
for batch in iterable: | |
(batched_img,) = batch | |
pred_raw = self.single_infer( | |
rgb_in=batched_img, | |
num_inference_steps=denoising_steps, | |
show_pbar=show_progress_bar, | |
# add | |
noise=noise, | |
normals=normals, | |
) | |
pred_ls.append(pred_raw.detach()) | |
preds = torch.concat(pred_ls, dim=0).squeeze() | |
torch.cuda.empty_cache() # clear vram cache for ensembling | |
# ----------------- Test-time ensembling ----------------- | |
if ensemble_size > 1: # add | |
pred, pred_uncert = ensemble_normals(preds) if normals else ensemble_depths(preds, **(ensemble_kwargs or {})) | |
else: | |
pred = preds | |
pred_uncert = None | |
# ----------------- Post processing ----------------- | |
if normals: | |
# add | |
# Normalizae normal vectors to unit length | |
pred /= (torch.norm(pred, p=2, dim=0, keepdim=True)+1e-5) | |
else: | |
# Scale relative prediction to [0, 1] | |
min_d = torch.min(pred) | |
max_d = torch.max(pred) | |
if max_d == min_d: | |
pred = torch.zeros_like(pred) | |
else: | |
pred = (pred - min_d) / (max_d - min_d) | |
# Resize back to original resolution | |
if match_input_res: | |
pred = resize( | |
pred if normals else pred.unsqueeze(0), | |
(input_size[-2],input_size[-1]), | |
interpolation=resample_method, | |
antialias=True, | |
).squeeze() | |
# Convert to numpy | |
pred = pred.cpu().numpy() | |
# Process prediction for visualization | |
if not normals: | |
# add | |
pred = pred.clip(0, 1) | |
if color_map is not None: | |
colored = colorize_depth_maps( | |
pred, 0, 1, cmap=color_map | |
).squeeze() # [3, H, W], value in (0, 1) | |
colored = (colored * 255).astype(np.uint8) | |
colored_hwc = chw2hwc(colored) | |
colored_img = Image.fromarray(colored_hwc) | |
else: | |
colored_img = None | |
else: | |
pred = pred.clip(-1.0, 1.0) | |
colored = (((pred+1)/2) * 255).astype(np.uint8) | |
colored_hwc = chw2hwc(colored) | |
colored_img = Image.fromarray(colored_hwc) | |
return MarigoldDepthOutput( | |
depth_np = pred if not normals else None, | |
depth_colored = colored_img if not normals else None, | |
uncertainty = pred_uncert, | |
# add | |
normal_np = pred if normals else None, | |
normal_colored = colored_img if normals else None, | |
) | |
def encode_empty_text(self): | |
""" | |
Encode text embedding for empty prompt | |
""" | |
prompt = "" | |
text_inputs = self.tokenizer( | |
prompt, | |
padding="do_not_pad", | |
max_length=self.tokenizer.model_max_length, | |
truncation=True, | |
return_tensors="pt", | |
) | |
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) | |
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) | |
def single_infer( | |
self, | |
rgb_in: torch.Tensor, | |
num_inference_steps: int, | |
show_pbar: bool, | |
# add | |
noise="gaussian", | |
normals=False, | |
) -> torch.Tensor: | |
""" | |
Perform an individual depth prediction without ensembling. | |
Args: | |
rgb_in (`torch.Tensor`): | |
Input RGB image. | |
num_inference_steps (`int`): | |
Number of diffusion denoisign steps (DDIM) during inference. | |
show_pbar (`bool`): | |
Display a progress bar of diffusion denoising. | |
noise (`str`, *optional*, defaults to `gaussian`): | |
Type of noise to be used for the initial depth map. | |
Can be one of `gaussian`, `pyramid`, `zeros`. | |
Returns: | |
`torch.Tensor`: Predicted depth map. | |
""" | |
device = self.device | |
rgb_in = rgb_in.to(device) | |
# Set timesteps | |
self.scheduler.set_timesteps(num_inference_steps, device=device) | |
timesteps = self.scheduler.timesteps # [T] | |
# Encode image | |
rgb_latent = self.encode_rgb(rgb_in) | |
# add | |
# Initial prediction | |
latent_shape = rgb_latent.shape | |
if noise == "gaussian": | |
latent = torch.randn( | |
latent_shape, | |
device=device, | |
dtype=self.dtype, | |
) | |
elif noise == "pyramid": | |
latent = pyramid_noise_like(rgb_latent).to(device) # [B, 4, h, w] | |
elif noise == "zeros": | |
latent = torch.zeros( | |
latent_shape, | |
device=device, | |
dtype=self.dtype, | |
) | |
else: | |
raise ValueError(f"Unknown noise type: {noise}") | |
# Batched empty text embedding | |
if self.empty_text_embed is None: | |
self.encode_empty_text() | |
batch_empty_text_embed = self.empty_text_embed.repeat( | |
(rgb_latent.shape[0], 1, 1) | |
) # [B, 2, 1024] | |
# Denoising loop | |
if show_pbar: | |
iterable = tqdm( | |
enumerate(timesteps), | |
total=len(timesteps), | |
leave=False, | |
desc=" " * 4 + "Diffusion denoising", | |
) | |
else: | |
iterable = enumerate(timesteps) | |
for i, t in iterable: | |
unet_input = torch.cat( | |
[rgb_latent, latent], dim=1 | |
) # this order is important | |
# predict the noise residual | |
noise_pred = self.unet( | |
unet_input, t, encoder_hidden_states=batch_empty_text_embed | |
).sample # [B, 4, h, w] | |
# compute the previous noisy sample x_t -> x_t-1 | |
scheduler_step = self.scheduler.step( | |
noise_pred, t, latent | |
) | |
latent = scheduler_step.prev_sample | |
if normals: | |
# add | |
# decode and normalize normal vectors | |
normal = self.decode_normal(latent) | |
normal /= (torch.norm(normal, p=2, dim=1, keepdim=True)+1e-5) | |
return normal | |
else: | |
# decode and normalize depth map | |
depth = self.decode_depth(latent) | |
depth = torch.clip(depth, -1.0, 1.0) | |
depth = (depth + 1.0) / 2.0 | |
return depth | |
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: | |
""" | |
Encode RGB image into latent. | |
Args: | |
rgb_in (`torch.Tensor`): | |
Input RGB image to be encoded. | |
Returns: | |
`torch.Tensor`: Image latent. | |
""" | |
# encode | |
h = self.vae.encoder(rgb_in) | |
moments = self.vae.quant_conv(h) | |
mean, logvar = torch.chunk(moments, 2, dim=1) | |
# scale latent | |
rgb_latent = mean * self.rgb_latent_scale_factor | |
return rgb_latent | |
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: | |
""" | |
Decode depth latent into depth map. | |
Args: | |
depth_latent (`torch.Tensor`): | |
Depth latent to be decoded. | |
Returns: | |
`torch.Tensor`: Decoded depth map. | |
""" | |
# scale latent | |
depth_latent = depth_latent / self.depth_latent_scale_factor | |
# decode | |
z = self.vae.post_quant_conv(depth_latent) | |
stacked = self.vae.decoder(z) | |
# mean of output channels | |
depth_mean = stacked.mean(dim=1, keepdim=True) | |
return depth_mean | |
# add | |
def decode_normal(self, normal_latent: torch.Tensor) -> torch.Tensor: | |
""" | |
Decode normal latent into normal map. | |
Args: | |
normal_latent (`torch.Tensor`): | |
normal latent to be decoded. | |
Returns: | |
`torch.Tensor`: Decoded depth map. | |
""" | |
# scale latent | |
normal_latent = normal_latent / self.depth_latent_scale_factor | |
# decode | |
z = self.vae.post_quant_conv(normal_latent) | |
normal = self.vae.decoder(z) | |
return normal | |