EscherNet / mini_dust3r /api /inference.py
kxhit
cuda reinit?
5ca3a35
raw
history blame
8.02 kB
import rerun as rr
from pathlib import Path
from typing import Literal
import copy
import torch
import numpy as np
from jaxtyping import Float32, Bool
import trimesh
from tqdm import tqdm
from mini_dust3r.utils.image import load_images, ImageDict
from mini_dust3r.inference import inference, Dust3rResult
from mini_dust3r.model import AsymmetricCroCo3DStereo
from mini_dust3r.image_pairs import make_pairs
from mini_dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
from mini_dust3r.viz import pts3d_to_trimesh, cat_meshes
from dataclasses import dataclass
@dataclass
class OptimizedResult:
K_b33: Float32[np.ndarray, "b 3 3"]
world_T_cam_b44: Float32[np.ndarray, "b 4 4"]
rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]]
depth_hw_list: list[Float32[np.ndarray, "h w"]]
conf_hw_list: list[Float32[np.ndarray, "h w"]]
masks_list: Bool[np.ndarray, "h w"]
point_cloud: trimesh.PointCloud
mesh: trimesh.Trimesh
def log_optimized_result(
optimized_result: OptimizedResult, parent_log_path: Path
) -> None:
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
# log pointcloud
rr.log(
f"{parent_log_path}/pointcloud",
rr.Points3D(
positions=optimized_result.point_cloud.vertices,
colors=optimized_result.point_cloud.colors,
),
timeless=True,
)
mesh = optimized_result.mesh
rr.log(
f"{parent_log_path}/mesh",
rr.Mesh3D(
vertex_positions=mesh.vertices,
vertex_colors=mesh.visual.vertex_colors,
indices=mesh.faces,
),
timeless=True,
)
pbar = tqdm(
zip(
optimized_result.rgb_hw3_list,
optimized_result.depth_hw_list,
optimized_result.K_b33,
optimized_result.world_T_cam_b44,
),
total=len(optimized_result.rgb_hw3_list),
)
for i, (rgb_hw3, depth_hw, k_33, world_T_cam_44) in enumerate(pbar):
camera_log_path = f"{parent_log_path}/camera_{i}"
height, width, _ = rgb_hw3.shape
rr.log(
f"{camera_log_path}",
rr.Transform3D(
translation=world_T_cam_44[:3, 3],
mat3x3=world_T_cam_44[:3, :3],
from_parent=False,
),
)
rr.log(
f"{camera_log_path}/pinhole",
rr.Pinhole(
image_from_camera=k_33,
height=height,
width=width,
camera_xyz=rr.ViewCoordinates.RDF,
),
)
rr.log(
f"{camera_log_path}/pinhole/rgb",
rr.Image(rgb_hw3),
)
rr.log(
f"{camera_log_path}/pinhole/depth",
rr.DepthImage(depth_hw),
)
def scene_to_results(scene: BasePCOptimizer, min_conf_thr: int) -> OptimizedResult:
### get camera parameters K and T
K_b33: Float32[np.ndarray, "b 3 3"] = scene.get_intrinsics().numpy(force=True)
world_T_cam_b44: Float32[np.ndarray, "b 4 4"] = scene.get_im_poses().numpy(
force=True
)
### image, confidence, depths
rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]] = scene.imgs
depth_hw_list: list[Float32[np.ndarray, "h w"]] = [
depth.numpy(force=True) for depth in scene.get_depthmaps()
]
# normalized depth
# depth_hw_list = [depth_hw / depth_hw.max() for depth_hw in depth_hw_list]
conf_hw_list: list[Float32[np.ndarray, "h w"]] = [
c.numpy(force=True) for c in scene.im_conf
]
# normalize confidence
# conf_hw_list = [conf_hw / conf_hw.max() for conf_hw in conf_hw_list]
# point cloud, mesh
pts3d_list: list[Float32[np.ndarray, "h w 3"]] = [
pt3d.numpy(force=True) for pt3d in scene.get_pts3d()
]
# get log confidence
log_conf_trf: Float32[torch.Tensor, ""] = scene.conf_trf(torch.tensor(min_conf_thr))
# set the minimum confidence threshold
scene.min_conf_thr = float(log_conf_trf)
masks_list: Bool[np.ndarray, "h w"] = [
mask.numpy(force=True) for mask in scene.get_masks()
]
point_cloud: Float32[np.ndarray, "num_points 3"] = np.concatenate(
[p[m] for p, m in zip(pts3d_list, masks_list)]
)
colors: Float32[np.ndarray, "num_points 3"] = np.concatenate(
[p[m] for p, m in zip(rgb_hw3_list, masks_list)]
)
point_cloud = trimesh.PointCloud(
point_cloud.reshape(-1, 3), colors=colors.reshape(-1, 3)
)
meshes = []
pbar = tqdm(zip(rgb_hw3_list, pts3d_list, masks_list), total=len(rgb_hw3_list))
for rgb_hw3, pts3d, mask in pbar:
meshes.append(pts3d_to_trimesh(rgb_hw3, pts3d, mask))
mesh = trimesh.Trimesh(**cat_meshes(meshes))
optimised_result = OptimizedResult(
K_b33=K_b33,
world_T_cam_b44=world_T_cam_b44,
rgb_hw3_list=rgb_hw3_list,
depth_hw_list=depth_hw_list,
conf_hw_list=conf_hw_list,
masks_list=masks_list,
point_cloud=point_cloud,
mesh=mesh,
)
return optimised_result
def inferece_dust3r(
image_dir_or_list: Path | list[Path],
model: AsymmetricCroCo3DStereo,
device: Literal["cpu", "cuda", "mps"],
batch_size: int = 1,
image_size: Literal[224, 512] = 512,
niter: int = 100,
schedule: Literal["linear", "cosine"] = "linear",
min_conf_thr: float = 10,
) -> OptimizedResult:
"""
Perform inference using the Dust3r algorithm.
Args:
image_dir_or_list (Union[Path, List[Path]]): Path to the directory containing images or a list of image paths.
model (AsymmetricCroCo3DStereo): The Dust3r model to use for inference.
device (Literal["cpu", "cuda", "mps"]): The device to use for inference ("cpu", "cuda", or "mps").
batch_size (int, optional): The batch size for inference. Defaults to 1.
image_size (Literal[224, 512], optional): The size of the input images. Defaults to 512.
niter (int, optional): The number of iterations for the global alignment optimization. Defaults to 100.
schedule (Literal["linear", "cosine"], optional): The learning rate schedule for the global alignment optimization. Defaults to "linear".
min_conf_thr (float, optional): The minimum confidence threshold for the optimized result. Defaults to 10.
Returns:
OptimizedResult: The optimized result containing the RGB, depth, and confidence images.
Raises:
ValueError: If `image_dir_or_list` is neither a list of paths nor a path.
"""
if isinstance(image_dir_or_list, list):
imgs: list[ImageDict] = load_images(
folder_or_list=image_dir_or_list, size=image_size, verbose=True
)
elif isinstance(image_dir_or_list, Path):
imgs: list[ImageDict] = load_images(
folder_or_list=str(image_dir_or_list), size=image_size, verbose=True
)
else:
raise ValueError("image_dir_or_list should be a list of paths or a path")
# if only one image was loaded, duplicate it to feed into stereo network
if len(imgs) == 1:
imgs = [imgs[0], copy.deepcopy(imgs[0])]
imgs[1]["idx"] = 1
pairs: list[tuple[ImageDict, ImageDict]] = make_pairs(
imgs, scene_graph="complete", prefilter=None, symmetrize=True
)
output: Dust3rResult = inference(pairs, model, device, batch_size=batch_size)
mode = (
GlobalAlignerMode.PointCloudOptimizer
if len(imgs) > 2
else GlobalAlignerMode.PairViewer
)
scene: BasePCOptimizer = global_aligner(
dust3r_output=output, device=device, mode=mode
)
lr = 0.01
if mode == GlobalAlignerMode.PointCloudOptimizer:
loss = scene.compute_global_alignment(
init="mst", niter=niter, schedule=schedule, lr=lr
)
# get the optimized result from the scene
optimized_result: OptimizedResult = scene_to_results(scene, min_conf_thr)
return optimized_result