Spaces:
Running
on
Zero
Running
on
Zero
import rerun as rr | |
from pathlib import Path | |
from typing import Literal | |
import copy | |
import torch | |
import numpy as np | |
from jaxtyping import Float32, Bool | |
import trimesh | |
from tqdm import tqdm | |
from mini_dust3r.utils.image import load_images, ImageDict | |
from mini_dust3r.inference import inference, Dust3rResult | |
from mini_dust3r.model import AsymmetricCroCo3DStereo | |
from mini_dust3r.image_pairs import make_pairs | |
from mini_dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer | |
from mini_dust3r.viz import pts3d_to_trimesh, cat_meshes | |
from dataclasses import dataclass | |
class OptimizedResult: | |
K_b33: Float32[np.ndarray, "b 3 3"] | |
world_T_cam_b44: Float32[np.ndarray, "b 4 4"] | |
rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]] | |
depth_hw_list: list[Float32[np.ndarray, "h w"]] | |
conf_hw_list: list[Float32[np.ndarray, "h w"]] | |
masks_list: Bool[np.ndarray, "h w"] | |
point_cloud: trimesh.PointCloud | |
mesh: trimesh.Trimesh | |
def log_optimized_result( | |
optimized_result: OptimizedResult, parent_log_path: Path | |
) -> None: | |
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True) | |
# log pointcloud | |
rr.log( | |
f"{parent_log_path}/pointcloud", | |
rr.Points3D( | |
positions=optimized_result.point_cloud.vertices, | |
colors=optimized_result.point_cloud.colors, | |
), | |
timeless=True, | |
) | |
mesh = optimized_result.mesh | |
rr.log( | |
f"{parent_log_path}/mesh", | |
rr.Mesh3D( | |
vertex_positions=mesh.vertices, | |
vertex_colors=mesh.visual.vertex_colors, | |
indices=mesh.faces, | |
), | |
timeless=True, | |
) | |
pbar = tqdm( | |
zip( | |
optimized_result.rgb_hw3_list, | |
optimized_result.depth_hw_list, | |
optimized_result.K_b33, | |
optimized_result.world_T_cam_b44, | |
), | |
total=len(optimized_result.rgb_hw3_list), | |
) | |
for i, (rgb_hw3, depth_hw, k_33, world_T_cam_44) in enumerate(pbar): | |
camera_log_path = f"{parent_log_path}/camera_{i}" | |
height, width, _ = rgb_hw3.shape | |
rr.log( | |
f"{camera_log_path}", | |
rr.Transform3D( | |
translation=world_T_cam_44[:3, 3], | |
mat3x3=world_T_cam_44[:3, :3], | |
from_parent=False, | |
), | |
) | |
rr.log( | |
f"{camera_log_path}/pinhole", | |
rr.Pinhole( | |
image_from_camera=k_33, | |
height=height, | |
width=width, | |
camera_xyz=rr.ViewCoordinates.RDF, | |
), | |
) | |
rr.log( | |
f"{camera_log_path}/pinhole/rgb", | |
rr.Image(rgb_hw3), | |
) | |
rr.log( | |
f"{camera_log_path}/pinhole/depth", | |
rr.DepthImage(depth_hw), | |
) | |
def scene_to_results(scene: BasePCOptimizer, min_conf_thr: int) -> OptimizedResult: | |
### get camera parameters K and T | |
K_b33: Float32[np.ndarray, "b 3 3"] = scene.get_intrinsics().numpy(force=True) | |
world_T_cam_b44: Float32[np.ndarray, "b 4 4"] = scene.get_im_poses().numpy( | |
force=True | |
) | |
### image, confidence, depths | |
rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]] = scene.imgs | |
depth_hw_list: list[Float32[np.ndarray, "h w"]] = [ | |
depth.numpy(force=True) for depth in scene.get_depthmaps() | |
] | |
# normalized depth | |
# depth_hw_list = [depth_hw / depth_hw.max() for depth_hw in depth_hw_list] | |
conf_hw_list: list[Float32[np.ndarray, "h w"]] = [ | |
c.numpy(force=True) for c in scene.im_conf | |
] | |
# normalize confidence | |
# conf_hw_list = [conf_hw / conf_hw.max() for conf_hw in conf_hw_list] | |
# point cloud, mesh | |
pts3d_list: list[Float32[np.ndarray, "h w 3"]] = [ | |
pt3d.numpy(force=True) for pt3d in scene.get_pts3d() | |
] | |
# get log confidence | |
log_conf_trf: Float32[torch.Tensor, ""] = scene.conf_trf(torch.tensor(min_conf_thr)) | |
# set the minimum confidence threshold | |
scene.min_conf_thr = float(log_conf_trf) | |
masks_list: Bool[np.ndarray, "h w"] = [ | |
mask.numpy(force=True) for mask in scene.get_masks() | |
] | |
point_cloud: Float32[np.ndarray, "num_points 3"] = np.concatenate( | |
[p[m] for p, m in zip(pts3d_list, masks_list)] | |
) | |
colors: Float32[np.ndarray, "num_points 3"] = np.concatenate( | |
[p[m] for p, m in zip(rgb_hw3_list, masks_list)] | |
) | |
point_cloud = trimesh.PointCloud( | |
point_cloud.reshape(-1, 3), colors=colors.reshape(-1, 3) | |
) | |
meshes = [] | |
pbar = tqdm(zip(rgb_hw3_list, pts3d_list, masks_list), total=len(rgb_hw3_list)) | |
for rgb_hw3, pts3d, mask in pbar: | |
meshes.append(pts3d_to_trimesh(rgb_hw3, pts3d, mask)) | |
mesh = trimesh.Trimesh(**cat_meshes(meshes)) | |
optimised_result = OptimizedResult( | |
K_b33=K_b33, | |
world_T_cam_b44=world_T_cam_b44, | |
rgb_hw3_list=rgb_hw3_list, | |
depth_hw_list=depth_hw_list, | |
conf_hw_list=conf_hw_list, | |
masks_list=masks_list, | |
point_cloud=point_cloud, | |
mesh=mesh, | |
) | |
return optimised_result | |
def inferece_dust3r( | |
image_dir_or_list: Path | list[Path], | |
model: AsymmetricCroCo3DStereo, | |
device: Literal["cpu", "cuda", "mps"], | |
batch_size: int = 1, | |
image_size: Literal[224, 512] = 512, | |
niter: int = 100, | |
schedule: Literal["linear", "cosine"] = "linear", | |
min_conf_thr: float = 10, | |
) -> OptimizedResult: | |
""" | |
Perform inference using the Dust3r algorithm. | |
Args: | |
image_dir_or_list (Union[Path, List[Path]]): Path to the directory containing images or a list of image paths. | |
model (AsymmetricCroCo3DStereo): The Dust3r model to use for inference. | |
device (Literal["cpu", "cuda", "mps"]): The device to use for inference ("cpu", "cuda", or "mps"). | |
batch_size (int, optional): The batch size for inference. Defaults to 1. | |
image_size (Literal[224, 512], optional): The size of the input images. Defaults to 512. | |
niter (int, optional): The number of iterations for the global alignment optimization. Defaults to 100. | |
schedule (Literal["linear", "cosine"], optional): The learning rate schedule for the global alignment optimization. Defaults to "linear". | |
min_conf_thr (float, optional): The minimum confidence threshold for the optimized result. Defaults to 10. | |
Returns: | |
OptimizedResult: The optimized result containing the RGB, depth, and confidence images. | |
Raises: | |
ValueError: If `image_dir_or_list` is neither a list of paths nor a path. | |
""" | |
if isinstance(image_dir_or_list, list): | |
imgs: list[ImageDict] = load_images( | |
folder_or_list=image_dir_or_list, size=image_size, verbose=True | |
) | |
elif isinstance(image_dir_or_list, Path): | |
imgs: list[ImageDict] = load_images( | |
folder_or_list=str(image_dir_or_list), size=image_size, verbose=True | |
) | |
else: | |
raise ValueError("image_dir_or_list should be a list of paths or a path") | |
# if only one image was loaded, duplicate it to feed into stereo network | |
if len(imgs) == 1: | |
imgs = [imgs[0], copy.deepcopy(imgs[0])] | |
imgs[1]["idx"] = 1 | |
pairs: list[tuple[ImageDict, ImageDict]] = make_pairs( | |
imgs, scene_graph="complete", prefilter=None, symmetrize=True | |
) | |
output: Dust3rResult = inference(pairs, model, device, batch_size=batch_size) | |
mode = ( | |
GlobalAlignerMode.PointCloudOptimizer | |
if len(imgs) > 2 | |
else GlobalAlignerMode.PairViewer | |
) | |
scene: BasePCOptimizer = global_aligner( | |
dust3r_output=output, device=device, mode=mode | |
) | |
lr = 0.01 | |
if mode == GlobalAlignerMode.PointCloudOptimizer: | |
loss = scene.compute_global_alignment( | |
init="mst", niter=niter, schedule=schedule, lr=lr | |
) | |
# get the optimized result from the scene | |
optimized_result: OptimizedResult = scene_to_results(scene, min_conf_thr) | |
return optimized_result | |