EscherNet / mini_dust3r /post_process.py
kxhit
cuda reinit?
5ca3a35
raw
history blame
2.48 kB
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# utilities for interpreting the DUST3R output
# --------------------------------------------------------
import numpy as np
import torch
from mini_dust3r.utils.geometry import xy_grid
def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf):
""" Reprojection method, for when the absolute depth is known:
1) estimate the camera focal using a robust estimator
2) reproject points onto true rays, minimizing a certain error
"""
B, H, W, THREE = pts3d.shape
assert THREE == 3
# centered pixel grid
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
if focal_mode == 'median':
with torch.no_grad():
# direct estimation of focal
u, v = pixels.unbind(dim=-1)
x, y, z = pts3d.unbind(dim=-1)
fx_votes = (u * z) / x
fy_votes = (v * z) / y
# assume square pixels, hence same focal for X and Y
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
focal = torch.nanmedian(f_votes, dim=-1).values
elif focal_mode == 'weiszfeld':
# init focal with l2 closed form
# we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
dot_xy_xy = xy_over_z.square().sum(dim=-1)
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
# iterative re-weighted least-squares
for iter in range(10):
# re-weighting by inverse of distance
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
# print(dis.nanmean(-1))
w = dis.clip(min=1e-8).reciprocal()
# update the scaling with the new weights
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
else:
raise ValueError(f'bad {focal_mode=}')
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
# print(focal)
return focal