Spaces:
Running
on
Zero
Running
on
Zero
kxhit
commited on
Commit
•
5ca3a35
1
Parent(s):
d161cfd
cuda reinit?
Browse files- app.py +2 -2
- app_bk.py +786 -0
- mini_dust3r/__init__.py +0 -0
- mini_dust3r/api/__init__.py +3 -0
- mini_dust3r/api/inference.py +225 -0
- mini_dust3r/cloud_opt/__init__.py +44 -0
- mini_dust3r/cloud_opt/base_opt.py +390 -0
- mini_dust3r/cloud_opt/commons.py +90 -0
- mini_dust3r/cloud_opt/init_im_poses.py +316 -0
- mini_dust3r/cloud_opt/modular_optimizer.py +145 -0
- mini_dust3r/cloud_opt/optimizer.py +248 -0
- mini_dust3r/cloud_opt/pair_viewer.py +127 -0
- mini_dust3r/croco/blocks.py +241 -0
- mini_dust3r/croco/croco.py +249 -0
- mini_dust3r/croco/dpt_block.py +450 -0
- mini_dust3r/croco/masking.py +25 -0
- mini_dust3r/croco/pos_embed.py +159 -0
- mini_dust3r/heads/__init__.py +19 -0
- mini_dust3r/heads/dpt_head.py +114 -0
- mini_dust3r/heads/linear_head.py +41 -0
- mini_dust3r/heads/postprocess.py +58 -0
- mini_dust3r/image_pairs.py +85 -0
- mini_dust3r/inference.py +204 -0
- mini_dust3r/model.py +259 -0
- mini_dust3r/optim_factory.py +14 -0
- mini_dust3r/patch_embed.py +69 -0
- mini_dust3r/post_process.py +60 -0
- mini_dust3r/utils/device.py +76 -0
- mini_dust3r/utils/geometry.py +361 -0
- mini_dust3r/utils/image.py +141 -0
- mini_dust3r/utils/misc.py +121 -0
- mini_dust3r/viz.py +320 -0
app.py
CHANGED
@@ -268,7 +268,7 @@ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_
|
|
268 |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
269 |
import math
|
270 |
|
271 |
-
@spaces.GPU(duration=120)
|
272 |
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
273 |
cam_color=None, as_pointcloud=False,
|
274 |
transparent_cams=False, silent=False, same_focals=False):
|
@@ -321,7 +321,7 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
|
|
321 |
scene.export(file_obj=outfile)
|
322 |
return outfile
|
323 |
|
324 |
-
@spaces.GPU(duration=120)
|
325 |
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
326 |
clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
|
327 |
"""
|
|
|
268 |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
269 |
import math
|
270 |
|
271 |
+
# @spaces.GPU(duration=120)
|
272 |
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
273 |
cam_color=None, as_pointcloud=False,
|
274 |
transparent_cams=False, silent=False, same_focals=False):
|
|
|
321 |
scene.export(file_obj=outfile)
|
322 |
return outfile
|
323 |
|
324 |
+
# @spaces.GPU(duration=120)
|
325 |
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
326 |
clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
|
327 |
"""
|
app_bk.py
ADDED
@@ -0,0 +1,786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import torch
|
3 |
+
print("cuda is available: ", torch.cuda.is_available())
|
4 |
+
|
5 |
+
import gradio as gr
|
6 |
+
import os
|
7 |
+
import shutil
|
8 |
+
import rembg
|
9 |
+
import numpy as np
|
10 |
+
import math
|
11 |
+
import open3d as o3d
|
12 |
+
from PIL import Image
|
13 |
+
import torchvision
|
14 |
+
import trimesh
|
15 |
+
from skimage.io import imsave
|
16 |
+
import imageio
|
17 |
+
import cv2
|
18 |
+
import matplotlib.pyplot as pl
|
19 |
+
pl.ion()
|
20 |
+
|
21 |
+
CaPE_TYPE = "6DoF"
|
22 |
+
device = 'cuda' #if torch.cuda.is_available() else 'cpu'
|
23 |
+
weight_dtype = torch.float16
|
24 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
25 |
+
|
26 |
+
# EscherNet
|
27 |
+
# create angles in archimedean spiral with N steps
|
28 |
+
def get_archimedean_spiral(sphere_radius, num_steps=250):
|
29 |
+
# x-z plane, around upper y
|
30 |
+
'''
|
31 |
+
https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
|
32 |
+
'''
|
33 |
+
a = 40
|
34 |
+
r = sphere_radius
|
35 |
+
|
36 |
+
translations = []
|
37 |
+
angles = []
|
38 |
+
|
39 |
+
# i = a / 2
|
40 |
+
i = 0.01
|
41 |
+
while i < a:
|
42 |
+
theta = i / a * math.pi
|
43 |
+
x = r * math.sin(theta) * math.cos(-i)
|
44 |
+
z = r * math.sin(-theta + math.pi) * math.sin(-i)
|
45 |
+
y = r * - math.cos(theta)
|
46 |
+
|
47 |
+
# translations.append((x, y, z)) # origin
|
48 |
+
translations.append((x, z, -y))
|
49 |
+
angles.append([np.rad2deg(-i), np.rad2deg(theta)])
|
50 |
+
|
51 |
+
# i += a / (2 * num_steps)
|
52 |
+
i += a / (1 * num_steps)
|
53 |
+
|
54 |
+
return np.array(translations), np.stack(angles)
|
55 |
+
|
56 |
+
def look_at(origin, target, up):
|
57 |
+
forward = (target - origin)
|
58 |
+
forward = forward / np.linalg.norm(forward)
|
59 |
+
right = np.cross(up, forward)
|
60 |
+
right = right / np.linalg.norm(right)
|
61 |
+
new_up = np.cross(forward, right)
|
62 |
+
rotation_matrix = np.column_stack((right, new_up, -forward, target))
|
63 |
+
matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
|
64 |
+
return matrix
|
65 |
+
|
66 |
+
import einops
|
67 |
+
import sys
|
68 |
+
|
69 |
+
sys.path.insert(0, "./6DoF/") # TODO change it when deploying
|
70 |
+
# use the customized diffusers modules
|
71 |
+
from diffusers import DDIMScheduler
|
72 |
+
from dataset import get_pose
|
73 |
+
from CN_encoder import CN_encoder
|
74 |
+
from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
|
75 |
+
from segment_anything import sam_model_registry, SamPredictor
|
76 |
+
|
77 |
+
# import rembg
|
78 |
+
from carvekit.api.high import HiInterface
|
79 |
+
|
80 |
+
|
81 |
+
pretrained_model_name_or_path = "kxic/EscherNet_demo"
|
82 |
+
resolution = 256
|
83 |
+
h,w = resolution,resolution
|
84 |
+
guidance_scale = 3.0
|
85 |
+
radius = 2.2
|
86 |
+
bg_color = [1., 1., 1., 1.]
|
87 |
+
image_transforms = torchvision.transforms.Compose(
|
88 |
+
[
|
89 |
+
torchvision.transforms.Resize((resolution, resolution)), # 256, 256
|
90 |
+
torchvision.transforms.ToTensor(),
|
91 |
+
torchvision.transforms.Normalize([0.5], [0.5])
|
92 |
+
]
|
93 |
+
)
|
94 |
+
xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
|
95 |
+
# only half toop
|
96 |
+
xyzs_spiral = xyzs_spiral[:100]
|
97 |
+
angles_spiral = angles_spiral[:100]
|
98 |
+
|
99 |
+
# Init pipeline
|
100 |
+
scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
|
101 |
+
image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
|
102 |
+
pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
|
103 |
+
pretrained_model_name_or_path,
|
104 |
+
revision=None,
|
105 |
+
scheduler=scheduler,
|
106 |
+
image_encoder=None,
|
107 |
+
safety_checker=None,
|
108 |
+
feature_extractor=None,
|
109 |
+
torch_dtype=weight_dtype,
|
110 |
+
)
|
111 |
+
pipeline.image_encoder = image_encoder.to(weight_dtype)
|
112 |
+
|
113 |
+
pipeline.set_progress_bar_config(disable=False)
|
114 |
+
|
115 |
+
pipeline = pipeline.to(device)
|
116 |
+
|
117 |
+
# pipeline.enable_xformers_memory_efficient_attention()
|
118 |
+
# enable vae slicing
|
119 |
+
pipeline.enable_vae_slicing()
|
120 |
+
# pipeline.enable_xformers_memory_efficient_attention()
|
121 |
+
|
122 |
+
|
123 |
+
#### object segmentation
|
124 |
+
def sam_init():
|
125 |
+
sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
|
126 |
+
if os.path.exists(sam_checkpoint) is False:
|
127 |
+
os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
|
128 |
+
model_type = "vit_h"
|
129 |
+
|
130 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
|
131 |
+
predictor = SamPredictor(sam)
|
132 |
+
return predictor
|
133 |
+
|
134 |
+
def create_carvekit_interface():
|
135 |
+
# Check doc strings for more information
|
136 |
+
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
|
137 |
+
batch_size_seg=6,
|
138 |
+
batch_size_matting=1,
|
139 |
+
device="cpu",
|
140 |
+
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
141 |
+
matting_mask_size=2048,
|
142 |
+
trimap_prob_threshold=231,
|
143 |
+
trimap_dilation=30,
|
144 |
+
trimap_erosion_iters=5,
|
145 |
+
fp16=True)
|
146 |
+
|
147 |
+
return interface
|
148 |
+
|
149 |
+
|
150 |
+
# rembg_session = rembg.new_session()
|
151 |
+
rembg_session = create_carvekit_interface()
|
152 |
+
predictor = sam_init()
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
@spaces.GPU(duration=120)
|
157 |
+
def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
|
158 |
+
# set the random seed
|
159 |
+
generator = torch.Generator(device=device).manual_seed(sample_seed)
|
160 |
+
# generator = None
|
161 |
+
T_out = nvs_num
|
162 |
+
T_in = len(eschernet_input_dict['imgs'])
|
163 |
+
####### output pose
|
164 |
+
# TODO choose T_out number of poses sequentially from the spiral
|
165 |
+
xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
|
166 |
+
angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
|
167 |
+
|
168 |
+
####### input's max radius for translation scaling
|
169 |
+
radii = eschernet_input_dict['radii']
|
170 |
+
max_t = np.max(radii)
|
171 |
+
min_t = np.min(radii)
|
172 |
+
|
173 |
+
####### input pose
|
174 |
+
pose_in = []
|
175 |
+
for T_in_index in range(T_in):
|
176 |
+
pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
|
177 |
+
pose[1:3, :] *= -1 # coordinate system conversion
|
178 |
+
pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
|
179 |
+
pose_in.append(torch.from_numpy(pose))
|
180 |
+
|
181 |
+
####### input image
|
182 |
+
img = eschernet_input_dict['imgs'] / 255.
|
183 |
+
img[img[:, :, :, -1] == 0.] = bg_color
|
184 |
+
# TODO batch image_transforms
|
185 |
+
input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
|
186 |
+
|
187 |
+
####### nvs pose
|
188 |
+
pose_out = []
|
189 |
+
for T_out_index in range(T_out):
|
190 |
+
azimuth, polar = angles_out[T_out_index]
|
191 |
+
if CaPE_TYPE == "4DoF":
|
192 |
+
pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
|
193 |
+
elif CaPE_TYPE == "6DoF":
|
194 |
+
pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
|
195 |
+
pose = np.linalg.inv(pose)
|
196 |
+
pose[2, :] *= -1
|
197 |
+
pose_out.append(torch.from_numpy(get_pose(pose)))
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
# [B, T, C, H, W]
|
202 |
+
input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
|
203 |
+
# [B, T, 4]
|
204 |
+
pose_in = np.stack(pose_in)
|
205 |
+
pose_out = np.stack(pose_out)
|
206 |
+
|
207 |
+
if CaPE_TYPE == "6DoF":
|
208 |
+
pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
|
209 |
+
pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
|
210 |
+
pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
|
211 |
+
pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
|
212 |
+
|
213 |
+
pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
|
214 |
+
pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
|
215 |
+
|
216 |
+
input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
|
217 |
+
assert T_in == input_image.shape[0]
|
218 |
+
assert T_in == pose_in.shape[1]
|
219 |
+
assert T_out == pose_out.shape[1]
|
220 |
+
|
221 |
+
# run inference
|
222 |
+
# pipeline.to(device)
|
223 |
+
pipeline.enable_xformers_memory_efficient_attention()
|
224 |
+
image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
|
225 |
+
poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
|
226 |
+
height=h, width=w, T_in=T_in, T_out=T_out,
|
227 |
+
guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
|
228 |
+
output_type="numpy").images
|
229 |
+
|
230 |
+
# save output image
|
231 |
+
output_dir = os.path.join(tmpdirname, "eschernet")
|
232 |
+
if os.path.exists(output_dir):
|
233 |
+
shutil.rmtree(output_dir)
|
234 |
+
os.makedirs(output_dir, exist_ok=True)
|
235 |
+
# # save to N imgs
|
236 |
+
# for i in range(T_out):
|
237 |
+
# imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
|
238 |
+
# make a gif
|
239 |
+
frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
|
240 |
+
# frame_one = frames[0]
|
241 |
+
# frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
|
242 |
+
# save_all=True, duration=50, loop=1)
|
243 |
+
|
244 |
+
# get a video
|
245 |
+
video_path = os.path.join(output_dir, "output.mp4")
|
246 |
+
imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
|
247 |
+
|
248 |
+
|
249 |
+
return video_path
|
250 |
+
|
251 |
+
# TODO mesh it
|
252 |
+
@spaces.GPU(duration=120)
|
253 |
+
def make3d():
|
254 |
+
pass
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
############################ Dust3r as Pose Estimation ############################
|
259 |
+
from scipy.spatial.transform import Rotation
|
260 |
+
import copy
|
261 |
+
|
262 |
+
from dust3r.inference import inference
|
263 |
+
from dust3r.model import AsymmetricCroCo3DStereo
|
264 |
+
from dust3r.image_pairs import make_pairs
|
265 |
+
from dust3r.utils.image import load_images, rgb
|
266 |
+
from dust3r.utils.device import to_numpy
|
267 |
+
from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
|
268 |
+
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
269 |
+
import math
|
270 |
+
|
271 |
+
@spaces.GPU(duration=120)
|
272 |
+
def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
|
273 |
+
cam_color=None, as_pointcloud=False,
|
274 |
+
transparent_cams=False, silent=False, same_focals=False):
|
275 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
|
276 |
+
if not same_focals:
|
277 |
+
assert (len(cams2world) == len(focals))
|
278 |
+
pts3d = to_numpy(pts3d)
|
279 |
+
imgs = to_numpy(imgs)
|
280 |
+
focals = to_numpy(focals)
|
281 |
+
cams2world = to_numpy(cams2world)
|
282 |
+
|
283 |
+
scene = trimesh.Scene()
|
284 |
+
|
285 |
+
# add axes
|
286 |
+
scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
|
287 |
+
|
288 |
+
# full pointcloud
|
289 |
+
if as_pointcloud:
|
290 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
291 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
292 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
293 |
+
scene.add_geometry(pct)
|
294 |
+
else:
|
295 |
+
meshes = []
|
296 |
+
for i in range(len(imgs)):
|
297 |
+
meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
|
298 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
299 |
+
scene.add_geometry(mesh)
|
300 |
+
|
301 |
+
# add each camera
|
302 |
+
for i, pose_c2w in enumerate(cams2world):
|
303 |
+
if isinstance(cam_color, list):
|
304 |
+
camera_edge_color = cam_color[i]
|
305 |
+
else:
|
306 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
307 |
+
if same_focals:
|
308 |
+
focal = focals[0]
|
309 |
+
else:
|
310 |
+
focal = focals[i]
|
311 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
312 |
+
None if transparent_cams else imgs[i], focal,
|
313 |
+
imsize=imgs[i].shape[1::-1], screen_width=cam_size)
|
314 |
+
|
315 |
+
rot = np.eye(4)
|
316 |
+
rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
|
317 |
+
scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
|
318 |
+
outfile = os.path.join(outdir, 'scene.glb')
|
319 |
+
if not silent:
|
320 |
+
print('(exporting 3D scene to', outfile, ')')
|
321 |
+
scene.export(file_obj=outfile)
|
322 |
+
return outfile
|
323 |
+
|
324 |
+
@spaces.GPU(duration=120)
|
325 |
+
def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
|
326 |
+
clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
|
327 |
+
"""
|
328 |
+
extract 3D_model (glb file) from a reconstructed scene
|
329 |
+
"""
|
330 |
+
if scene is None:
|
331 |
+
return None
|
332 |
+
# post processes
|
333 |
+
if clean_depth:
|
334 |
+
scene = scene.clean_pointcloud()
|
335 |
+
if mask_sky:
|
336 |
+
scene = scene.mask_sky()
|
337 |
+
|
338 |
+
# get optimized values from scene
|
339 |
+
rgbimg = to_numpy(scene.imgs)
|
340 |
+
focals = to_numpy(scene.get_focals().cpu())
|
341 |
+
# cams2world = to_numpy(scene.get_im_poses().cpu())
|
342 |
+
# TODO use the vis_poses
|
343 |
+
cams2world = scene.vis_poses
|
344 |
+
|
345 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
346 |
+
# pts3d = to_numpy(scene.get_pts3d())
|
347 |
+
# TODO use the vis_poses
|
348 |
+
pts3d = scene.vis_pts3d
|
349 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
350 |
+
msk = to_numpy(scene.get_masks())
|
351 |
+
|
352 |
+
return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
|
353 |
+
transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
|
354 |
+
same_focals=same_focals)
|
355 |
+
|
356 |
+
@spaces.GPU(duration=120)
|
357 |
+
def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
|
358 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
|
359 |
+
scenegraph_type, winsize, refid, same_focals):
|
360 |
+
"""
|
361 |
+
from a list of images, run dust3r inference, global aligner.
|
362 |
+
then run get_3D_model_from_scene
|
363 |
+
"""
|
364 |
+
silent = False
|
365 |
+
image_size = 224
|
366 |
+
# remove the directory if it already exists
|
367 |
+
outdir = tmpdirname
|
368 |
+
if os.path.exists(outdir):
|
369 |
+
shutil.rmtree(outdir)
|
370 |
+
os.makedirs(outdir, exist_ok=True)
|
371 |
+
imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
|
372 |
+
if len(imgs) == 1:
|
373 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
374 |
+
imgs[1]['idx'] = 1
|
375 |
+
if scenegraph_type == "swin":
|
376 |
+
scenegraph_type = scenegraph_type + "-" + str(winsize)
|
377 |
+
elif scenegraph_type == "oneref":
|
378 |
+
scenegraph_type = scenegraph_type + "-" + str(refid)
|
379 |
+
|
380 |
+
pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
|
381 |
+
output = inference(pairs, model, device, batch_size=1, verbose=not silent)
|
382 |
+
|
383 |
+
mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
|
384 |
+
scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
|
385 |
+
lr = 0.01
|
386 |
+
|
387 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
388 |
+
loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
|
389 |
+
|
390 |
+
# outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
|
391 |
+
# clean_depth, transparent_cams, cam_size, same_focals=same_focals)
|
392 |
+
|
393 |
+
# also return rgb, depth and confidence imgs
|
394 |
+
# depth is normalized with the max value for all images
|
395 |
+
# we apply the jet colormap on the confidence maps
|
396 |
+
rgbimg = scene.imgs
|
397 |
+
# depths = to_numpy(scene.get_depthmaps())
|
398 |
+
# confs = to_numpy([c for c in scene.im_conf])
|
399 |
+
# cmap = pl.get_cmap('jet')
|
400 |
+
# depths_max = max([d.max() for d in depths])
|
401 |
+
# depths = [d / depths_max for d in depths]
|
402 |
+
# confs_max = max([d.max() for d in confs])
|
403 |
+
# confs = [cmap(d / confs_max) for d in confs]
|
404 |
+
|
405 |
+
imgs = []
|
406 |
+
rgbaimg = []
|
407 |
+
for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
|
408 |
+
imgs.append(rgbimg[i])
|
409 |
+
# imgs.append(rgb(depths[i]))
|
410 |
+
# imgs.append(rgb(confs[i]))
|
411 |
+
# imgs.append(imgs_rgba[i])
|
412 |
+
if len(imgs_rgba) == 1 and i == 1:
|
413 |
+
imgs.append(imgs_rgba[0])
|
414 |
+
rgbaimg.append(np.array(imgs_rgba[0]))
|
415 |
+
else:
|
416 |
+
imgs.append(imgs_rgba[i])
|
417 |
+
rgbaimg.append(np.array(imgs_rgba[i]))
|
418 |
+
|
419 |
+
rgbaimg = np.array(rgbaimg)
|
420 |
+
|
421 |
+
# for eschernet
|
422 |
+
# get optimized values from scene
|
423 |
+
rgbimg = to_numpy(scene.imgs)
|
424 |
+
# focals = to_numpy(scene.get_focals().cpu())
|
425 |
+
cams2world = to_numpy(scene.get_im_poses().cpu())
|
426 |
+
|
427 |
+
# 3D pointcloud from depthmap, poses and intrinsics
|
428 |
+
pts3d = to_numpy(scene.get_pts3d())
|
429 |
+
scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
|
430 |
+
msk = to_numpy(scene.get_masks())
|
431 |
+
obj_mask = rgbaimg[..., 3] > 0
|
432 |
+
|
433 |
+
# TODO set global coordinate system at the center of the scene, z-axis is up
|
434 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
|
435 |
+
pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
|
436 |
+
centroid = np.mean(pts_obj, axis=0) # obj center
|
437 |
+
obj2world = np.eye(4)
|
438 |
+
obj2world[:3, 3] = -centroid # T_wc
|
439 |
+
|
440 |
+
# get z_up vector
|
441 |
+
# TODO fit a plane and get the normal vector
|
442 |
+
pcd = o3d.geometry.PointCloud()
|
443 |
+
pcd.points = o3d.utility.Vector3dVector(pts)
|
444 |
+
plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
|
445 |
+
# get the normalised normal vector dim = 3
|
446 |
+
normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
|
447 |
+
# the normal direction should be pointing up
|
448 |
+
if normal[1] < 0:
|
449 |
+
normal = -normal
|
450 |
+
# print("normal", normal)
|
451 |
+
|
452 |
+
# # TODO z-up 180
|
453 |
+
# z_up = np.array([[1,0,0,0],
|
454 |
+
# [0,-1,0,0],
|
455 |
+
# [0,0,-1,0],
|
456 |
+
# [0,0,0,1]])
|
457 |
+
# obj2world = z_up @ obj2world
|
458 |
+
|
459 |
+
# # avg the y
|
460 |
+
# z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
|
461 |
+
# # import pdb; pdb.set_trace()
|
462 |
+
# rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
|
463 |
+
# rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
|
464 |
+
# rot = Rotation.from_rotvec(rot_angle * rot_axis)
|
465 |
+
# z_up = np.eye(4)
|
466 |
+
# z_up[:3, :3] = rot.as_matrix()
|
467 |
+
|
468 |
+
# get the rotation matrix from normal to z-axis
|
469 |
+
z_axis = np.array([0, 0, 1])
|
470 |
+
rot_axis = np.cross(normal, z_axis)
|
471 |
+
rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
|
472 |
+
rot = Rotation.from_rotvec(rot_angle * rot_axis)
|
473 |
+
z_up = np.eye(4)
|
474 |
+
z_up[:3, :3] = rot.as_matrix()
|
475 |
+
obj2world = z_up @ obj2world
|
476 |
+
# flip 180
|
477 |
+
flip_rot = np.array([[1, 0, 0, 0],
|
478 |
+
[0, -1, 0, 0],
|
479 |
+
[0, 0, -1, 0],
|
480 |
+
[0, 0, 0, 1]])
|
481 |
+
obj2world = flip_rot @ obj2world
|
482 |
+
|
483 |
+
# get new cams2obj
|
484 |
+
cams2obj = []
|
485 |
+
for i, cam2world in enumerate(cams2world):
|
486 |
+
cams2obj.append(obj2world @ cam2world)
|
487 |
+
# TODO transform pts3d to the new coordinate system
|
488 |
+
for i, pts in enumerate(pts3d):
|
489 |
+
pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
|
490 |
+
-1)) \
|
491 |
+
.reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
|
492 |
+
cams2world = np.array(cams2obj)
|
493 |
+
# TODO rewrite hack
|
494 |
+
scene.vis_poses = cams2world.copy()
|
495 |
+
scene.vis_pts3d = pts3d.copy()
|
496 |
+
|
497 |
+
# TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
|
498 |
+
for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
|
499 |
+
np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
|
500 |
+
pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
|
501 |
+
pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
|
502 |
+
# np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
|
503 |
+
# save the min/max radius of camera
|
504 |
+
radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
|
505 |
+
np.save(os.path.join(outdir, "radii.npy"), radii)
|
506 |
+
|
507 |
+
eschernet_input = {"poses": cams2world,
|
508 |
+
"radii": radii,
|
509 |
+
"imgs": rgbaimg}
|
510 |
+
print("got eschernet input")
|
511 |
+
outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
|
512 |
+
clean_depth, transparent_cams, cam_size, same_focals=same_focals)
|
513 |
+
|
514 |
+
return scene, outfile, imgs, eschernet_input
|
515 |
+
|
516 |
+
|
517 |
+
def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
|
518 |
+
num_files = len(inputfiles) if inputfiles is not None else 1
|
519 |
+
max_winsize = max(1, math.ceil((num_files - 1) / 2))
|
520 |
+
if scenegraph_type == "swin":
|
521 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
522 |
+
minimum=1, maximum=max_winsize, step=1, visible=True)
|
523 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
524 |
+
maximum=num_files - 1, step=1, visible=False)
|
525 |
+
elif scenegraph_type == "oneref":
|
526 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
527 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
528 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
529 |
+
maximum=num_files - 1, step=1, visible=True)
|
530 |
+
else:
|
531 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
|
532 |
+
minimum=1, maximum=max_winsize, step=1, visible=False)
|
533 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
|
534 |
+
maximum=num_files - 1, step=1, visible=False)
|
535 |
+
return winsize, refid
|
536 |
+
|
537 |
+
|
538 |
+
def get_examples(path):
|
539 |
+
objs = []
|
540 |
+
for obj_name in sorted(os.listdir(path)):
|
541 |
+
img_files = []
|
542 |
+
for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
|
543 |
+
img_files.append(os.path.join(path, obj_name, img_file))
|
544 |
+
objs.append([img_files])
|
545 |
+
print("objs = ", objs)
|
546 |
+
return objs
|
547 |
+
|
548 |
+
def preview_input(inputfiles):
|
549 |
+
if inputfiles is None:
|
550 |
+
return None
|
551 |
+
imgs = []
|
552 |
+
for img_file in inputfiles:
|
553 |
+
img = pl.imread(img_file)
|
554 |
+
imgs.append(img)
|
555 |
+
return imgs
|
556 |
+
|
557 |
+
# def main():
|
558 |
+
# dustr init
|
559 |
+
silent = False
|
560 |
+
image_size = 224
|
561 |
+
weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
|
562 |
+
model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
|
563 |
+
# dust3r will write the 3D model inside tmpdirname
|
564 |
+
# with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
|
565 |
+
tmpdirname = os.path.join('logs/user_object')
|
566 |
+
# remove the directory if it already exists
|
567 |
+
if os.path.exists(tmpdirname):
|
568 |
+
shutil.rmtree(tmpdirname)
|
569 |
+
os.makedirs(tmpdirname, exist_ok=True)
|
570 |
+
if not silent:
|
571 |
+
print('Outputing stuff in', tmpdirname)
|
572 |
+
|
573 |
+
_HEADER_ = '''
|
574 |
+
<h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
|
575 |
+
<b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
|
576 |
+
|
577 |
+
Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
|
578 |
+
|
579 |
+
<a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
|
580 |
+
<a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
|
581 |
+
<a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
|
582 |
+
|
583 |
+
<h4><b>Tips:</b></h4>
|
584 |
+
|
585 |
+
- Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
|
586 |
+
|
587 |
+
- Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
|
588 |
+
|
589 |
+
- The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
|
590 |
+
|
591 |
+
- The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
|
592 |
+
|
593 |
+
'''
|
594 |
+
|
595 |
+
_CITE_ = r"""
|
596 |
+
📝 <b>Citation</b>:
|
597 |
+
```bibtex
|
598 |
+
@article{kong2024eschernet,
|
599 |
+
title={EscherNet: A Generative Model for Scalable View Synthesis},
|
600 |
+
author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
|
601 |
+
journal={arXiv preprint arXiv:2402.03908},
|
602 |
+
year={2024}
|
603 |
+
}
|
604 |
+
```
|
605 |
+
"""
|
606 |
+
|
607 |
+
with gr.Blocks() as demo:
|
608 |
+
gr.Markdown(_HEADER_)
|
609 |
+
# mv_images = gr.State()
|
610 |
+
scene = gr.State(None)
|
611 |
+
eschernet_input = gr.State(None)
|
612 |
+
with gr.Row(variant="panel"):
|
613 |
+
# left column
|
614 |
+
with gr.Column():
|
615 |
+
with gr.Row():
|
616 |
+
input_image = gr.File(file_count="multiple")
|
617 |
+
with gr.Row():
|
618 |
+
run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
|
619 |
+
with gr.Row():
|
620 |
+
processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
|
621 |
+
with gr.Row(variant="panel"):
|
622 |
+
# input examples under "examples" folder
|
623 |
+
gr.Examples(
|
624 |
+
examples=get_examples('examples'),
|
625 |
+
inputs=[input_image],
|
626 |
+
label="Examples (click one set of images to start!)",
|
627 |
+
examples_per_page=20
|
628 |
+
)
|
629 |
+
|
630 |
+
|
631 |
+
|
632 |
+
|
633 |
+
|
634 |
+
# right column
|
635 |
+
with gr.Column():
|
636 |
+
|
637 |
+
with gr.Row():
|
638 |
+
outmodel = gr.Model3D()
|
639 |
+
|
640 |
+
with gr.Row():
|
641 |
+
gr.Markdown('''
|
642 |
+
<h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
|
643 |
+
''')
|
644 |
+
|
645 |
+
with gr.Row():
|
646 |
+
with gr.Group():
|
647 |
+
do_remove_background = gr.Checkbox(
|
648 |
+
label="Remove Background", value=True
|
649 |
+
)
|
650 |
+
sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
|
651 |
+
|
652 |
+
sample_steps = gr.Slider(
|
653 |
+
label="Sample Steps",
|
654 |
+
minimum=30,
|
655 |
+
maximum=75,
|
656 |
+
value=50,
|
657 |
+
step=5,
|
658 |
+
visible=False
|
659 |
+
)
|
660 |
+
|
661 |
+
nvs_num = gr.Slider(
|
662 |
+
label="Number of Novel Views",
|
663 |
+
minimum=5,
|
664 |
+
maximum=100,
|
665 |
+
value=30,
|
666 |
+
step=1
|
667 |
+
)
|
668 |
+
|
669 |
+
nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
|
670 |
+
value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
|
671 |
+
|
672 |
+
with gr.Row():
|
673 |
+
gr.Markdown('''
|
674 |
+
<h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
|
675 |
+
''')
|
676 |
+
|
677 |
+
with gr.Row():
|
678 |
+
submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
|
679 |
+
|
680 |
+
with gr.Row():
|
681 |
+
with gr.Column():
|
682 |
+
output_video = gr.Video(
|
683 |
+
label="video", format="mp4",
|
684 |
+
width=379,
|
685 |
+
autoplay=True,
|
686 |
+
interactive=False
|
687 |
+
)
|
688 |
+
|
689 |
+
with gr.Row():
|
690 |
+
gr.Markdown('''
|
691 |
+
<h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
|
692 |
+
''')
|
693 |
+
|
694 |
+
gr.Markdown(_CITE_)
|
695 |
+
|
696 |
+
# set dust3r parameter invisible to be clean
|
697 |
+
with gr.Column():
|
698 |
+
with gr.Row():
|
699 |
+
schedule = gr.Dropdown(["linear", "cosine"],
|
700 |
+
value='linear', label="schedule", info="For global alignment!", visible=False)
|
701 |
+
niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
|
702 |
+
label="num_iterations", info="For global alignment!", visible=False)
|
703 |
+
scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
|
704 |
+
value='complete', label="Scenegraph",
|
705 |
+
info="Define how to make pairs",
|
706 |
+
interactive=True, visible=False)
|
707 |
+
same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
|
708 |
+
winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
|
709 |
+
minimum=1, maximum=1, step=1, visible=False)
|
710 |
+
refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
|
711 |
+
|
712 |
+
with gr.Row():
|
713 |
+
# adjust the confidence threshold
|
714 |
+
min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
|
715 |
+
# adjust the camera size in the output pointcloud
|
716 |
+
cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
|
717 |
+
with gr.Row():
|
718 |
+
as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
|
719 |
+
# two post process implemented
|
720 |
+
mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
|
721 |
+
clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
|
722 |
+
transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
|
723 |
+
|
724 |
+
# events
|
725 |
+
# scenegraph_type.change(set_scenegraph_options,
|
726 |
+
# inputs=[input_image, winsize, refid, scenegraph_type],
|
727 |
+
# outputs=[winsize, refid])
|
728 |
+
# min_conf_thr.release(fn=model_from_scene_fun,
|
729 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
730 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
731 |
+
# outputs=outmodel)
|
732 |
+
# cam_size.change(fn=model_from_scene_fun,
|
733 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
734 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
735 |
+
# outputs=outmodel)
|
736 |
+
# as_pointcloud.change(fn=model_from_scene_fun,
|
737 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
738 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
739 |
+
# outputs=outmodel)
|
740 |
+
# mask_sky.change(fn=model_from_scene_fun,
|
741 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
742 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
743 |
+
# outputs=outmodel)
|
744 |
+
# clean_depth.change(fn=model_from_scene_fun,
|
745 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
746 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
747 |
+
# outputs=outmodel)
|
748 |
+
# transparent_cams.change(model_from_scene_fun,
|
749 |
+
# inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
|
750 |
+
# clean_depth, transparent_cams, cam_size, same_focals],
|
751 |
+
# outputs=outmodel)
|
752 |
+
# run_dust3r.click(fn=recon_fun,
|
753 |
+
# inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
|
754 |
+
# mask_sky, clean_depth, transparent_cams, cam_size,
|
755 |
+
# scenegraph_type, winsize, refid, same_focals],
|
756 |
+
# outputs=[scene, outmodel, processed_image, eschernet_input])
|
757 |
+
|
758 |
+
# events
|
759 |
+
input_image.change(set_scenegraph_options,
|
760 |
+
inputs=[input_image, winsize, refid, scenegraph_type],
|
761 |
+
outputs=[winsize, refid])
|
762 |
+
run_dust3r.click(fn=get_reconstructed_scene,
|
763 |
+
inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
|
764 |
+
mask_sky, clean_depth, transparent_cams, cam_size,
|
765 |
+
scenegraph_type, winsize, refid, same_focals],
|
766 |
+
outputs=[scene, outmodel, processed_image, eschernet_input])
|
767 |
+
|
768 |
+
|
769 |
+
# events
|
770 |
+
input_image.change(fn=preview_input,
|
771 |
+
inputs=[input_image],
|
772 |
+
outputs=[processed_image])
|
773 |
+
|
774 |
+
submit.click(fn=run_eschernet,
|
775 |
+
inputs=[eschernet_input, sample_steps, sample_seed,
|
776 |
+
nvs_num, nvs_mode],
|
777 |
+
outputs=[output_video])
|
778 |
+
|
779 |
+
|
780 |
+
|
781 |
+
# demo.queue(max_size=10)
|
782 |
+
# demo.launch(share=True, server_name="0.0.0.0", server_port=None)
|
783 |
+
demo.queue(max_size=10).launch()
|
784 |
+
|
785 |
+
# if __name__ == '__main__':
|
786 |
+
# main()
|
mini_dust3r/__init__.py
ADDED
File without changes
|
mini_dust3r/api/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .inference import inferece_dust3r, OptimizedResult, log_optimized_result
|
2 |
+
|
3 |
+
__all__ = ["inferece_dust3r", "OptimizedResult", "log_optimized_result"]
|
mini_dust3r/api/inference.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import rerun as rr
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Literal
|
4 |
+
import copy
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from jaxtyping import Float32, Bool
|
8 |
+
import trimesh
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from mini_dust3r.utils.image import load_images, ImageDict
|
12 |
+
from mini_dust3r.inference import inference, Dust3rResult
|
13 |
+
from mini_dust3r.model import AsymmetricCroCo3DStereo
|
14 |
+
from mini_dust3r.image_pairs import make_pairs
|
15 |
+
from mini_dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
16 |
+
from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
|
17 |
+
from mini_dust3r.viz import pts3d_to_trimesh, cat_meshes
|
18 |
+
from dataclasses import dataclass
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class OptimizedResult:
|
23 |
+
K_b33: Float32[np.ndarray, "b 3 3"]
|
24 |
+
world_T_cam_b44: Float32[np.ndarray, "b 4 4"]
|
25 |
+
rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]]
|
26 |
+
depth_hw_list: list[Float32[np.ndarray, "h w"]]
|
27 |
+
conf_hw_list: list[Float32[np.ndarray, "h w"]]
|
28 |
+
masks_list: Bool[np.ndarray, "h w"]
|
29 |
+
point_cloud: trimesh.PointCloud
|
30 |
+
mesh: trimesh.Trimesh
|
31 |
+
|
32 |
+
|
33 |
+
def log_optimized_result(
|
34 |
+
optimized_result: OptimizedResult, parent_log_path: Path
|
35 |
+
) -> None:
|
36 |
+
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
|
37 |
+
# log pointcloud
|
38 |
+
rr.log(
|
39 |
+
f"{parent_log_path}/pointcloud",
|
40 |
+
rr.Points3D(
|
41 |
+
positions=optimized_result.point_cloud.vertices,
|
42 |
+
colors=optimized_result.point_cloud.colors,
|
43 |
+
),
|
44 |
+
timeless=True,
|
45 |
+
)
|
46 |
+
|
47 |
+
mesh = optimized_result.mesh
|
48 |
+
rr.log(
|
49 |
+
f"{parent_log_path}/mesh",
|
50 |
+
rr.Mesh3D(
|
51 |
+
vertex_positions=mesh.vertices,
|
52 |
+
vertex_colors=mesh.visual.vertex_colors,
|
53 |
+
indices=mesh.faces,
|
54 |
+
),
|
55 |
+
timeless=True,
|
56 |
+
)
|
57 |
+
pbar = tqdm(
|
58 |
+
zip(
|
59 |
+
optimized_result.rgb_hw3_list,
|
60 |
+
optimized_result.depth_hw_list,
|
61 |
+
optimized_result.K_b33,
|
62 |
+
optimized_result.world_T_cam_b44,
|
63 |
+
),
|
64 |
+
total=len(optimized_result.rgb_hw3_list),
|
65 |
+
)
|
66 |
+
for i, (rgb_hw3, depth_hw, k_33, world_T_cam_44) in enumerate(pbar):
|
67 |
+
camera_log_path = f"{parent_log_path}/camera_{i}"
|
68 |
+
height, width, _ = rgb_hw3.shape
|
69 |
+
rr.log(
|
70 |
+
f"{camera_log_path}",
|
71 |
+
rr.Transform3D(
|
72 |
+
translation=world_T_cam_44[:3, 3],
|
73 |
+
mat3x3=world_T_cam_44[:3, :3],
|
74 |
+
from_parent=False,
|
75 |
+
),
|
76 |
+
)
|
77 |
+
rr.log(
|
78 |
+
f"{camera_log_path}/pinhole",
|
79 |
+
rr.Pinhole(
|
80 |
+
image_from_camera=k_33,
|
81 |
+
height=height,
|
82 |
+
width=width,
|
83 |
+
camera_xyz=rr.ViewCoordinates.RDF,
|
84 |
+
),
|
85 |
+
)
|
86 |
+
rr.log(
|
87 |
+
f"{camera_log_path}/pinhole/rgb",
|
88 |
+
rr.Image(rgb_hw3),
|
89 |
+
)
|
90 |
+
rr.log(
|
91 |
+
f"{camera_log_path}/pinhole/depth",
|
92 |
+
rr.DepthImage(depth_hw),
|
93 |
+
)
|
94 |
+
|
95 |
+
|
96 |
+
def scene_to_results(scene: BasePCOptimizer, min_conf_thr: int) -> OptimizedResult:
|
97 |
+
### get camera parameters K and T
|
98 |
+
K_b33: Float32[np.ndarray, "b 3 3"] = scene.get_intrinsics().numpy(force=True)
|
99 |
+
world_T_cam_b44: Float32[np.ndarray, "b 4 4"] = scene.get_im_poses().numpy(
|
100 |
+
force=True
|
101 |
+
)
|
102 |
+
### image, confidence, depths
|
103 |
+
rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]] = scene.imgs
|
104 |
+
depth_hw_list: list[Float32[np.ndarray, "h w"]] = [
|
105 |
+
depth.numpy(force=True) for depth in scene.get_depthmaps()
|
106 |
+
]
|
107 |
+
# normalized depth
|
108 |
+
# depth_hw_list = [depth_hw / depth_hw.max() for depth_hw in depth_hw_list]
|
109 |
+
|
110 |
+
conf_hw_list: list[Float32[np.ndarray, "h w"]] = [
|
111 |
+
c.numpy(force=True) for c in scene.im_conf
|
112 |
+
]
|
113 |
+
# normalize confidence
|
114 |
+
# conf_hw_list = [conf_hw / conf_hw.max() for conf_hw in conf_hw_list]
|
115 |
+
|
116 |
+
# point cloud, mesh
|
117 |
+
pts3d_list: list[Float32[np.ndarray, "h w 3"]] = [
|
118 |
+
pt3d.numpy(force=True) for pt3d in scene.get_pts3d()
|
119 |
+
]
|
120 |
+
# get log confidence
|
121 |
+
log_conf_trf: Float32[torch.Tensor, ""] = scene.conf_trf(torch.tensor(min_conf_thr))
|
122 |
+
# set the minimum confidence threshold
|
123 |
+
scene.min_conf_thr = float(log_conf_trf)
|
124 |
+
masks_list: Bool[np.ndarray, "h w"] = [
|
125 |
+
mask.numpy(force=True) for mask in scene.get_masks()
|
126 |
+
]
|
127 |
+
|
128 |
+
point_cloud: Float32[np.ndarray, "num_points 3"] = np.concatenate(
|
129 |
+
[p[m] for p, m in zip(pts3d_list, masks_list)]
|
130 |
+
)
|
131 |
+
colors: Float32[np.ndarray, "num_points 3"] = np.concatenate(
|
132 |
+
[p[m] for p, m in zip(rgb_hw3_list, masks_list)]
|
133 |
+
)
|
134 |
+
point_cloud = trimesh.PointCloud(
|
135 |
+
point_cloud.reshape(-1, 3), colors=colors.reshape(-1, 3)
|
136 |
+
)
|
137 |
+
|
138 |
+
meshes = []
|
139 |
+
pbar = tqdm(zip(rgb_hw3_list, pts3d_list, masks_list), total=len(rgb_hw3_list))
|
140 |
+
for rgb_hw3, pts3d, mask in pbar:
|
141 |
+
meshes.append(pts3d_to_trimesh(rgb_hw3, pts3d, mask))
|
142 |
+
|
143 |
+
mesh = trimesh.Trimesh(**cat_meshes(meshes))
|
144 |
+
optimised_result = OptimizedResult(
|
145 |
+
K_b33=K_b33,
|
146 |
+
world_T_cam_b44=world_T_cam_b44,
|
147 |
+
rgb_hw3_list=rgb_hw3_list,
|
148 |
+
depth_hw_list=depth_hw_list,
|
149 |
+
conf_hw_list=conf_hw_list,
|
150 |
+
masks_list=masks_list,
|
151 |
+
point_cloud=point_cloud,
|
152 |
+
mesh=mesh,
|
153 |
+
)
|
154 |
+
return optimised_result
|
155 |
+
|
156 |
+
|
157 |
+
def inferece_dust3r(
|
158 |
+
image_dir_or_list: Path | list[Path],
|
159 |
+
model: AsymmetricCroCo3DStereo,
|
160 |
+
device: Literal["cpu", "cuda", "mps"],
|
161 |
+
batch_size: int = 1,
|
162 |
+
image_size: Literal[224, 512] = 512,
|
163 |
+
niter: int = 100,
|
164 |
+
schedule: Literal["linear", "cosine"] = "linear",
|
165 |
+
min_conf_thr: float = 10,
|
166 |
+
) -> OptimizedResult:
|
167 |
+
"""
|
168 |
+
Perform inference using the Dust3r algorithm.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
image_dir_or_list (Union[Path, List[Path]]): Path to the directory containing images or a list of image paths.
|
172 |
+
model (AsymmetricCroCo3DStereo): The Dust3r model to use for inference.
|
173 |
+
device (Literal["cpu", "cuda", "mps"]): The device to use for inference ("cpu", "cuda", or "mps").
|
174 |
+
batch_size (int, optional): The batch size for inference. Defaults to 1.
|
175 |
+
image_size (Literal[224, 512], optional): The size of the input images. Defaults to 512.
|
176 |
+
niter (int, optional): The number of iterations for the global alignment optimization. Defaults to 100.
|
177 |
+
schedule (Literal["linear", "cosine"], optional): The learning rate schedule for the global alignment optimization. Defaults to "linear".
|
178 |
+
min_conf_thr (float, optional): The minimum confidence threshold for the optimized result. Defaults to 10.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
OptimizedResult: The optimized result containing the RGB, depth, and confidence images.
|
182 |
+
|
183 |
+
Raises:
|
184 |
+
ValueError: If `image_dir_or_list` is neither a list of paths nor a path.
|
185 |
+
"""
|
186 |
+
if isinstance(image_dir_or_list, list):
|
187 |
+
imgs: list[ImageDict] = load_images(
|
188 |
+
folder_or_list=image_dir_or_list, size=image_size, verbose=True
|
189 |
+
)
|
190 |
+
elif isinstance(image_dir_or_list, Path):
|
191 |
+
imgs: list[ImageDict] = load_images(
|
192 |
+
folder_or_list=str(image_dir_or_list), size=image_size, verbose=True
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
raise ValueError("image_dir_or_list should be a list of paths or a path")
|
196 |
+
|
197 |
+
# if only one image was loaded, duplicate it to feed into stereo network
|
198 |
+
if len(imgs) == 1:
|
199 |
+
imgs = [imgs[0], copy.deepcopy(imgs[0])]
|
200 |
+
imgs[1]["idx"] = 1
|
201 |
+
|
202 |
+
pairs: list[tuple[ImageDict, ImageDict]] = make_pairs(
|
203 |
+
imgs, scene_graph="complete", prefilter=None, symmetrize=True
|
204 |
+
)
|
205 |
+
output: Dust3rResult = inference(pairs, model, device, batch_size=batch_size)
|
206 |
+
|
207 |
+
mode = (
|
208 |
+
GlobalAlignerMode.PointCloudOptimizer
|
209 |
+
if len(imgs) > 2
|
210 |
+
else GlobalAlignerMode.PairViewer
|
211 |
+
)
|
212 |
+
scene: BasePCOptimizer = global_aligner(
|
213 |
+
dust3r_output=output, device=device, mode=mode
|
214 |
+
)
|
215 |
+
|
216 |
+
lr = 0.01
|
217 |
+
|
218 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
219 |
+
loss = scene.compute_global_alignment(
|
220 |
+
init="mst", niter=niter, schedule=schedule, lr=lr
|
221 |
+
)
|
222 |
+
|
223 |
+
# get the optimized result from the scene
|
224 |
+
optimized_result: OptimizedResult = scene_to_results(scene, min_conf_thr)
|
225 |
+
return optimized_result
|
mini_dust3r/cloud_opt/__init__.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# global alignment optimization wrapper function
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from enum import Enum
|
8 |
+
|
9 |
+
from .optimizer import PointCloudOptimizer
|
10 |
+
from .modular_optimizer import ModularPointCloudOptimizer
|
11 |
+
from .pair_viewer import PairViewer
|
12 |
+
from mini_dust3r.inference import Dust3rResult
|
13 |
+
from typing import Literal
|
14 |
+
|
15 |
+
|
16 |
+
class GlobalAlignerMode(Enum):
|
17 |
+
PointCloudOptimizer = "PointCloudOptimizer"
|
18 |
+
ModularPointCloudOptimizer = "ModularPointCloudOptimizer"
|
19 |
+
PairViewer = "PairViewer"
|
20 |
+
|
21 |
+
|
22 |
+
def global_aligner(
|
23 |
+
dust3r_output: Dust3rResult,
|
24 |
+
device: Literal["cpu", "cuda", "mps"],
|
25 |
+
mode: GlobalAlignerMode = GlobalAlignerMode.PointCloudOptimizer,
|
26 |
+
**optim_kw,
|
27 |
+
):
|
28 |
+
# extract all inputs
|
29 |
+
view1, view2, pred1, pred2 = [
|
30 |
+
dust3r_output[k] for k in "view1 view2 pred1 pred2".split()
|
31 |
+
]
|
32 |
+
# build the optimizer
|
33 |
+
if mode == GlobalAlignerMode.PointCloudOptimizer:
|
34 |
+
net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
|
35 |
+
elif mode == GlobalAlignerMode.ModularPointCloudOptimizer:
|
36 |
+
net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(
|
37 |
+
device
|
38 |
+
)
|
39 |
+
elif mode == GlobalAlignerMode.PairViewer:
|
40 |
+
net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
|
41 |
+
else:
|
42 |
+
raise NotImplementedError(f"Unknown mode {mode}")
|
43 |
+
|
44 |
+
return net
|
mini_dust3r/cloud_opt/base_opt.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Base class for the global alignement procedure
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from copy import deepcopy
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import roma
|
13 |
+
from copy import deepcopy
|
14 |
+
import tqdm
|
15 |
+
|
16 |
+
from mini_dust3r.utils.geometry import inv, geotrf
|
17 |
+
from mini_dust3r.utils.device import to_numpy
|
18 |
+
from mini_dust3r.utils.image import rgb
|
19 |
+
from mini_dust3r.viz import SceneViz, segment_sky, auto_cam_size
|
20 |
+
from mini_dust3r.optim_factory import adjust_learning_rate_by_lr
|
21 |
+
|
22 |
+
from mini_dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
|
23 |
+
cosine_schedule, linear_schedule, get_conf_trf)
|
24 |
+
import mini_dust3r.cloud_opt.init_im_poses as init_fun
|
25 |
+
|
26 |
+
|
27 |
+
class BasePCOptimizer (nn.Module):
|
28 |
+
""" Optimize a global scene, given a list of pairwise observations.
|
29 |
+
Graph node: images
|
30 |
+
Graph edges: observations = (pred1, pred2)
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, *args, **kwargs):
|
34 |
+
if len(args) == 1 and len(kwargs) == 0:
|
35 |
+
other = deepcopy(args[0])
|
36 |
+
attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
|
37 |
+
min_conf_thr conf_thr conf_i conf_j im_conf
|
38 |
+
base_scale norm_pw_scale POSE_DIM pw_poses
|
39 |
+
pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split()
|
40 |
+
self.__dict__.update({k: other[k] for k in attrs})
|
41 |
+
else:
|
42 |
+
self._init_from_views(*args, **kwargs)
|
43 |
+
|
44 |
+
def _init_from_views(self, view1, view2, pred1, pred2,
|
45 |
+
dist='l1',
|
46 |
+
conf='log',
|
47 |
+
min_conf_thr=3,
|
48 |
+
base_scale=0.5,
|
49 |
+
allow_pw_adaptors=False,
|
50 |
+
pw_break=20,
|
51 |
+
rand_pose=torch.randn,
|
52 |
+
iterationsCount=None,
|
53 |
+
verbose=True):
|
54 |
+
super().__init__()
|
55 |
+
if not isinstance(view1['idx'], list):
|
56 |
+
view1['idx'] = view1['idx'].tolist()
|
57 |
+
if not isinstance(view2['idx'], list):
|
58 |
+
view2['idx'] = view2['idx'].tolist()
|
59 |
+
self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
|
60 |
+
self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
|
61 |
+
self.dist = ALL_DISTS[dist]
|
62 |
+
self.verbose = verbose
|
63 |
+
|
64 |
+
self.n_imgs = self._check_edges()
|
65 |
+
|
66 |
+
# input data
|
67 |
+
pred1_pts = pred1['pts3d']
|
68 |
+
pred2_pts = pred2['pts3d_in_other_view']
|
69 |
+
self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
|
70 |
+
self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
|
71 |
+
self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
|
72 |
+
|
73 |
+
# work in log-scale with conf
|
74 |
+
pred1_conf = pred1['conf']
|
75 |
+
pred2_conf = pred2['conf']
|
76 |
+
self.min_conf_thr = min_conf_thr
|
77 |
+
self.conf_trf = get_conf_trf(conf)
|
78 |
+
|
79 |
+
self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
|
80 |
+
self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
|
81 |
+
self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
|
82 |
+
|
83 |
+
# pairwise pose parameters
|
84 |
+
self.base_scale = base_scale
|
85 |
+
self.norm_pw_scale = True
|
86 |
+
self.pw_break = pw_break
|
87 |
+
self.POSE_DIM = 7
|
88 |
+
self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
|
89 |
+
self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
|
90 |
+
self.pw_adaptors.requires_grad_(allow_pw_adaptors)
|
91 |
+
self.has_im_poses = False
|
92 |
+
self.rand_pose = rand_pose
|
93 |
+
|
94 |
+
# possibly store images for show_pointcloud
|
95 |
+
self.imgs = None
|
96 |
+
if 'img' in view1 and 'img' in view2:
|
97 |
+
imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
|
98 |
+
for v in range(len(self.edges)):
|
99 |
+
idx = view1['idx'][v]
|
100 |
+
imgs[idx] = view1['img'][v]
|
101 |
+
idx = view2['idx'][v]
|
102 |
+
imgs[idx] = view2['img'][v]
|
103 |
+
self.imgs = rgb(imgs)
|
104 |
+
|
105 |
+
@property
|
106 |
+
def n_edges(self):
|
107 |
+
return len(self.edges)
|
108 |
+
|
109 |
+
@property
|
110 |
+
def str_edges(self):
|
111 |
+
return [edge_str(i, j) for i, j in self.edges]
|
112 |
+
|
113 |
+
@property
|
114 |
+
def imsizes(self):
|
115 |
+
return [(w, h) for h, w in self.imshapes]
|
116 |
+
|
117 |
+
@property
|
118 |
+
def device(self):
|
119 |
+
return next(iter(self.parameters())).device
|
120 |
+
|
121 |
+
def state_dict(self, trainable=True):
|
122 |
+
all_params = super().state_dict()
|
123 |
+
return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
|
124 |
+
|
125 |
+
def load_state_dict(self, data):
|
126 |
+
return super().load_state_dict(self.state_dict(trainable=False) | data)
|
127 |
+
|
128 |
+
def _check_edges(self):
|
129 |
+
indices = sorted({i for edge in self.edges for i in edge})
|
130 |
+
assert indices == list(range(len(indices))), 'bad pair indices: missing values '
|
131 |
+
return len(indices)
|
132 |
+
|
133 |
+
@torch.no_grad()
|
134 |
+
def _compute_img_conf(self, pred1_conf, pred2_conf):
|
135 |
+
im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
|
136 |
+
for e, (i, j) in enumerate(self.edges):
|
137 |
+
im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
|
138 |
+
im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
|
139 |
+
return im_conf
|
140 |
+
|
141 |
+
def get_adaptors(self):
|
142 |
+
adapt = self.pw_adaptors
|
143 |
+
adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
|
144 |
+
if self.norm_pw_scale: # normalize so that the product == 1
|
145 |
+
adapt = adapt - adapt.mean(dim=1, keepdim=True)
|
146 |
+
return (adapt / self.pw_break).exp()
|
147 |
+
|
148 |
+
def _get_poses(self, poses):
|
149 |
+
# normalize rotation
|
150 |
+
Q = poses[:, :4]
|
151 |
+
T = signed_expm1(poses[:, 4:7])
|
152 |
+
RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
|
153 |
+
return RT
|
154 |
+
|
155 |
+
def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
|
156 |
+
# all poses == cam-to-world
|
157 |
+
pose = poses[idx]
|
158 |
+
if not (pose.requires_grad or force):
|
159 |
+
return pose
|
160 |
+
|
161 |
+
if R.shape == (4, 4):
|
162 |
+
assert T is None
|
163 |
+
T = R[:3, 3]
|
164 |
+
R = R[:3, :3]
|
165 |
+
|
166 |
+
if R is not None:
|
167 |
+
pose.data[0:4] = roma.rotmat_to_unitquat(R)
|
168 |
+
if T is not None:
|
169 |
+
pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
|
170 |
+
|
171 |
+
if scale is not None:
|
172 |
+
assert poses.shape[-1] in (8, 13)
|
173 |
+
pose.data[-1] = np.log(float(scale))
|
174 |
+
return pose
|
175 |
+
|
176 |
+
def get_pw_norm_scale_factor(self):
|
177 |
+
if self.norm_pw_scale:
|
178 |
+
# normalize scales so that things cannot go south
|
179 |
+
# we want that exp(scale) ~= self.base_scale
|
180 |
+
return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
|
181 |
+
else:
|
182 |
+
return 1 # don't norm scale for known poses
|
183 |
+
|
184 |
+
def get_pw_scale(self):
|
185 |
+
scale = self.pw_poses[:, -1].exp() # (n_edges,)
|
186 |
+
scale = scale * self.get_pw_norm_scale_factor()
|
187 |
+
return scale
|
188 |
+
|
189 |
+
def get_pw_poses(self): # cam to world
|
190 |
+
RT = self._get_poses(self.pw_poses)
|
191 |
+
scaled_RT = RT.clone()
|
192 |
+
scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
|
193 |
+
return scaled_RT
|
194 |
+
|
195 |
+
def get_masks(self):
|
196 |
+
return [(conf > self.min_conf_thr) for conf in self.im_conf]
|
197 |
+
|
198 |
+
def depth_to_pts3d(self):
|
199 |
+
raise NotImplementedError()
|
200 |
+
|
201 |
+
def get_pts3d(self, raw=False):
|
202 |
+
res = self.depth_to_pts3d()
|
203 |
+
if not raw:
|
204 |
+
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
205 |
+
return res
|
206 |
+
|
207 |
+
def _set_focal(self, idx, focal, force=False):
|
208 |
+
raise NotImplementedError()
|
209 |
+
|
210 |
+
def get_focals(self):
|
211 |
+
raise NotImplementedError()
|
212 |
+
|
213 |
+
def get_known_focal_mask(self):
|
214 |
+
raise NotImplementedError()
|
215 |
+
|
216 |
+
def get_principal_points(self):
|
217 |
+
raise NotImplementedError()
|
218 |
+
|
219 |
+
def get_conf(self, mode=None):
|
220 |
+
trf = self.conf_trf if mode is None else get_conf_trf(mode)
|
221 |
+
return [trf(c) for c in self.im_conf]
|
222 |
+
|
223 |
+
def get_im_poses(self):
|
224 |
+
raise NotImplementedError()
|
225 |
+
|
226 |
+
def _set_depthmap(self, idx, depth, force=False):
|
227 |
+
raise NotImplementedError()
|
228 |
+
|
229 |
+
def get_depthmaps(self, raw=False):
|
230 |
+
raise NotImplementedError()
|
231 |
+
|
232 |
+
@torch.no_grad()
|
233 |
+
def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
|
234 |
+
""" Method:
|
235 |
+
1) express all 3d points in each camera coordinate frame
|
236 |
+
2) if they're in front of a depthmap --> then lower their confidence
|
237 |
+
"""
|
238 |
+
assert 0 <= tol < 1
|
239 |
+
cams = inv(self.get_im_poses())
|
240 |
+
K = self.get_intrinsics()
|
241 |
+
depthmaps = self.get_depthmaps()
|
242 |
+
res = deepcopy(self)
|
243 |
+
|
244 |
+
for i, pts3d in enumerate(self.depth_to_pts3d()):
|
245 |
+
for j in range(self.n_imgs):
|
246 |
+
if i == j:
|
247 |
+
continue
|
248 |
+
|
249 |
+
# project 3dpts in other view
|
250 |
+
Hi, Wi = self.imshapes[i]
|
251 |
+
Hj, Wj = self.imshapes[j]
|
252 |
+
proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
|
253 |
+
proj_depth = proj[:, :, 2]
|
254 |
+
u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
|
255 |
+
|
256 |
+
# check which points are actually in the visible cone
|
257 |
+
msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
|
258 |
+
msk_j = v[msk_i], u[msk_i]
|
259 |
+
|
260 |
+
# find bad points = those in front but less confident
|
261 |
+
bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
|
262 |
+
) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
|
263 |
+
|
264 |
+
bad_msk_i = msk_i.clone()
|
265 |
+
bad_msk_i[msk_i] = bad_points
|
266 |
+
res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
|
267 |
+
|
268 |
+
return res
|
269 |
+
|
270 |
+
def forward(self, ret_details=False):
|
271 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
272 |
+
pw_adapt = self.get_adaptors()
|
273 |
+
proj_pts3d = self.get_pts3d()
|
274 |
+
# pre-compute pixel weights
|
275 |
+
weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
|
276 |
+
weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
|
277 |
+
|
278 |
+
loss = 0
|
279 |
+
if ret_details:
|
280 |
+
details = -torch.ones((self.n_imgs, self.n_imgs))
|
281 |
+
|
282 |
+
for e, (i, j) in enumerate(self.edges):
|
283 |
+
i_j = edge_str(i, j)
|
284 |
+
# distance in image i and j
|
285 |
+
aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
|
286 |
+
aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
|
287 |
+
li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
|
288 |
+
lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
|
289 |
+
loss = loss + li + lj
|
290 |
+
|
291 |
+
if ret_details:
|
292 |
+
details[i, j] = li + lj
|
293 |
+
loss /= self.n_edges # average over all pairs
|
294 |
+
|
295 |
+
if ret_details:
|
296 |
+
return loss, details
|
297 |
+
return loss
|
298 |
+
|
299 |
+
@torch.cuda.amp.autocast(enabled=False)
|
300 |
+
def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
|
301 |
+
if init is None:
|
302 |
+
pass
|
303 |
+
elif init == 'msp' or init == 'mst':
|
304 |
+
init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
|
305 |
+
elif init == 'known_poses':
|
306 |
+
init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr,
|
307 |
+
niter_PnP=niter_PnP)
|
308 |
+
else:
|
309 |
+
raise ValueError(f'bad value for {init=}')
|
310 |
+
|
311 |
+
return global_alignment_loop(self, **kw)
|
312 |
+
|
313 |
+
@torch.no_grad()
|
314 |
+
def mask_sky(self):
|
315 |
+
res = deepcopy(self)
|
316 |
+
for i in range(self.n_imgs):
|
317 |
+
sky = segment_sky(self.imgs[i])
|
318 |
+
res.im_conf[i][sky] = 0
|
319 |
+
return res
|
320 |
+
|
321 |
+
def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
|
322 |
+
viz = SceneViz()
|
323 |
+
if self.imgs is None:
|
324 |
+
colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
|
325 |
+
colors = list(map(tuple, colors.tolist()))
|
326 |
+
for n in range(self.n_imgs):
|
327 |
+
viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
|
328 |
+
else:
|
329 |
+
viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
|
330 |
+
colors = np.random.randint(256, size=(self.n_imgs, 3))
|
331 |
+
|
332 |
+
# camera poses
|
333 |
+
im_poses = to_numpy(self.get_im_poses())
|
334 |
+
if cam_size is None:
|
335 |
+
cam_size = auto_cam_size(im_poses)
|
336 |
+
viz.add_cameras(im_poses, self.get_focals(), colors=colors,
|
337 |
+
images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
|
338 |
+
if show_pw_cams:
|
339 |
+
pw_poses = self.get_pw_poses()
|
340 |
+
viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
|
341 |
+
|
342 |
+
if show_pw_pts3d:
|
343 |
+
pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
|
344 |
+
viz.add_pointcloud(pts, (128, 0, 128))
|
345 |
+
|
346 |
+
viz.show(**kw)
|
347 |
+
return viz
|
348 |
+
|
349 |
+
|
350 |
+
def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6):
|
351 |
+
params = [p for p in net.parameters() if p.requires_grad]
|
352 |
+
if not params:
|
353 |
+
return net
|
354 |
+
|
355 |
+
verbose = net.verbose
|
356 |
+
if verbose:
|
357 |
+
print('Global alignement - optimizing for:')
|
358 |
+
print([name for name, value in net.named_parameters() if value.requires_grad])
|
359 |
+
|
360 |
+
lr_base = lr
|
361 |
+
optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
|
362 |
+
|
363 |
+
loss = float('inf')
|
364 |
+
if verbose:
|
365 |
+
with tqdm.tqdm(total=niter) as bar:
|
366 |
+
while bar.n < bar.total:
|
367 |
+
loss = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule)
|
368 |
+
bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
|
369 |
+
bar.update()
|
370 |
+
else:
|
371 |
+
for n in range(niter):
|
372 |
+
loss = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule)
|
373 |
+
return loss
|
374 |
+
|
375 |
+
|
376 |
+
def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
|
377 |
+
t = cur_iter / niter
|
378 |
+
if schedule == 'cosine':
|
379 |
+
lr = cosine_schedule(t, lr_base, lr_min)
|
380 |
+
elif schedule == 'linear':
|
381 |
+
lr = linear_schedule(t, lr_base, lr_min)
|
382 |
+
else:
|
383 |
+
raise ValueError(f'bad lr {schedule=}')
|
384 |
+
adjust_learning_rate_by_lr(optimizer, lr)
|
385 |
+
optimizer.zero_grad()
|
386 |
+
loss = net()
|
387 |
+
loss.backward()
|
388 |
+
optimizer.step()
|
389 |
+
|
390 |
+
return float(loss)
|
mini_dust3r/cloud_opt/commons.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utility functions for global alignment
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
def edge_str(i, j):
|
13 |
+
return f'{i}_{j}'
|
14 |
+
|
15 |
+
|
16 |
+
def i_j_ij(ij):
|
17 |
+
return edge_str(*ij), ij
|
18 |
+
|
19 |
+
|
20 |
+
def edge_conf(conf_i, conf_j, edge):
|
21 |
+
return float(conf_i[edge].mean() * conf_j[edge].mean())
|
22 |
+
|
23 |
+
|
24 |
+
def compute_edge_scores(edges, conf_i, conf_j):
|
25 |
+
return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
|
26 |
+
|
27 |
+
|
28 |
+
def NoGradParamDict(x):
|
29 |
+
assert isinstance(x, dict)
|
30 |
+
return nn.ParameterDict(x).requires_grad_(False)
|
31 |
+
|
32 |
+
|
33 |
+
def get_imshapes(edges, pred_i, pred_j):
|
34 |
+
n_imgs = max(max(e) for e in edges) + 1
|
35 |
+
imshapes = [None] * n_imgs
|
36 |
+
for e, (i, j) in enumerate(edges):
|
37 |
+
shape_i = tuple(pred_i[e].shape[0:2])
|
38 |
+
shape_j = tuple(pred_j[e].shape[0:2])
|
39 |
+
if imshapes[i]:
|
40 |
+
assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
|
41 |
+
if imshapes[j]:
|
42 |
+
assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
|
43 |
+
imshapes[i] = shape_i
|
44 |
+
imshapes[j] = shape_j
|
45 |
+
return imshapes
|
46 |
+
|
47 |
+
|
48 |
+
def get_conf_trf(mode):
|
49 |
+
if mode == 'log':
|
50 |
+
def conf_trf(x): return x.log()
|
51 |
+
elif mode == 'sqrt':
|
52 |
+
def conf_trf(x): return x.sqrt()
|
53 |
+
elif mode == 'm1':
|
54 |
+
def conf_trf(x): return x-1
|
55 |
+
elif mode in ('id', 'none'):
|
56 |
+
def conf_trf(x): return x
|
57 |
+
else:
|
58 |
+
raise ValueError(f'bad mode for {mode=}')
|
59 |
+
return conf_trf
|
60 |
+
|
61 |
+
|
62 |
+
def l2_dist(a, b, weight):
|
63 |
+
return ((a - b).square().sum(dim=-1) * weight)
|
64 |
+
|
65 |
+
|
66 |
+
def l1_dist(a, b, weight):
|
67 |
+
return ((a - b).norm(dim=-1) * weight)
|
68 |
+
|
69 |
+
|
70 |
+
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
|
71 |
+
|
72 |
+
|
73 |
+
def signed_log1p(x):
|
74 |
+
sign = torch.sign(x)
|
75 |
+
return sign * torch.log1p(torch.abs(x))
|
76 |
+
|
77 |
+
|
78 |
+
def signed_expm1(x):
|
79 |
+
sign = torch.sign(x)
|
80 |
+
return sign * torch.expm1(torch.abs(x))
|
81 |
+
|
82 |
+
|
83 |
+
def cosine_schedule(t, lr_start, lr_end):
|
84 |
+
assert 0 <= t <= 1
|
85 |
+
return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
|
86 |
+
|
87 |
+
|
88 |
+
def linear_schedule(t, lr_start, lr_end):
|
89 |
+
assert 0 <= t <= 1
|
90 |
+
return lr_start + (lr_end - lr_start) * t
|
mini_dust3r/cloud_opt/init_im_poses.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Initialization functions for global alignment
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from functools import cache
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import scipy.sparse as sp
|
11 |
+
import torch
|
12 |
+
import cv2
|
13 |
+
import roma
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from mini_dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
|
17 |
+
from mini_dust3r.post_process import estimate_focal_knowing_depth
|
18 |
+
from mini_dust3r.viz import to_numpy
|
19 |
+
|
20 |
+
from mini_dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
|
21 |
+
|
22 |
+
|
23 |
+
@torch.no_grad()
|
24 |
+
def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
|
25 |
+
device = self.device
|
26 |
+
|
27 |
+
# indices of known poses
|
28 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
29 |
+
assert nkp == self.n_imgs, 'not all poses are known'
|
30 |
+
|
31 |
+
# get all focals
|
32 |
+
nkf, _, im_focals = get_known_focals(self)
|
33 |
+
assert nkf == self.n_imgs
|
34 |
+
im_pp = self.get_principal_points()
|
35 |
+
|
36 |
+
best_depthmaps = {}
|
37 |
+
# init all pairwise poses
|
38 |
+
for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)):
|
39 |
+
i_j = edge_str(i, j)
|
40 |
+
|
41 |
+
# find relative pose for this pair
|
42 |
+
P1 = torch.eye(4, device=device)
|
43 |
+
msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
|
44 |
+
_, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
|
45 |
+
pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
|
46 |
+
|
47 |
+
# align the two predicted camera with the two gt cameras
|
48 |
+
s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
|
49 |
+
# normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
|
50 |
+
# and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
|
51 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
52 |
+
|
53 |
+
# remember if this is a good depthmap
|
54 |
+
score = float(self.conf_i[i_j].mean())
|
55 |
+
if score > best_depthmaps.get(i, (0,))[0]:
|
56 |
+
best_depthmaps[i] = score, i_j, s
|
57 |
+
|
58 |
+
# init all image poses
|
59 |
+
for n in range(self.n_imgs):
|
60 |
+
assert known_poses_msk[n]
|
61 |
+
_, i_j, scale = best_depthmaps[n]
|
62 |
+
depth = self.pred_i[i_j][:, :, 2]
|
63 |
+
self._set_depthmap(n, depth * scale)
|
64 |
+
|
65 |
+
|
66 |
+
@torch.no_grad()
|
67 |
+
def init_minimum_spanning_tree(self, **kw):
|
68 |
+
""" Init all camera poses (image-wise and pairwise poses) given
|
69 |
+
an initial set of pairwise estimations.
|
70 |
+
"""
|
71 |
+
device = self.device
|
72 |
+
pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
|
73 |
+
self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
|
74 |
+
device, has_im_poses=self.has_im_poses, verbose=self.verbose,
|
75 |
+
**kw)
|
76 |
+
|
77 |
+
return init_from_pts3d(self, pts3d, im_focals, im_poses)
|
78 |
+
|
79 |
+
|
80 |
+
def init_from_pts3d(self, pts3d, im_focals, im_poses):
|
81 |
+
# init poses
|
82 |
+
nkp, known_poses_msk, known_poses = get_known_poses(self)
|
83 |
+
if nkp == 1:
|
84 |
+
raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
|
85 |
+
elif nkp > 1:
|
86 |
+
# global rigid SE3 alignment
|
87 |
+
s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
|
88 |
+
trf = sRT_to_4x4(s, R, T, device=known_poses.device)
|
89 |
+
|
90 |
+
# rotate everything
|
91 |
+
im_poses = trf @ im_poses
|
92 |
+
im_poses[:, :3, :3] /= s # undo scaling on the rotation part
|
93 |
+
for img_pts3d in pts3d:
|
94 |
+
img_pts3d[:] = geotrf(trf, img_pts3d)
|
95 |
+
|
96 |
+
# set all pairwise poses
|
97 |
+
for e, (i, j) in enumerate(self.edges):
|
98 |
+
i_j = edge_str(i, j)
|
99 |
+
# compute transform that goes from cam to world
|
100 |
+
s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j])
|
101 |
+
self._set_pose(self.pw_poses, e, R, T, scale=s)
|
102 |
+
|
103 |
+
# take into account the scale normalization
|
104 |
+
s_factor = self.get_pw_norm_scale_factor()
|
105 |
+
im_poses[:, :3, 3] *= s_factor # apply downscaling factor
|
106 |
+
for img_pts3d in pts3d:
|
107 |
+
img_pts3d *= s_factor
|
108 |
+
|
109 |
+
# init all image poses
|
110 |
+
if self.has_im_poses:
|
111 |
+
for i in range(self.n_imgs):
|
112 |
+
cam2world = im_poses[i]
|
113 |
+
depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
|
114 |
+
self._set_depthmap(i, depth)
|
115 |
+
self._set_pose(self.im_poses, i, cam2world)
|
116 |
+
if im_focals[i] is not None:
|
117 |
+
self._set_focal(i, im_focals[i])
|
118 |
+
|
119 |
+
if self.verbose:
|
120 |
+
print(' init loss =', float(self()))
|
121 |
+
|
122 |
+
|
123 |
+
def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
|
124 |
+
device, has_im_poses=True, niter_PnP=10, verbose=True):
|
125 |
+
n_imgs = len(imshapes)
|
126 |
+
sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j))
|
127 |
+
msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
|
128 |
+
|
129 |
+
# temp variable to store 3d points
|
130 |
+
pts3d = [None] * len(imshapes)
|
131 |
+
|
132 |
+
todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
|
133 |
+
im_poses = [None] * n_imgs
|
134 |
+
im_focals = [None] * n_imgs
|
135 |
+
|
136 |
+
# init with strongest edge
|
137 |
+
score, i, j = todo.pop()
|
138 |
+
if verbose:
|
139 |
+
print(f' init edge ({i}*,{j}*) {score=}')
|
140 |
+
i_j = edge_str(i, j)
|
141 |
+
pts3d[i] = pred_i[i_j].clone()
|
142 |
+
pts3d[j] = pred_j[i_j].clone()
|
143 |
+
done = {i, j}
|
144 |
+
if has_im_poses:
|
145 |
+
im_poses[i] = torch.eye(4, device=device)
|
146 |
+
im_focals[i] = estimate_focal(pred_i[i_j])
|
147 |
+
|
148 |
+
# set initial pointcloud based on pairwise graph
|
149 |
+
msp_edges = [(i, j)]
|
150 |
+
while todo:
|
151 |
+
# each time, predict the next one
|
152 |
+
score, i, j = todo.pop()
|
153 |
+
|
154 |
+
if im_focals[i] is None:
|
155 |
+
im_focals[i] = estimate_focal(pred_i[i_j])
|
156 |
+
|
157 |
+
if i in done:
|
158 |
+
if verbose:
|
159 |
+
print(f' init edge ({i},{j}*) {score=}')
|
160 |
+
assert j not in done
|
161 |
+
# align pred[i] with pts3d[i], and then set j accordingly
|
162 |
+
i_j = edge_str(i, j)
|
163 |
+
s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
|
164 |
+
trf = sRT_to_4x4(s, R, T, device)
|
165 |
+
pts3d[j] = geotrf(trf, pred_j[i_j])
|
166 |
+
done.add(j)
|
167 |
+
msp_edges.append((i, j))
|
168 |
+
|
169 |
+
if has_im_poses and im_poses[i] is None:
|
170 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
171 |
+
|
172 |
+
elif j in done:
|
173 |
+
if verbose:
|
174 |
+
print(f' init edge ({i}*,{j}) {score=}')
|
175 |
+
assert i not in done
|
176 |
+
i_j = edge_str(i, j)
|
177 |
+
s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
|
178 |
+
trf = sRT_to_4x4(s, R, T, device)
|
179 |
+
pts3d[i] = geotrf(trf, pred_i[i_j])
|
180 |
+
done.add(i)
|
181 |
+
msp_edges.append((i, j))
|
182 |
+
|
183 |
+
if has_im_poses and im_poses[i] is None:
|
184 |
+
im_poses[i] = sRT_to_4x4(1, R, T, device)
|
185 |
+
else:
|
186 |
+
# let's try again later
|
187 |
+
todo.insert(0, (score, i, j))
|
188 |
+
|
189 |
+
if has_im_poses:
|
190 |
+
# complete all missing informations
|
191 |
+
pair_scores = list(sparse_graph.values()) # already negative scores: less is best
|
192 |
+
edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
|
193 |
+
for i, j in edges_from_best_to_worse.tolist():
|
194 |
+
if im_focals[i] is None:
|
195 |
+
im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
|
196 |
+
|
197 |
+
for i in range(n_imgs):
|
198 |
+
if im_poses[i] is None:
|
199 |
+
msk = im_conf[i] > min_conf_thr
|
200 |
+
res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
|
201 |
+
if res:
|
202 |
+
im_focals[i], im_poses[i] = res
|
203 |
+
if im_poses[i] is None:
|
204 |
+
im_poses[i] = torch.eye(4, device=device)
|
205 |
+
im_poses = torch.stack(im_poses)
|
206 |
+
else:
|
207 |
+
im_poses = im_focals = None
|
208 |
+
|
209 |
+
return pts3d, msp_edges, im_focals, im_poses
|
210 |
+
|
211 |
+
|
212 |
+
def dict_to_sparse_graph(dic):
|
213 |
+
n_imgs = max(max(e) for e in dic) + 1
|
214 |
+
res = sp.dok_array((n_imgs, n_imgs))
|
215 |
+
for edge, value in dic.items():
|
216 |
+
res[edge] = value
|
217 |
+
return res
|
218 |
+
|
219 |
+
|
220 |
+
def rigid_points_registration(pts1, pts2, conf):
|
221 |
+
R, T, s = roma.rigid_points_registration(
|
222 |
+
pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
|
223 |
+
return s, R, T # return un-scaled (R, T)
|
224 |
+
|
225 |
+
|
226 |
+
def sRT_to_4x4(scale, R, T, device):
|
227 |
+
trf = torch.eye(4, device=device)
|
228 |
+
trf[:3, :3] = R * scale
|
229 |
+
trf[:3, 3] = T.ravel() # doesn't need scaling
|
230 |
+
return trf
|
231 |
+
|
232 |
+
|
233 |
+
def estimate_focal(pts3d_i, pp=None):
|
234 |
+
if pp is None:
|
235 |
+
H, W, THREE = pts3d_i.shape
|
236 |
+
assert THREE == 3
|
237 |
+
pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
|
238 |
+
focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
|
239 |
+
return float(focal)
|
240 |
+
|
241 |
+
|
242 |
+
@cache
|
243 |
+
def pixel_grid(H, W):
|
244 |
+
return np.mgrid[:W, :H].T.astype(np.float32)
|
245 |
+
|
246 |
+
|
247 |
+
def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
|
248 |
+
# extract camera poses and focals with RANSAC-PnP
|
249 |
+
if msk.sum() < 4:
|
250 |
+
return None # we need at least 4 points for PnP
|
251 |
+
pts3d, msk = map(to_numpy, (pts3d, msk))
|
252 |
+
|
253 |
+
H, W, THREE = pts3d.shape
|
254 |
+
assert THREE == 3
|
255 |
+
pixels = pixel_grid(H, W)
|
256 |
+
|
257 |
+
if focal is None:
|
258 |
+
S = max(W, H)
|
259 |
+
tentative_focals = np.geomspace(S/2, S*3, 21)
|
260 |
+
else:
|
261 |
+
tentative_focals = [focal]
|
262 |
+
|
263 |
+
if pp is None:
|
264 |
+
pp = (W/2, H/2)
|
265 |
+
else:
|
266 |
+
pp = to_numpy(pp)
|
267 |
+
|
268 |
+
best = 0,
|
269 |
+
for focal in tentative_focals:
|
270 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
271 |
+
|
272 |
+
success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
|
273 |
+
iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
|
274 |
+
if not success:
|
275 |
+
continue
|
276 |
+
|
277 |
+
score = len(inliers)
|
278 |
+
if success and score > best[0]:
|
279 |
+
best = score, R, T, focal
|
280 |
+
|
281 |
+
if not best[0]:
|
282 |
+
return None
|
283 |
+
|
284 |
+
_, R, T, best_focal = best
|
285 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
286 |
+
R, T = map(torch.from_numpy, (R, T))
|
287 |
+
return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
|
288 |
+
|
289 |
+
|
290 |
+
def get_known_poses(self):
|
291 |
+
if self.has_im_poses:
|
292 |
+
known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
|
293 |
+
known_poses = self.get_im_poses()
|
294 |
+
return known_poses_msk.sum(), known_poses_msk, known_poses
|
295 |
+
else:
|
296 |
+
return 0, None, None
|
297 |
+
|
298 |
+
|
299 |
+
def get_known_focals(self):
|
300 |
+
if self.has_im_poses:
|
301 |
+
known_focal_msk = self.get_known_focal_mask()
|
302 |
+
known_focals = self.get_focals()
|
303 |
+
return known_focal_msk.sum(), known_focal_msk, known_focals
|
304 |
+
else:
|
305 |
+
return 0, None, None
|
306 |
+
|
307 |
+
|
308 |
+
def align_multiple_poses(src_poses, target_poses):
|
309 |
+
N = len(src_poses)
|
310 |
+
assert src_poses.shape == target_poses.shape == (N, 4, 4)
|
311 |
+
|
312 |
+
def center_and_z(poses):
|
313 |
+
eps = get_med_dist_between_poses(poses) / 100
|
314 |
+
return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
|
315 |
+
R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
|
316 |
+
return s, R, T
|
mini_dust3r/cloud_opt/modular_optimizer.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Slower implementation of the global alignment that allows to freeze partial poses/intrinsics
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
|
12 |
+
from mini_dust3r.utils.geometry import geotrf
|
13 |
+
from mini_dust3r.utils.device import to_cpu, to_numpy
|
14 |
+
from mini_dust3r.utils.geometry import depthmap_to_pts3d
|
15 |
+
|
16 |
+
|
17 |
+
class ModularPointCloudOptimizer (BasePCOptimizer):
|
18 |
+
""" Optimize a global scene, given a list of pairwise observations.
|
19 |
+
Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics)
|
20 |
+
Graph node: images
|
21 |
+
Graph edges: observations = (pred1, pred2)
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs):
|
25 |
+
super().__init__(*args, **kwargs)
|
26 |
+
self.has_im_poses = True # by definition of this class
|
27 |
+
self.focal_brake = focal_brake
|
28 |
+
|
29 |
+
# adding thing to optimize
|
30 |
+
self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
|
31 |
+
self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
|
32 |
+
default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes]
|
33 |
+
self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [
|
34 |
+
f]) for f in default_focals) # camera intrinsics
|
35 |
+
self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
|
36 |
+
self.im_pp.requires_grad_(optimize_pp)
|
37 |
+
|
38 |
+
def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
|
39 |
+
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
|
40 |
+
known_poses = [known_poses]
|
41 |
+
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
|
42 |
+
if self.verbose:
|
43 |
+
print(f' (setting pose #{idx} = {pose[:3,3]})')
|
44 |
+
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True))
|
45 |
+
|
46 |
+
# normalize scale if there's less than 1 known pose
|
47 |
+
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
|
48 |
+
self.norm_pw_scale = (n_known_poses <= 1)
|
49 |
+
|
50 |
+
def preset_intrinsics(self, known_intrinsics, msk=None):
|
51 |
+
if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2:
|
52 |
+
known_intrinsics = [known_intrinsics]
|
53 |
+
for K in known_intrinsics:
|
54 |
+
assert K.shape == (3, 3)
|
55 |
+
self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk)
|
56 |
+
self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk)
|
57 |
+
|
58 |
+
def preset_focal(self, known_focals, msk=None):
|
59 |
+
for idx, focal in zip(self._get_msk_indices(msk), known_focals):
|
60 |
+
if self.verbose:
|
61 |
+
print(f' (setting focal #{idx} = {focal})')
|
62 |
+
self._no_grad(self._set_focal(idx, focal, force=True))
|
63 |
+
|
64 |
+
def preset_principal_point(self, known_pp, msk=None):
|
65 |
+
for idx, pp in zip(self._get_msk_indices(msk), known_pp):
|
66 |
+
if self.verbose:
|
67 |
+
print(f' (setting principal point #{idx} = {pp})')
|
68 |
+
self._no_grad(self._set_principal_point(idx, pp, force=True))
|
69 |
+
|
70 |
+
def _no_grad(self, tensor):
|
71 |
+
return tensor.requires_grad_(False)
|
72 |
+
|
73 |
+
def _get_msk_indices(self, msk):
|
74 |
+
if msk is None:
|
75 |
+
return range(self.n_imgs)
|
76 |
+
elif isinstance(msk, int):
|
77 |
+
return [msk]
|
78 |
+
elif isinstance(msk, (tuple, list)):
|
79 |
+
return self._get_msk_indices(np.array(msk))
|
80 |
+
elif msk.dtype in (bool, torch.bool, np.bool_):
|
81 |
+
assert len(msk) == self.n_imgs
|
82 |
+
return np.where(msk)[0]
|
83 |
+
elif np.issubdtype(msk.dtype, np.integer):
|
84 |
+
return msk
|
85 |
+
else:
|
86 |
+
raise ValueError(f'bad {msk=}')
|
87 |
+
|
88 |
+
def _set_focal(self, idx, focal, force=False):
|
89 |
+
param = self.im_focals[idx]
|
90 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
91 |
+
param.data[:] = self.focal_brake * np.log(focal)
|
92 |
+
return param
|
93 |
+
|
94 |
+
def get_focals(self):
|
95 |
+
log_focals = torch.stack(list(self.im_focals), dim=0)
|
96 |
+
return (log_focals / self.focal_brake).exp()
|
97 |
+
|
98 |
+
def _set_principal_point(self, idx, pp, force=False):
|
99 |
+
param = self.im_pp[idx]
|
100 |
+
H, W = self.imshapes[idx]
|
101 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
102 |
+
param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
|
103 |
+
return param
|
104 |
+
|
105 |
+
def get_principal_points(self):
|
106 |
+
return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)])
|
107 |
+
|
108 |
+
def get_intrinsics(self):
|
109 |
+
K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
|
110 |
+
focals = self.get_focals().view(self.n_imgs, -1)
|
111 |
+
K[:, 0, 0] = focals[:, 0]
|
112 |
+
K[:, 1, 1] = focals[:, -1]
|
113 |
+
K[:, :2, 2] = self.get_principal_points()
|
114 |
+
K[:, 2, 2] = 1
|
115 |
+
return K
|
116 |
+
|
117 |
+
def get_im_poses(self): # cam to world
|
118 |
+
cam2world = self._get_poses(torch.stack(list(self.im_poses)))
|
119 |
+
return cam2world
|
120 |
+
|
121 |
+
def _set_depthmap(self, idx, depth, force=False):
|
122 |
+
param = self.im_depthmaps[idx]
|
123 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
124 |
+
param.data[:] = depth.log().nan_to_num(neginf=0)
|
125 |
+
return param
|
126 |
+
|
127 |
+
def get_depthmaps(self):
|
128 |
+
return [d.exp() for d in self.im_depthmaps]
|
129 |
+
|
130 |
+
def depth_to_pts3d(self):
|
131 |
+
# Get depths and projection params if not provided
|
132 |
+
focals = self.get_focals()
|
133 |
+
pp = self.get_principal_points()
|
134 |
+
im_poses = self.get_im_poses()
|
135 |
+
depth = self.get_depthmaps()
|
136 |
+
|
137 |
+
# convert focal to (1,2,H,W) constant field
|
138 |
+
def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i])
|
139 |
+
# get pointmaps in camera frame
|
140 |
+
rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])]
|
141 |
+
# project to world frame
|
142 |
+
return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)]
|
143 |
+
|
144 |
+
def get_pts3d(self):
|
145 |
+
return self.depth_to_pts3d()
|
mini_dust3r/cloud_opt/optimizer.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Main class for the implementation of the global alignment
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
|
12 |
+
from mini_dust3r.utils.geometry import xy_grid, geotrf
|
13 |
+
from mini_dust3r.utils.device import to_cpu, to_numpy
|
14 |
+
|
15 |
+
|
16 |
+
class PointCloudOptimizer(BasePCOptimizer):
|
17 |
+
""" Optimize a global scene, given a list of pairwise observations.
|
18 |
+
Graph node: images
|
19 |
+
Graph edges: observations = (pred1, pred2)
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
|
23 |
+
super().__init__(*args, **kwargs)
|
24 |
+
|
25 |
+
self.has_im_poses = True # by definition of this class
|
26 |
+
self.focal_break = focal_break
|
27 |
+
|
28 |
+
# adding thing to optimize
|
29 |
+
self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
|
30 |
+
self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
|
31 |
+
self.im_focals = nn.ParameterList(torch.FloatTensor(
|
32 |
+
[self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
|
33 |
+
self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
|
34 |
+
self.im_pp.requires_grad_(optimize_pp)
|
35 |
+
|
36 |
+
self.imshape = self.imshapes[0]
|
37 |
+
im_areas = [h*w for h, w in self.imshapes]
|
38 |
+
self.max_area = max(im_areas)
|
39 |
+
|
40 |
+
# adding thing to optimize
|
41 |
+
self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
|
42 |
+
self.im_poses = ParameterStack(self.im_poses, is_param=True)
|
43 |
+
self.im_focals = ParameterStack(self.im_focals, is_param=True)
|
44 |
+
self.im_pp = ParameterStack(self.im_pp, is_param=True)
|
45 |
+
self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
|
46 |
+
self.register_buffer('_grid', ParameterStack(
|
47 |
+
[xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
|
48 |
+
|
49 |
+
# pre-compute pixel weights
|
50 |
+
self.register_buffer('_weight_i', ParameterStack(
|
51 |
+
[self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
52 |
+
self.register_buffer('_weight_j', ParameterStack(
|
53 |
+
[self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
54 |
+
|
55 |
+
# precompute aa
|
56 |
+
self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
|
57 |
+
self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
|
58 |
+
self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
|
59 |
+
self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
|
60 |
+
self.total_area_i = sum([im_areas[i] for i, j in self.edges])
|
61 |
+
self.total_area_j = sum([im_areas[j] for i, j in self.edges])
|
62 |
+
|
63 |
+
def _check_all_imgs_are_selected(self, msk):
|
64 |
+
assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
|
65 |
+
|
66 |
+
def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
|
67 |
+
self._check_all_imgs_are_selected(pose_msk)
|
68 |
+
|
69 |
+
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
|
70 |
+
known_poses = [known_poses]
|
71 |
+
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
|
72 |
+
if self.verbose:
|
73 |
+
print(f' (setting pose #{idx} = {pose[:3,3]})')
|
74 |
+
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
|
75 |
+
|
76 |
+
# normalize scale if there's less than 1 known pose
|
77 |
+
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
|
78 |
+
self.norm_pw_scale = (n_known_poses <= 1)
|
79 |
+
|
80 |
+
self.im_poses.requires_grad_(False)
|
81 |
+
self.norm_pw_scale = False
|
82 |
+
|
83 |
+
def preset_focal(self, known_focals, msk=None):
|
84 |
+
self._check_all_imgs_are_selected(msk)
|
85 |
+
|
86 |
+
for idx, focal in zip(self._get_msk_indices(msk), known_focals):
|
87 |
+
if self.verbose:
|
88 |
+
print(f' (setting focal #{idx} = {focal})')
|
89 |
+
self._no_grad(self._set_focal(idx, focal))
|
90 |
+
|
91 |
+
self.im_focals.requires_grad_(False)
|
92 |
+
|
93 |
+
def preset_principal_point(self, known_pp, msk=None):
|
94 |
+
self._check_all_imgs_are_selected(msk)
|
95 |
+
|
96 |
+
for idx, pp in zip(self._get_msk_indices(msk), known_pp):
|
97 |
+
if self.verbose:
|
98 |
+
print(f' (setting principal point #{idx} = {pp})')
|
99 |
+
self._no_grad(self._set_principal_point(idx, pp))
|
100 |
+
|
101 |
+
self.im_pp.requires_grad_(False)
|
102 |
+
|
103 |
+
def _get_msk_indices(self, msk):
|
104 |
+
if msk is None:
|
105 |
+
return range(self.n_imgs)
|
106 |
+
elif isinstance(msk, int):
|
107 |
+
return [msk]
|
108 |
+
elif isinstance(msk, (tuple, list)):
|
109 |
+
return self._get_msk_indices(np.array(msk))
|
110 |
+
elif msk.dtype in (bool, torch.bool, np.bool_):
|
111 |
+
assert len(msk) == self.n_imgs
|
112 |
+
return np.where(msk)[0]
|
113 |
+
elif np.issubdtype(msk.dtype, np.integer):
|
114 |
+
return msk
|
115 |
+
else:
|
116 |
+
raise ValueError(f'bad {msk=}')
|
117 |
+
|
118 |
+
def _no_grad(self, tensor):
|
119 |
+
assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
|
120 |
+
|
121 |
+
def _set_focal(self, idx, focal, force=False):
|
122 |
+
param = self.im_focals[idx]
|
123 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
124 |
+
param.data[:] = self.focal_break * np.log(focal)
|
125 |
+
return param
|
126 |
+
|
127 |
+
def get_focals(self):
|
128 |
+
log_focals = torch.stack(list(self.im_focals), dim=0)
|
129 |
+
return (log_focals / self.focal_break).exp()
|
130 |
+
|
131 |
+
def get_known_focal_mask(self):
|
132 |
+
return torch.tensor([not (p.requires_grad) for p in self.im_focals])
|
133 |
+
|
134 |
+
def _set_principal_point(self, idx, pp, force=False):
|
135 |
+
param = self.im_pp[idx]
|
136 |
+
H, W = self.imshapes[idx]
|
137 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
138 |
+
param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
|
139 |
+
return param
|
140 |
+
|
141 |
+
def get_principal_points(self):
|
142 |
+
return self._pp + 10 * self.im_pp
|
143 |
+
|
144 |
+
def get_intrinsics(self):
|
145 |
+
K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
|
146 |
+
focals = self.get_focals().flatten()
|
147 |
+
K[:, 0, 0] = K[:, 1, 1] = focals
|
148 |
+
K[:, :2, 2] = self.get_principal_points()
|
149 |
+
K[:, 2, 2] = 1
|
150 |
+
return K
|
151 |
+
|
152 |
+
def get_im_poses(self): # cam to world
|
153 |
+
cam2world = self._get_poses(self.im_poses)
|
154 |
+
return cam2world
|
155 |
+
|
156 |
+
def _set_depthmap(self, idx, depth, force=False):
|
157 |
+
depth = _ravel_hw(depth, self.max_area)
|
158 |
+
|
159 |
+
param = self.im_depthmaps[idx]
|
160 |
+
if param.requires_grad or force: # can only init a parameter not already initialized
|
161 |
+
param.data[:] = depth.log().nan_to_num(neginf=0)
|
162 |
+
return param
|
163 |
+
|
164 |
+
def get_depthmaps(self, raw=False):
|
165 |
+
res = self.im_depthmaps.exp()
|
166 |
+
if not raw:
|
167 |
+
res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
|
168 |
+
return res
|
169 |
+
|
170 |
+
def depth_to_pts3d(self):
|
171 |
+
# Get depths and projection params if not provided
|
172 |
+
focals = self.get_focals()
|
173 |
+
pp = self.get_principal_points()
|
174 |
+
im_poses = self.get_im_poses()
|
175 |
+
depth = self.get_depthmaps(raw=True)
|
176 |
+
|
177 |
+
# get pointmaps in camera frame
|
178 |
+
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
|
179 |
+
# project to world frame
|
180 |
+
return geotrf(im_poses, rel_ptmaps)
|
181 |
+
|
182 |
+
def get_pts3d(self, raw=False):
|
183 |
+
res = self.depth_to_pts3d()
|
184 |
+
if not raw:
|
185 |
+
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
186 |
+
return res
|
187 |
+
|
188 |
+
def forward(self):
|
189 |
+
pw_poses = self.get_pw_poses() # cam-to-world
|
190 |
+
pw_adapt = self.get_adaptors().unsqueeze(1)
|
191 |
+
proj_pts3d = self.get_pts3d(raw=True)
|
192 |
+
|
193 |
+
# rotate pairwise prediction according to pw_poses
|
194 |
+
aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
|
195 |
+
aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
|
196 |
+
|
197 |
+
# compute the less
|
198 |
+
li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
|
199 |
+
lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
|
200 |
+
|
201 |
+
return li + lj
|
202 |
+
|
203 |
+
|
204 |
+
def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
|
205 |
+
pp = pp.unsqueeze(1)
|
206 |
+
focal = focal.unsqueeze(1)
|
207 |
+
assert focal.shape == (len(depth), 1, 1)
|
208 |
+
assert pp.shape == (len(depth), 1, 2)
|
209 |
+
assert pixel_grid.shape == depth.shape + (2,)
|
210 |
+
depth = depth.unsqueeze(-1)
|
211 |
+
return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
|
212 |
+
|
213 |
+
|
214 |
+
def ParameterStack(params, keys=None, is_param=None, fill=0):
|
215 |
+
if keys is not None:
|
216 |
+
params = [params[k] for k in keys]
|
217 |
+
|
218 |
+
if fill > 0:
|
219 |
+
params = [_ravel_hw(p, fill) for p in params]
|
220 |
+
|
221 |
+
requires_grad = params[0].requires_grad
|
222 |
+
assert all(p.requires_grad == requires_grad for p in params)
|
223 |
+
|
224 |
+
params = torch.stack(list(params)).float().detach()
|
225 |
+
if is_param or requires_grad:
|
226 |
+
params = nn.Parameter(params)
|
227 |
+
params.requires_grad_(requires_grad)
|
228 |
+
return params
|
229 |
+
|
230 |
+
|
231 |
+
def _ravel_hw(tensor, fill=0):
|
232 |
+
# ravel H,W
|
233 |
+
tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
234 |
+
|
235 |
+
if len(tensor) < fill:
|
236 |
+
tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
|
237 |
+
return tensor
|
238 |
+
|
239 |
+
|
240 |
+
def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
|
241 |
+
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
|
242 |
+
return minf*focal_base, maxf*focal_base
|
243 |
+
|
244 |
+
|
245 |
+
def apply_mask(img, msk):
|
246 |
+
img = img.copy()
|
247 |
+
img[msk] = 0
|
248 |
+
return img
|
mini_dust3r/cloud_opt/pair_viewer.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Dummy optimizer for visualizing pairs
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
|
13 |
+
from mini_dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
|
14 |
+
from mini_dust3r.cloud_opt.commons import edge_str
|
15 |
+
from mini_dust3r.post_process import estimate_focal_knowing_depth
|
16 |
+
|
17 |
+
|
18 |
+
class PairViewer (BasePCOptimizer):
|
19 |
+
"""
|
20 |
+
This a Dummy Optimizer.
|
21 |
+
To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, *args, **kwargs):
|
25 |
+
super().__init__(*args, **kwargs)
|
26 |
+
assert self.is_symmetrized and self.n_edges == 2
|
27 |
+
self.has_im_poses = True
|
28 |
+
|
29 |
+
# compute all parameters directly from raw input
|
30 |
+
self.focals = []
|
31 |
+
self.pp = []
|
32 |
+
rel_poses = []
|
33 |
+
confs = []
|
34 |
+
for i in range(self.n_imgs):
|
35 |
+
conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
|
36 |
+
if self.verbose:
|
37 |
+
print(f' - {conf=:.3} for edge {i}-{1-i}')
|
38 |
+
confs.append(conf)
|
39 |
+
|
40 |
+
H, W = self.imshapes[i]
|
41 |
+
pts3d = self.pred_i[edge_str(i, 1-i)]
|
42 |
+
pp = torch.tensor((W/2, H/2))
|
43 |
+
focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
|
44 |
+
self.focals.append(focal)
|
45 |
+
self.pp.append(pp)
|
46 |
+
|
47 |
+
# estimate the pose of pts1 in image 2
|
48 |
+
pixels = np.mgrid[:W, :H].T.astype(np.float32)
|
49 |
+
pts3d = self.pred_j[edge_str(1-i, i)].numpy()
|
50 |
+
assert pts3d.shape[:2] == (H, W)
|
51 |
+
msk = self.get_masks()[i].numpy()
|
52 |
+
K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
|
53 |
+
|
54 |
+
try:
|
55 |
+
res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
|
56 |
+
iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
|
57 |
+
success, R, T, inliers = res
|
58 |
+
assert success
|
59 |
+
|
60 |
+
R = cv2.Rodrigues(R)[0] # world to cam
|
61 |
+
pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
|
62 |
+
except:
|
63 |
+
pose = np.eye(4)
|
64 |
+
rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
|
65 |
+
|
66 |
+
# let's use the pair with the most confidence
|
67 |
+
if confs[0] > confs[1]:
|
68 |
+
# ptcloud is expressed in camera1
|
69 |
+
self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
|
70 |
+
self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
|
71 |
+
else:
|
72 |
+
# ptcloud is expressed in camera2
|
73 |
+
self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
|
74 |
+
self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
|
75 |
+
|
76 |
+
self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
|
77 |
+
self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
|
78 |
+
self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
|
79 |
+
self.depth = nn.ParameterList(self.depth)
|
80 |
+
for p in self.parameters():
|
81 |
+
p.requires_grad = False
|
82 |
+
|
83 |
+
def _set_depthmap(self, idx, depth, force=False):
|
84 |
+
if self.verbose:
|
85 |
+
print('_set_depthmap is ignored in PairViewer')
|
86 |
+
return
|
87 |
+
|
88 |
+
def get_depthmaps(self, raw=False):
|
89 |
+
depth = [d.to(self.device) for d in self.depth]
|
90 |
+
return depth
|
91 |
+
|
92 |
+
def _set_focal(self, idx, focal, force=False):
|
93 |
+
self.focals[idx] = focal
|
94 |
+
|
95 |
+
def get_focals(self):
|
96 |
+
return self.focals
|
97 |
+
|
98 |
+
def get_known_focal_mask(self):
|
99 |
+
return torch.tensor([not (p.requires_grad) for p in self.focals])
|
100 |
+
|
101 |
+
def get_principal_points(self):
|
102 |
+
return self.pp
|
103 |
+
|
104 |
+
def get_intrinsics(self):
|
105 |
+
focals = self.get_focals()
|
106 |
+
pps = self.get_principal_points()
|
107 |
+
K = torch.zeros((len(focals), 3, 3), device=self.device)
|
108 |
+
for i in range(len(focals)):
|
109 |
+
K[i, 0, 0] = K[i, 1, 1] = focals[i]
|
110 |
+
K[i, :2, 2] = pps[i]
|
111 |
+
K[i, 2, 2] = 1
|
112 |
+
return K
|
113 |
+
|
114 |
+
def get_im_poses(self):
|
115 |
+
return self.im_poses
|
116 |
+
|
117 |
+
def depth_to_pts3d(self):
|
118 |
+
pts3d = []
|
119 |
+
for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
|
120 |
+
pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
|
121 |
+
intrinsics.cpu().numpy(),
|
122 |
+
im_pose.cpu().numpy())
|
123 |
+
pts3d.append(torch.from_numpy(pts).to(device=self.device))
|
124 |
+
return pts3d
|
125 |
+
|
126 |
+
def forward(self):
|
127 |
+
return float('nan')
|
mini_dust3r/croco/blocks.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# Main encoder/decoder blocks
|
7 |
+
# --------------------------------------------------------
|
8 |
+
# References:
|
9 |
+
# timm
|
10 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
11 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
|
12 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
|
13 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
|
14 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
|
15 |
+
|
16 |
+
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
|
20 |
+
from itertools import repeat
|
21 |
+
import collections.abc
|
22 |
+
|
23 |
+
|
24 |
+
def _ntuple(n):
|
25 |
+
def parse(x):
|
26 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
27 |
+
return x
|
28 |
+
return tuple(repeat(x, n))
|
29 |
+
return parse
|
30 |
+
to_2tuple = _ntuple(2)
|
31 |
+
|
32 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
|
33 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
34 |
+
"""
|
35 |
+
if drop_prob == 0. or not training:
|
36 |
+
return x
|
37 |
+
keep_prob = 1 - drop_prob
|
38 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
39 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
40 |
+
if keep_prob > 0.0 and scale_by_keep:
|
41 |
+
random_tensor.div_(keep_prob)
|
42 |
+
return x * random_tensor
|
43 |
+
|
44 |
+
class DropPath(nn.Module):
|
45 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
46 |
+
"""
|
47 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
48 |
+
super(DropPath, self).__init__()
|
49 |
+
self.drop_prob = drop_prob
|
50 |
+
self.scale_by_keep = scale_by_keep
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
54 |
+
|
55 |
+
def extra_repr(self):
|
56 |
+
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
57 |
+
|
58 |
+
class Mlp(nn.Module):
|
59 |
+
""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
60 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
|
61 |
+
super().__init__()
|
62 |
+
out_features = out_features or in_features
|
63 |
+
hidden_features = hidden_features or in_features
|
64 |
+
bias = to_2tuple(bias)
|
65 |
+
drop_probs = to_2tuple(drop)
|
66 |
+
|
67 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
68 |
+
self.act = act_layer()
|
69 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
70 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
71 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
x = self.fc1(x)
|
75 |
+
x = self.act(x)
|
76 |
+
x = self.drop1(x)
|
77 |
+
x = self.fc2(x)
|
78 |
+
x = self.drop2(x)
|
79 |
+
return x
|
80 |
+
|
81 |
+
class Attention(nn.Module):
|
82 |
+
|
83 |
+
def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
84 |
+
super().__init__()
|
85 |
+
self.num_heads = num_heads
|
86 |
+
head_dim = dim // num_heads
|
87 |
+
self.scale = head_dim ** -0.5
|
88 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
89 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
90 |
+
self.proj = nn.Linear(dim, dim)
|
91 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
92 |
+
self.rope = rope
|
93 |
+
|
94 |
+
def forward(self, x, xpos):
|
95 |
+
B, N, C = x.shape
|
96 |
+
|
97 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
|
98 |
+
q, k, v = [qkv[:,:,i] for i in range(3)]
|
99 |
+
# q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
|
100 |
+
|
101 |
+
if self.rope is not None:
|
102 |
+
q = self.rope(q, xpos)
|
103 |
+
k = self.rope(k, xpos)
|
104 |
+
|
105 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
106 |
+
attn = attn.softmax(dim=-1)
|
107 |
+
attn = self.attn_drop(attn)
|
108 |
+
|
109 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
110 |
+
x = self.proj(x)
|
111 |
+
x = self.proj_drop(x)
|
112 |
+
return x
|
113 |
+
|
114 |
+
class Block(nn.Module):
|
115 |
+
|
116 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
117 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
|
118 |
+
super().__init__()
|
119 |
+
self.norm1 = norm_layer(dim)
|
120 |
+
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
121 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
122 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
123 |
+
self.norm2 = norm_layer(dim)
|
124 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
125 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
126 |
+
|
127 |
+
def forward(self, x, xpos):
|
128 |
+
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
129 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
130 |
+
return x
|
131 |
+
|
132 |
+
class CrossAttention(nn.Module):
|
133 |
+
|
134 |
+
def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
|
135 |
+
super().__init__()
|
136 |
+
self.num_heads = num_heads
|
137 |
+
head_dim = dim // num_heads
|
138 |
+
self.scale = head_dim ** -0.5
|
139 |
+
|
140 |
+
self.projq = nn.Linear(dim, dim, bias=qkv_bias)
|
141 |
+
self.projk = nn.Linear(dim, dim, bias=qkv_bias)
|
142 |
+
self.projv = nn.Linear(dim, dim, bias=qkv_bias)
|
143 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
144 |
+
self.proj = nn.Linear(dim, dim)
|
145 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
146 |
+
|
147 |
+
self.rope = rope
|
148 |
+
|
149 |
+
def forward(self, query, key, value, qpos, kpos):
|
150 |
+
B, Nq, C = query.shape
|
151 |
+
Nk = key.shape[1]
|
152 |
+
Nv = value.shape[1]
|
153 |
+
|
154 |
+
q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
155 |
+
k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
156 |
+
v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
|
157 |
+
|
158 |
+
if self.rope is not None:
|
159 |
+
q = self.rope(q, qpos)
|
160 |
+
k = self.rope(k, kpos)
|
161 |
+
|
162 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
163 |
+
attn = attn.softmax(dim=-1)
|
164 |
+
attn = self.attn_drop(attn)
|
165 |
+
|
166 |
+
x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
|
167 |
+
x = self.proj(x)
|
168 |
+
x = self.proj_drop(x)
|
169 |
+
return x
|
170 |
+
|
171 |
+
class DecoderBlock(nn.Module):
|
172 |
+
|
173 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
174 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
|
175 |
+
super().__init__()
|
176 |
+
self.norm1 = norm_layer(dim)
|
177 |
+
self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
178 |
+
self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
179 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
180 |
+
self.norm2 = norm_layer(dim)
|
181 |
+
self.norm3 = norm_layer(dim)
|
182 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
183 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
184 |
+
self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
|
185 |
+
|
186 |
+
def forward(self, x, y, xpos, ypos):
|
187 |
+
x = x + self.drop_path(self.attn(self.norm1(x), xpos))
|
188 |
+
y_ = self.norm_y(y)
|
189 |
+
x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
|
190 |
+
x = x + self.drop_path(self.mlp(self.norm3(x)))
|
191 |
+
return x, y
|
192 |
+
|
193 |
+
|
194 |
+
# patch embedding
|
195 |
+
class PositionGetter(object):
|
196 |
+
""" return positions of patches """
|
197 |
+
|
198 |
+
def __init__(self):
|
199 |
+
self.cache_positions = {}
|
200 |
+
|
201 |
+
def __call__(self, b, h, w, device):
|
202 |
+
if not (h,w) in self.cache_positions:
|
203 |
+
x = torch.arange(w, device=device)
|
204 |
+
y = torch.arange(h, device=device)
|
205 |
+
self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
|
206 |
+
pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
|
207 |
+
return pos
|
208 |
+
|
209 |
+
class PatchEmbed(nn.Module):
|
210 |
+
""" just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
|
211 |
+
|
212 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
213 |
+
super().__init__()
|
214 |
+
img_size = to_2tuple(img_size)
|
215 |
+
patch_size = to_2tuple(patch_size)
|
216 |
+
self.img_size = img_size
|
217 |
+
self.patch_size = patch_size
|
218 |
+
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
219 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
220 |
+
self.flatten = flatten
|
221 |
+
|
222 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
223 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
224 |
+
|
225 |
+
self.position_getter = PositionGetter()
|
226 |
+
|
227 |
+
def forward(self, x):
|
228 |
+
B, C, H, W = x.shape
|
229 |
+
torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
230 |
+
torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
231 |
+
x = self.proj(x)
|
232 |
+
pos = self.position_getter(B, x.size(2), x.size(3), x.device)
|
233 |
+
if self.flatten:
|
234 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
235 |
+
x = self.norm(x)
|
236 |
+
return x, pos
|
237 |
+
|
238 |
+
def _init_weights(self):
|
239 |
+
w = self.proj.weight.data
|
240 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
241 |
+
|
mini_dust3r/croco/croco.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# CroCo model during pretraining
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
|
14 |
+
from functools import partial
|
15 |
+
|
16 |
+
from mini_dust3r.croco.blocks import Block, DecoderBlock, PatchEmbed
|
17 |
+
from mini_dust3r.croco.pos_embed import get_2d_sincos_pos_embed, RoPE2D
|
18 |
+
from mini_dust3r.croco.masking import RandomMask
|
19 |
+
|
20 |
+
|
21 |
+
class CroCoNet(nn.Module):
|
22 |
+
|
23 |
+
def __init__(self,
|
24 |
+
img_size=224, # input image size
|
25 |
+
patch_size=16, # patch_size
|
26 |
+
mask_ratio=0.9, # ratios of masked tokens
|
27 |
+
enc_embed_dim=768, # encoder feature dimension
|
28 |
+
enc_depth=12, # encoder depth
|
29 |
+
enc_num_heads=12, # encoder number of heads in the transformer block
|
30 |
+
dec_embed_dim=512, # decoder feature dimension
|
31 |
+
dec_depth=8, # decoder depth
|
32 |
+
dec_num_heads=16, # decoder number of heads in the transformer block
|
33 |
+
mlp_ratio=4,
|
34 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
35 |
+
norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
|
36 |
+
pos_embed='cosine', # positional embedding (either cosine or RoPE100)
|
37 |
+
):
|
38 |
+
|
39 |
+
super(CroCoNet, self).__init__()
|
40 |
+
|
41 |
+
# patch embeddings (with initialization done as in MAE)
|
42 |
+
self._set_patch_embed(img_size, patch_size, enc_embed_dim)
|
43 |
+
|
44 |
+
# mask generations
|
45 |
+
self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
|
46 |
+
|
47 |
+
self.pos_embed = pos_embed
|
48 |
+
if pos_embed=='cosine':
|
49 |
+
# positional embedding of the encoder
|
50 |
+
enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
|
51 |
+
self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
|
52 |
+
# positional embedding of the decoder
|
53 |
+
dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
|
54 |
+
self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
|
55 |
+
# pos embedding in each block
|
56 |
+
self.rope = None # nothing for cosine
|
57 |
+
elif pos_embed.startswith('RoPE'): # eg RoPE100
|
58 |
+
self.enc_pos_embed = None # nothing to add in the encoder with RoPE
|
59 |
+
self.dec_pos_embed = None # nothing to add in the decoder with RoPE
|
60 |
+
if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
|
61 |
+
freq = float(pos_embed[len('RoPE'):])
|
62 |
+
self.rope = RoPE2D(freq=freq)
|
63 |
+
else:
|
64 |
+
raise NotImplementedError('Unknown pos_embed '+pos_embed)
|
65 |
+
|
66 |
+
# transformer for the encoder
|
67 |
+
self.enc_depth = enc_depth
|
68 |
+
self.enc_embed_dim = enc_embed_dim
|
69 |
+
self.enc_blocks = nn.ModuleList([
|
70 |
+
Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
|
71 |
+
for i in range(enc_depth)])
|
72 |
+
self.enc_norm = norm_layer(enc_embed_dim)
|
73 |
+
|
74 |
+
# masked tokens
|
75 |
+
self._set_mask_token(dec_embed_dim)
|
76 |
+
|
77 |
+
# decoder
|
78 |
+
self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
|
79 |
+
|
80 |
+
# prediction head
|
81 |
+
self._set_prediction_head(dec_embed_dim, patch_size)
|
82 |
+
|
83 |
+
# initializer weights
|
84 |
+
self.initialize_weights()
|
85 |
+
|
86 |
+
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
|
87 |
+
self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
|
88 |
+
|
89 |
+
def _set_mask_generator(self, num_patches, mask_ratio):
|
90 |
+
self.mask_generator = RandomMask(num_patches, mask_ratio)
|
91 |
+
|
92 |
+
def _set_mask_token(self, dec_embed_dim):
|
93 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
|
94 |
+
|
95 |
+
def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
|
96 |
+
self.dec_depth = dec_depth
|
97 |
+
self.dec_embed_dim = dec_embed_dim
|
98 |
+
# transfer from encoder to decoder
|
99 |
+
self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
|
100 |
+
# transformer for the decoder
|
101 |
+
self.dec_blocks = nn.ModuleList([
|
102 |
+
DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
|
103 |
+
for i in range(dec_depth)])
|
104 |
+
# final norm layer
|
105 |
+
self.dec_norm = norm_layer(dec_embed_dim)
|
106 |
+
|
107 |
+
def _set_prediction_head(self, dec_embed_dim, patch_size):
|
108 |
+
self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
|
109 |
+
|
110 |
+
|
111 |
+
def initialize_weights(self):
|
112 |
+
# patch embed
|
113 |
+
self.patch_embed._init_weights()
|
114 |
+
# mask tokens
|
115 |
+
if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
|
116 |
+
# linears and layer norms
|
117 |
+
self.apply(self._init_weights)
|
118 |
+
|
119 |
+
def _init_weights(self, m):
|
120 |
+
if isinstance(m, nn.Linear):
|
121 |
+
# we use xavier_uniform following official JAX ViT:
|
122 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
123 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
124 |
+
nn.init.constant_(m.bias, 0)
|
125 |
+
elif isinstance(m, nn.LayerNorm):
|
126 |
+
nn.init.constant_(m.bias, 0)
|
127 |
+
nn.init.constant_(m.weight, 1.0)
|
128 |
+
|
129 |
+
def _encode_image(self, image, do_mask=False, return_all_blocks=False):
|
130 |
+
"""
|
131 |
+
image has B x 3 x img_size x img_size
|
132 |
+
do_mask: whether to perform masking or not
|
133 |
+
return_all_blocks: if True, return the features at the end of every block
|
134 |
+
instead of just the features from the last block (eg for some prediction heads)
|
135 |
+
"""
|
136 |
+
# embed the image into patches (x has size B x Npatches x C)
|
137 |
+
# and get position if each return patch (pos has size B x Npatches x 2)
|
138 |
+
x, pos = self.patch_embed(image)
|
139 |
+
# add positional embedding without cls token
|
140 |
+
if self.enc_pos_embed is not None:
|
141 |
+
x = x + self.enc_pos_embed[None,...]
|
142 |
+
# apply masking
|
143 |
+
B,N,C = x.size()
|
144 |
+
if do_mask:
|
145 |
+
masks = self.mask_generator(x)
|
146 |
+
x = x[~masks].view(B, -1, C)
|
147 |
+
posvis = pos[~masks].view(B, -1, 2)
|
148 |
+
else:
|
149 |
+
B,N,C = x.size()
|
150 |
+
masks = torch.zeros((B,N), dtype=bool)
|
151 |
+
posvis = pos
|
152 |
+
# now apply the transformer encoder and normalization
|
153 |
+
if return_all_blocks:
|
154 |
+
out = []
|
155 |
+
for blk in self.enc_blocks:
|
156 |
+
x = blk(x, posvis)
|
157 |
+
out.append(x)
|
158 |
+
out[-1] = self.enc_norm(out[-1])
|
159 |
+
return out, pos, masks
|
160 |
+
else:
|
161 |
+
for blk in self.enc_blocks:
|
162 |
+
x = blk(x, posvis)
|
163 |
+
x = self.enc_norm(x)
|
164 |
+
return x, pos, masks
|
165 |
+
|
166 |
+
def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
|
167 |
+
"""
|
168 |
+
return_all_blocks: if True, return the features at the end of every block
|
169 |
+
instead of just the features from the last block (eg for some prediction heads)
|
170 |
+
|
171 |
+
masks1 can be None => assume image1 fully visible
|
172 |
+
"""
|
173 |
+
# encoder to decoder layer
|
174 |
+
visf1 = self.decoder_embed(feat1)
|
175 |
+
f2 = self.decoder_embed(feat2)
|
176 |
+
# append masked tokens to the sequence
|
177 |
+
B,Nenc,C = visf1.size()
|
178 |
+
if masks1 is None: # downstreams
|
179 |
+
f1_ = visf1
|
180 |
+
else: # pretraining
|
181 |
+
Ntotal = masks1.size(1)
|
182 |
+
f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
|
183 |
+
f1_[~masks1] = visf1.view(B * Nenc, C)
|
184 |
+
# add positional embedding
|
185 |
+
if self.dec_pos_embed is not None:
|
186 |
+
f1_ = f1_ + self.dec_pos_embed
|
187 |
+
f2 = f2 + self.dec_pos_embed
|
188 |
+
# apply Transformer blocks
|
189 |
+
out = f1_
|
190 |
+
out2 = f2
|
191 |
+
if return_all_blocks:
|
192 |
+
_out, out = out, []
|
193 |
+
for blk in self.dec_blocks:
|
194 |
+
_out, out2 = blk(_out, out2, pos1, pos2)
|
195 |
+
out.append(_out)
|
196 |
+
out[-1] = self.dec_norm(out[-1])
|
197 |
+
else:
|
198 |
+
for blk in self.dec_blocks:
|
199 |
+
out, out2 = blk(out, out2, pos1, pos2)
|
200 |
+
out = self.dec_norm(out)
|
201 |
+
return out
|
202 |
+
|
203 |
+
def patchify(self, imgs):
|
204 |
+
"""
|
205 |
+
imgs: (B, 3, H, W)
|
206 |
+
x: (B, L, patch_size**2 *3)
|
207 |
+
"""
|
208 |
+
p = self.patch_embed.patch_size[0]
|
209 |
+
assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
|
210 |
+
|
211 |
+
h = w = imgs.shape[2] // p
|
212 |
+
x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
|
213 |
+
x = torch.einsum('nchpwq->nhwpqc', x)
|
214 |
+
x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
|
215 |
+
|
216 |
+
return x
|
217 |
+
|
218 |
+
def unpatchify(self, x, channels=3):
|
219 |
+
"""
|
220 |
+
x: (N, L, patch_size**2 *channels)
|
221 |
+
imgs: (N, 3, H, W)
|
222 |
+
"""
|
223 |
+
patch_size = self.patch_embed.patch_size[0]
|
224 |
+
h = w = int(x.shape[1]**.5)
|
225 |
+
assert h * w == x.shape[1]
|
226 |
+
x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
|
227 |
+
x = torch.einsum('nhwpqc->nchpwq', x)
|
228 |
+
imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
|
229 |
+
return imgs
|
230 |
+
|
231 |
+
def forward(self, img1, img2):
|
232 |
+
"""
|
233 |
+
img1: tensor of size B x 3 x img_size x img_size
|
234 |
+
img2: tensor of size B x 3 x img_size x img_size
|
235 |
+
|
236 |
+
out will be B x N x (3*patch_size*patch_size)
|
237 |
+
masks are also returned as B x N just in case
|
238 |
+
"""
|
239 |
+
# encoder of the masked first image
|
240 |
+
feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
|
241 |
+
# encoder of the second image
|
242 |
+
feat2, pos2, _ = self._encode_image(img2, do_mask=False)
|
243 |
+
# decoder
|
244 |
+
decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
|
245 |
+
# prediction head
|
246 |
+
out = self.prediction_head(decfeat)
|
247 |
+
# get target
|
248 |
+
target = self.patchify(img1)
|
249 |
+
return out, mask1, target
|
mini_dust3r/croco/dpt_block.py
ADDED
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# DPT head for ViTs
|
6 |
+
# --------------------------------------------------------
|
7 |
+
# References:
|
8 |
+
# https://github.com/isl-org/DPT
|
9 |
+
# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
from typing import Union, Tuple, Iterable, List, Optional, Dict
|
16 |
+
|
17 |
+
def pair(t):
|
18 |
+
return t if isinstance(t, tuple) else (t, t)
|
19 |
+
|
20 |
+
def make_scratch(in_shape, out_shape, groups=1, expand=False):
|
21 |
+
scratch = nn.Module()
|
22 |
+
|
23 |
+
out_shape1 = out_shape
|
24 |
+
out_shape2 = out_shape
|
25 |
+
out_shape3 = out_shape
|
26 |
+
out_shape4 = out_shape
|
27 |
+
if expand == True:
|
28 |
+
out_shape1 = out_shape
|
29 |
+
out_shape2 = out_shape * 2
|
30 |
+
out_shape3 = out_shape * 4
|
31 |
+
out_shape4 = out_shape * 8
|
32 |
+
|
33 |
+
scratch.layer1_rn = nn.Conv2d(
|
34 |
+
in_shape[0],
|
35 |
+
out_shape1,
|
36 |
+
kernel_size=3,
|
37 |
+
stride=1,
|
38 |
+
padding=1,
|
39 |
+
bias=False,
|
40 |
+
groups=groups,
|
41 |
+
)
|
42 |
+
scratch.layer2_rn = nn.Conv2d(
|
43 |
+
in_shape[1],
|
44 |
+
out_shape2,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1,
|
48 |
+
bias=False,
|
49 |
+
groups=groups,
|
50 |
+
)
|
51 |
+
scratch.layer3_rn = nn.Conv2d(
|
52 |
+
in_shape[2],
|
53 |
+
out_shape3,
|
54 |
+
kernel_size=3,
|
55 |
+
stride=1,
|
56 |
+
padding=1,
|
57 |
+
bias=False,
|
58 |
+
groups=groups,
|
59 |
+
)
|
60 |
+
scratch.layer4_rn = nn.Conv2d(
|
61 |
+
in_shape[3],
|
62 |
+
out_shape4,
|
63 |
+
kernel_size=3,
|
64 |
+
stride=1,
|
65 |
+
padding=1,
|
66 |
+
bias=False,
|
67 |
+
groups=groups,
|
68 |
+
)
|
69 |
+
|
70 |
+
scratch.layer_rn = nn.ModuleList([
|
71 |
+
scratch.layer1_rn,
|
72 |
+
scratch.layer2_rn,
|
73 |
+
scratch.layer3_rn,
|
74 |
+
scratch.layer4_rn,
|
75 |
+
])
|
76 |
+
|
77 |
+
return scratch
|
78 |
+
|
79 |
+
class ResidualConvUnit_custom(nn.Module):
|
80 |
+
"""Residual convolution module."""
|
81 |
+
|
82 |
+
def __init__(self, features, activation, bn):
|
83 |
+
"""Init.
|
84 |
+
Args:
|
85 |
+
features (int): number of features
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self.bn = bn
|
90 |
+
|
91 |
+
self.groups = 1
|
92 |
+
|
93 |
+
self.conv1 = nn.Conv2d(
|
94 |
+
features,
|
95 |
+
features,
|
96 |
+
kernel_size=3,
|
97 |
+
stride=1,
|
98 |
+
padding=1,
|
99 |
+
bias=not self.bn,
|
100 |
+
groups=self.groups,
|
101 |
+
)
|
102 |
+
|
103 |
+
self.conv2 = nn.Conv2d(
|
104 |
+
features,
|
105 |
+
features,
|
106 |
+
kernel_size=3,
|
107 |
+
stride=1,
|
108 |
+
padding=1,
|
109 |
+
bias=not self.bn,
|
110 |
+
groups=self.groups,
|
111 |
+
)
|
112 |
+
|
113 |
+
if self.bn == True:
|
114 |
+
self.bn1 = nn.BatchNorm2d(features)
|
115 |
+
self.bn2 = nn.BatchNorm2d(features)
|
116 |
+
|
117 |
+
self.activation = activation
|
118 |
+
|
119 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
"""Forward pass.
|
123 |
+
Args:
|
124 |
+
x (tensor): input
|
125 |
+
Returns:
|
126 |
+
tensor: output
|
127 |
+
"""
|
128 |
+
|
129 |
+
out = self.activation(x)
|
130 |
+
out = self.conv1(out)
|
131 |
+
if self.bn == True:
|
132 |
+
out = self.bn1(out)
|
133 |
+
|
134 |
+
out = self.activation(out)
|
135 |
+
out = self.conv2(out)
|
136 |
+
if self.bn == True:
|
137 |
+
out = self.bn2(out)
|
138 |
+
|
139 |
+
if self.groups > 1:
|
140 |
+
out = self.conv_merge(out)
|
141 |
+
|
142 |
+
return self.skip_add.add(out, x)
|
143 |
+
|
144 |
+
class FeatureFusionBlock_custom(nn.Module):
|
145 |
+
"""Feature fusion block."""
|
146 |
+
|
147 |
+
def __init__(
|
148 |
+
self,
|
149 |
+
features,
|
150 |
+
activation,
|
151 |
+
deconv=False,
|
152 |
+
bn=False,
|
153 |
+
expand=False,
|
154 |
+
align_corners=True,
|
155 |
+
width_ratio=1,
|
156 |
+
):
|
157 |
+
"""Init.
|
158 |
+
Args:
|
159 |
+
features (int): number of features
|
160 |
+
"""
|
161 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
162 |
+
self.width_ratio = width_ratio
|
163 |
+
|
164 |
+
self.deconv = deconv
|
165 |
+
self.align_corners = align_corners
|
166 |
+
|
167 |
+
self.groups = 1
|
168 |
+
|
169 |
+
self.expand = expand
|
170 |
+
out_features = features
|
171 |
+
if self.expand == True:
|
172 |
+
out_features = features // 2
|
173 |
+
|
174 |
+
self.out_conv = nn.Conv2d(
|
175 |
+
features,
|
176 |
+
out_features,
|
177 |
+
kernel_size=1,
|
178 |
+
stride=1,
|
179 |
+
padding=0,
|
180 |
+
bias=True,
|
181 |
+
groups=1,
|
182 |
+
)
|
183 |
+
|
184 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
185 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
186 |
+
|
187 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
188 |
+
|
189 |
+
def forward(self, *xs):
|
190 |
+
"""Forward pass.
|
191 |
+
Returns:
|
192 |
+
tensor: output
|
193 |
+
"""
|
194 |
+
output = xs[0]
|
195 |
+
|
196 |
+
if len(xs) == 2:
|
197 |
+
res = self.resConfUnit1(xs[1])
|
198 |
+
if self.width_ratio != 1:
|
199 |
+
res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
|
200 |
+
|
201 |
+
output = self.skip_add.add(output, res)
|
202 |
+
# output += res
|
203 |
+
|
204 |
+
output = self.resConfUnit2(output)
|
205 |
+
|
206 |
+
if self.width_ratio != 1:
|
207 |
+
# and output.shape[3] < self.width_ratio * output.shape[2]
|
208 |
+
#size=(image.shape[])
|
209 |
+
if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
|
210 |
+
shape = 3 * output.shape[3]
|
211 |
+
else:
|
212 |
+
shape = int(self.width_ratio * 2 * output.shape[2])
|
213 |
+
output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
|
214 |
+
else:
|
215 |
+
output = nn.functional.interpolate(output, scale_factor=2,
|
216 |
+
mode="bilinear", align_corners=self.align_corners)
|
217 |
+
output = self.out_conv(output)
|
218 |
+
return output
|
219 |
+
|
220 |
+
def make_fusion_block(features, use_bn, width_ratio=1):
|
221 |
+
return FeatureFusionBlock_custom(
|
222 |
+
features,
|
223 |
+
nn.ReLU(False),
|
224 |
+
deconv=False,
|
225 |
+
bn=use_bn,
|
226 |
+
expand=False,
|
227 |
+
align_corners=True,
|
228 |
+
width_ratio=width_ratio,
|
229 |
+
)
|
230 |
+
|
231 |
+
class Interpolate(nn.Module):
|
232 |
+
"""Interpolation module."""
|
233 |
+
|
234 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
235 |
+
"""Init.
|
236 |
+
Args:
|
237 |
+
scale_factor (float): scaling
|
238 |
+
mode (str): interpolation mode
|
239 |
+
"""
|
240 |
+
super(Interpolate, self).__init__()
|
241 |
+
|
242 |
+
self.interp = nn.functional.interpolate
|
243 |
+
self.scale_factor = scale_factor
|
244 |
+
self.mode = mode
|
245 |
+
self.align_corners = align_corners
|
246 |
+
|
247 |
+
def forward(self, x):
|
248 |
+
"""Forward pass.
|
249 |
+
Args:
|
250 |
+
x (tensor): input
|
251 |
+
Returns:
|
252 |
+
tensor: interpolated data
|
253 |
+
"""
|
254 |
+
|
255 |
+
x = self.interp(
|
256 |
+
x,
|
257 |
+
scale_factor=self.scale_factor,
|
258 |
+
mode=self.mode,
|
259 |
+
align_corners=self.align_corners,
|
260 |
+
)
|
261 |
+
|
262 |
+
return x
|
263 |
+
|
264 |
+
class DPTOutputAdapter(nn.Module):
|
265 |
+
"""DPT output adapter.
|
266 |
+
|
267 |
+
:param num_cahnnels: Number of output channels
|
268 |
+
:param stride_level: tride level compared to the full-sized image.
|
269 |
+
E.g. 4 for 1/4th the size of the image.
|
270 |
+
:param patch_size_full: Int or tuple of the patch size over the full image size.
|
271 |
+
Patch size for smaller inputs will be computed accordingly.
|
272 |
+
:param hooks: Index of intermediate layers
|
273 |
+
:param layer_dims: Dimension of intermediate layers
|
274 |
+
:param feature_dim: Feature dimension
|
275 |
+
:param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
|
276 |
+
:param use_bn: If set to True, activates batch norm
|
277 |
+
:param dim_tokens_enc: Dimension of tokens coming from encoder
|
278 |
+
"""
|
279 |
+
|
280 |
+
def __init__(self,
|
281 |
+
num_channels: int = 1,
|
282 |
+
stride_level: int = 1,
|
283 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
284 |
+
main_tasks: Iterable[str] = ('rgb',),
|
285 |
+
hooks: List[int] = [2, 5, 8, 11],
|
286 |
+
layer_dims: List[int] = [96, 192, 384, 768],
|
287 |
+
feature_dim: int = 256,
|
288 |
+
last_dim: int = 32,
|
289 |
+
use_bn: bool = False,
|
290 |
+
dim_tokens_enc: Optional[int] = None,
|
291 |
+
head_type: str = 'regression',
|
292 |
+
output_width_ratio=1,
|
293 |
+
**kwargs):
|
294 |
+
super().__init__()
|
295 |
+
self.num_channels = num_channels
|
296 |
+
self.stride_level = stride_level
|
297 |
+
self.patch_size = pair(patch_size)
|
298 |
+
self.main_tasks = main_tasks
|
299 |
+
self.hooks = hooks
|
300 |
+
self.layer_dims = layer_dims
|
301 |
+
self.feature_dim = feature_dim
|
302 |
+
self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
|
303 |
+
self.head_type = head_type
|
304 |
+
|
305 |
+
# Actual patch height and width, taking into account stride of input
|
306 |
+
self.P_H = max(1, self.patch_size[0] // stride_level)
|
307 |
+
self.P_W = max(1, self.patch_size[1] // stride_level)
|
308 |
+
|
309 |
+
self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
|
310 |
+
|
311 |
+
self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
312 |
+
self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
313 |
+
self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
314 |
+
self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
|
315 |
+
|
316 |
+
if self.head_type == 'regression':
|
317 |
+
# The "DPTDepthModel" head
|
318 |
+
self.head = nn.Sequential(
|
319 |
+
nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
|
320 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
321 |
+
nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
|
322 |
+
nn.ReLU(True),
|
323 |
+
nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
|
324 |
+
)
|
325 |
+
elif self.head_type == 'semseg':
|
326 |
+
# The "DPTSegmentationModel" head
|
327 |
+
self.head = nn.Sequential(
|
328 |
+
nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
|
329 |
+
nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
|
330 |
+
nn.ReLU(True),
|
331 |
+
nn.Dropout(0.1, False),
|
332 |
+
nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
|
333 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
334 |
+
)
|
335 |
+
else:
|
336 |
+
raise ValueError('DPT head_type must be "regression" or "semseg".')
|
337 |
+
|
338 |
+
if self.dim_tokens_enc is not None:
|
339 |
+
self.init(dim_tokens_enc=dim_tokens_enc)
|
340 |
+
|
341 |
+
def init(self, dim_tokens_enc=768):
|
342 |
+
"""
|
343 |
+
Initialize parts of decoder that are dependent on dimension of encoder tokens.
|
344 |
+
Should be called when setting up MultiMAE.
|
345 |
+
|
346 |
+
:param dim_tokens_enc: Dimension of tokens coming from encoder
|
347 |
+
"""
|
348 |
+
#print(dim_tokens_enc)
|
349 |
+
|
350 |
+
# Set up activation postprocessing layers
|
351 |
+
if isinstance(dim_tokens_enc, int):
|
352 |
+
dim_tokens_enc = 4 * [dim_tokens_enc]
|
353 |
+
|
354 |
+
self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
|
355 |
+
|
356 |
+
self.act_1_postprocess = nn.Sequential(
|
357 |
+
nn.Conv2d(
|
358 |
+
in_channels=self.dim_tokens_enc[0],
|
359 |
+
out_channels=self.layer_dims[0],
|
360 |
+
kernel_size=1, stride=1, padding=0,
|
361 |
+
),
|
362 |
+
nn.ConvTranspose2d(
|
363 |
+
in_channels=self.layer_dims[0],
|
364 |
+
out_channels=self.layer_dims[0],
|
365 |
+
kernel_size=4, stride=4, padding=0,
|
366 |
+
bias=True, dilation=1, groups=1,
|
367 |
+
)
|
368 |
+
)
|
369 |
+
|
370 |
+
self.act_2_postprocess = nn.Sequential(
|
371 |
+
nn.Conv2d(
|
372 |
+
in_channels=self.dim_tokens_enc[1],
|
373 |
+
out_channels=self.layer_dims[1],
|
374 |
+
kernel_size=1, stride=1, padding=0,
|
375 |
+
),
|
376 |
+
nn.ConvTranspose2d(
|
377 |
+
in_channels=self.layer_dims[1],
|
378 |
+
out_channels=self.layer_dims[1],
|
379 |
+
kernel_size=2, stride=2, padding=0,
|
380 |
+
bias=True, dilation=1, groups=1,
|
381 |
+
)
|
382 |
+
)
|
383 |
+
|
384 |
+
self.act_3_postprocess = nn.Sequential(
|
385 |
+
nn.Conv2d(
|
386 |
+
in_channels=self.dim_tokens_enc[2],
|
387 |
+
out_channels=self.layer_dims[2],
|
388 |
+
kernel_size=1, stride=1, padding=0,
|
389 |
+
)
|
390 |
+
)
|
391 |
+
|
392 |
+
self.act_4_postprocess = nn.Sequential(
|
393 |
+
nn.Conv2d(
|
394 |
+
in_channels=self.dim_tokens_enc[3],
|
395 |
+
out_channels=self.layer_dims[3],
|
396 |
+
kernel_size=1, stride=1, padding=0,
|
397 |
+
),
|
398 |
+
nn.Conv2d(
|
399 |
+
in_channels=self.layer_dims[3],
|
400 |
+
out_channels=self.layer_dims[3],
|
401 |
+
kernel_size=3, stride=2, padding=1,
|
402 |
+
)
|
403 |
+
)
|
404 |
+
|
405 |
+
self.act_postprocess = nn.ModuleList([
|
406 |
+
self.act_1_postprocess,
|
407 |
+
self.act_2_postprocess,
|
408 |
+
self.act_3_postprocess,
|
409 |
+
self.act_4_postprocess
|
410 |
+
])
|
411 |
+
|
412 |
+
def adapt_tokens(self, encoder_tokens):
|
413 |
+
# Adapt tokens
|
414 |
+
x = []
|
415 |
+
x.append(encoder_tokens[:, :])
|
416 |
+
x = torch.cat(x, dim=-1)
|
417 |
+
return x
|
418 |
+
|
419 |
+
def forward(self, encoder_tokens: List[torch.Tensor], image_size):
|
420 |
+
#input_info: Dict):
|
421 |
+
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
|
422 |
+
H, W = image_size
|
423 |
+
|
424 |
+
# Number of patches in height and width
|
425 |
+
N_H = H // (self.stride_level * self.P_H)
|
426 |
+
N_W = W // (self.stride_level * self.P_W)
|
427 |
+
|
428 |
+
# Hook decoder onto 4 layers from specified ViT layers
|
429 |
+
layers = [encoder_tokens[hook] for hook in self.hooks]
|
430 |
+
|
431 |
+
# Extract only task-relevant tokens and ignore global tokens.
|
432 |
+
layers = [self.adapt_tokens(l) for l in layers]
|
433 |
+
|
434 |
+
# Reshape tokens to spatial representation
|
435 |
+
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
|
436 |
+
|
437 |
+
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
|
438 |
+
# Project layers to chosen feature dim
|
439 |
+
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
|
440 |
+
|
441 |
+
# Fuse layers using refinement stages
|
442 |
+
path_4 = self.scratch.refinenet4(layers[3])
|
443 |
+
path_3 = self.scratch.refinenet3(path_4, layers[2])
|
444 |
+
path_2 = self.scratch.refinenet2(path_3, layers[1])
|
445 |
+
path_1 = self.scratch.refinenet1(path_2, layers[0])
|
446 |
+
|
447 |
+
# Output head
|
448 |
+
out = self.head(path_1)
|
449 |
+
|
450 |
+
return out
|
mini_dust3r/croco/masking.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# Masking utils
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
class RandomMask(nn.Module):
|
13 |
+
"""
|
14 |
+
random masking
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, num_patches, mask_ratio):
|
18 |
+
super().__init__()
|
19 |
+
self.num_patches = num_patches
|
20 |
+
self.num_mask = int(mask_ratio * self.num_patches)
|
21 |
+
|
22 |
+
def __call__(self, x):
|
23 |
+
noise = torch.rand(x.size(0), self.num_patches, device=x.device)
|
24 |
+
argsort = torch.argsort(noise, dim=1)
|
25 |
+
return argsort < self.num_mask
|
mini_dust3r/croco/pos_embed.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2022-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# Position embedding utils
|
7 |
+
# --------------------------------------------------------
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
import torch
|
14 |
+
|
15 |
+
# --------------------------------------------------------
|
16 |
+
# 2D sine-cosine position embedding
|
17 |
+
# References:
|
18 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
19 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
20 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
21 |
+
# --------------------------------------------------------
|
22 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
|
23 |
+
"""
|
24 |
+
grid_size: int of the grid height and width
|
25 |
+
return:
|
26 |
+
pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
27 |
+
"""
|
28 |
+
grid_h = np.arange(grid_size, dtype=np.float32)
|
29 |
+
grid_w = np.arange(grid_size, dtype=np.float32)
|
30 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
31 |
+
grid = np.stack(grid, axis=0)
|
32 |
+
|
33 |
+
grid = grid.reshape([2, 1, grid_size, grid_size])
|
34 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
35 |
+
if n_cls_token>0:
|
36 |
+
pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
|
37 |
+
return pos_embed
|
38 |
+
|
39 |
+
|
40 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
41 |
+
assert embed_dim % 2 == 0
|
42 |
+
|
43 |
+
# use half of dimensions to encode grid_h
|
44 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
45 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
46 |
+
|
47 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
48 |
+
return emb
|
49 |
+
|
50 |
+
|
51 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
52 |
+
"""
|
53 |
+
embed_dim: output dimension for each position
|
54 |
+
pos: a list of positions to be encoded: size (M,)
|
55 |
+
out: (M, D)
|
56 |
+
"""
|
57 |
+
assert embed_dim % 2 == 0
|
58 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
59 |
+
omega /= embed_dim / 2.
|
60 |
+
omega = 1. / 10000**omega # (D/2,)
|
61 |
+
|
62 |
+
pos = pos.reshape(-1) # (M,)
|
63 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
64 |
+
|
65 |
+
emb_sin = np.sin(out) # (M, D/2)
|
66 |
+
emb_cos = np.cos(out) # (M, D/2)
|
67 |
+
|
68 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
69 |
+
return emb
|
70 |
+
|
71 |
+
|
72 |
+
# --------------------------------------------------------
|
73 |
+
# Interpolate position embeddings for high-resolution
|
74 |
+
# References:
|
75 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
76 |
+
# DeiT: https://github.com/facebookresearch/deit
|
77 |
+
# --------------------------------------------------------
|
78 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
79 |
+
if 'pos_embed' in checkpoint_model:
|
80 |
+
pos_embed_checkpoint = checkpoint_model['pos_embed']
|
81 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
82 |
+
num_patches = model.patch_embed.num_patches
|
83 |
+
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
84 |
+
# height (== width) for the checkpoint position embedding
|
85 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
86 |
+
# height (== width) for the new position embedding
|
87 |
+
new_size = int(num_patches ** 0.5)
|
88 |
+
# class_token and dist_token are kept unchanged
|
89 |
+
if orig_size != new_size:
|
90 |
+
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
|
91 |
+
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
92 |
+
# only the position tokens are interpolated
|
93 |
+
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
94 |
+
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
95 |
+
pos_tokens = torch.nn.functional.interpolate(
|
96 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
97 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
98 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
99 |
+
checkpoint_model['pos_embed'] = new_pos_embed
|
100 |
+
|
101 |
+
|
102 |
+
#----------------------------------------------------------
|
103 |
+
# RoPE2D: RoPE implementation in 2D
|
104 |
+
#----------------------------------------------------------
|
105 |
+
|
106 |
+
try:
|
107 |
+
from mini_dust3r.croco.curope import cuRoPE2D
|
108 |
+
RoPE2D = cuRoPE2D
|
109 |
+
except ImportError:
|
110 |
+
print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
|
111 |
+
|
112 |
+
class RoPE2D(torch.nn.Module):
|
113 |
+
|
114 |
+
def __init__(self, freq=100.0, F0=1.0):
|
115 |
+
super().__init__()
|
116 |
+
self.base = freq
|
117 |
+
self.F0 = F0
|
118 |
+
self.cache = {}
|
119 |
+
|
120 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
121 |
+
if (D,seq_len,device,dtype) not in self.cache:
|
122 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
123 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
124 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
125 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
126 |
+
cos = freqs.cos() # (Seq, Dim)
|
127 |
+
sin = freqs.sin()
|
128 |
+
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
129 |
+
return self.cache[D,seq_len,device,dtype]
|
130 |
+
|
131 |
+
@staticmethod
|
132 |
+
def rotate_half(x):
|
133 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
134 |
+
return torch.cat((-x2, x1), dim=-1)
|
135 |
+
|
136 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
137 |
+
assert pos1d.ndim==2
|
138 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
139 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
140 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
141 |
+
|
142 |
+
def forward(self, tokens, positions):
|
143 |
+
"""
|
144 |
+
input:
|
145 |
+
* tokens: batch_size x nheads x ntokens x dim
|
146 |
+
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
147 |
+
output:
|
148 |
+
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
149 |
+
"""
|
150 |
+
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
151 |
+
D = tokens.size(3) // 2
|
152 |
+
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
153 |
+
cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
|
154 |
+
# split features into two along the feature dimension, and apply rope1d on each half
|
155 |
+
y, x = tokens.chunk(2, dim=-1)
|
156 |
+
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
157 |
+
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
158 |
+
tokens = torch.cat((y, x), dim=-1)
|
159 |
+
return tokens
|
mini_dust3r/heads/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# head factory
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from .linear_head import LinearPts3d
|
8 |
+
from .dpt_head import create_dpt_head
|
9 |
+
|
10 |
+
|
11 |
+
def head_factory(head_type, output_mode, net, has_conf=False):
|
12 |
+
"""" build a prediction head for the decoder
|
13 |
+
"""
|
14 |
+
if head_type == 'linear' and output_mode == 'pts3d':
|
15 |
+
return LinearPts3d(net, has_conf)
|
16 |
+
elif head_type == 'dpt' and output_mode == 'pts3d':
|
17 |
+
return create_dpt_head(net, has_conf=has_conf)
|
18 |
+
else:
|
19 |
+
raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
|
mini_dust3r/heads/dpt_head.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# dpt head implementation for DUST3R
|
6 |
+
# Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
|
7 |
+
# or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
|
8 |
+
# the forward function also takes as input a dictionnary img_info with key "height" and "width"
|
9 |
+
# for PixelwiseTask, the output will be of dimension B x num_channels x H x W
|
10 |
+
# --------------------------------------------------------
|
11 |
+
from einops import rearrange
|
12 |
+
from typing import List
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
from mini_dust3r.heads.postprocess import postprocess
|
16 |
+
from mini_dust3r.croco.dpt_block import DPTOutputAdapter
|
17 |
+
|
18 |
+
|
19 |
+
class DPTOutputAdapter_fix(DPTOutputAdapter):
|
20 |
+
"""
|
21 |
+
Adapt croco's DPTOutputAdapter implementation for dust3r:
|
22 |
+
remove duplicated weigths, and fix forward for dust3r
|
23 |
+
"""
|
24 |
+
|
25 |
+
def init(self, dim_tokens_enc=768):
|
26 |
+
super().init(dim_tokens_enc)
|
27 |
+
# these are duplicated weights
|
28 |
+
del self.act_1_postprocess
|
29 |
+
del self.act_2_postprocess
|
30 |
+
del self.act_3_postprocess
|
31 |
+
del self.act_4_postprocess
|
32 |
+
|
33 |
+
def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
|
34 |
+
assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
|
35 |
+
# H, W = input_info['image_size']
|
36 |
+
image_size = self.image_size if image_size is None else image_size
|
37 |
+
H, W = image_size
|
38 |
+
# Number of patches in height and width
|
39 |
+
N_H = H // (self.stride_level * self.P_H)
|
40 |
+
N_W = W // (self.stride_level * self.P_W)
|
41 |
+
|
42 |
+
# Hook decoder onto 4 layers from specified ViT layers
|
43 |
+
layers = [encoder_tokens[hook] for hook in self.hooks]
|
44 |
+
|
45 |
+
# Extract only task-relevant tokens and ignore global tokens.
|
46 |
+
layers = [self.adapt_tokens(l) for l in layers]
|
47 |
+
|
48 |
+
# Reshape tokens to spatial representation
|
49 |
+
layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
|
50 |
+
|
51 |
+
layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
|
52 |
+
# Project layers to chosen feature dim
|
53 |
+
layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
|
54 |
+
|
55 |
+
# Fuse layers using refinement stages
|
56 |
+
path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
|
57 |
+
path_3 = self.scratch.refinenet3(path_4, layers[2])
|
58 |
+
path_2 = self.scratch.refinenet2(path_3, layers[1])
|
59 |
+
path_1 = self.scratch.refinenet1(path_2, layers[0])
|
60 |
+
|
61 |
+
# Output head
|
62 |
+
out = self.head(path_1)
|
63 |
+
|
64 |
+
return out
|
65 |
+
|
66 |
+
|
67 |
+
class PixelwiseTaskWithDPT(nn.Module):
|
68 |
+
""" DPT module for dust3r, can return 3D points + confidence for all pixels"""
|
69 |
+
|
70 |
+
def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
|
71 |
+
output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
|
72 |
+
super(PixelwiseTaskWithDPT, self).__init__()
|
73 |
+
self.return_all_layers = True # backbone needs to return all layers
|
74 |
+
self.postprocess = postprocess
|
75 |
+
self.depth_mode = depth_mode
|
76 |
+
self.conf_mode = conf_mode
|
77 |
+
|
78 |
+
assert n_cls_token == 0, "Not implemented"
|
79 |
+
dpt_args = dict(output_width_ratio=output_width_ratio,
|
80 |
+
num_channels=num_channels,
|
81 |
+
**kwargs)
|
82 |
+
if hooks_idx is not None:
|
83 |
+
dpt_args.update(hooks=hooks_idx)
|
84 |
+
self.dpt = DPTOutputAdapter_fix(**dpt_args)
|
85 |
+
dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
|
86 |
+
self.dpt.init(**dpt_init_args)
|
87 |
+
|
88 |
+
def forward(self, x, img_info):
|
89 |
+
out = self.dpt(x, image_size=(img_info[0], img_info[1]))
|
90 |
+
if self.postprocess:
|
91 |
+
out = self.postprocess(out, self.depth_mode, self.conf_mode)
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
def create_dpt_head(net, has_conf=False):
|
96 |
+
"""
|
97 |
+
return PixelwiseTaskWithDPT for given net params
|
98 |
+
"""
|
99 |
+
assert net.dec_depth > 9
|
100 |
+
l2 = net.dec_depth
|
101 |
+
feature_dim = 256
|
102 |
+
last_dim = feature_dim//2
|
103 |
+
out_nchan = 3
|
104 |
+
ed = net.enc_embed_dim
|
105 |
+
dd = net.dec_embed_dim
|
106 |
+
return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
|
107 |
+
feature_dim=feature_dim,
|
108 |
+
last_dim=last_dim,
|
109 |
+
hooks_idx=[0, l2*2//4, l2*3//4, l2],
|
110 |
+
dim_tokens=[ed, dd, dd, dd],
|
111 |
+
postprocess=postprocess,
|
112 |
+
depth_mode=net.depth_mode,
|
113 |
+
conf_mode=net.conf_mode,
|
114 |
+
head_type='regression')
|
mini_dust3r/heads/linear_head.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# linear head implementation for DUST3R
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from mini_dust3r.heads.postprocess import postprocess
|
10 |
+
|
11 |
+
|
12 |
+
class LinearPts3d (nn.Module):
|
13 |
+
"""
|
14 |
+
Linear head for dust3r
|
15 |
+
Each token outputs: - 16x16 3D points (+ confidence)
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, net, has_conf=False):
|
19 |
+
super().__init__()
|
20 |
+
self.patch_size = net.patch_embed.patch_size[0]
|
21 |
+
self.depth_mode = net.depth_mode
|
22 |
+
self.conf_mode = net.conf_mode
|
23 |
+
self.has_conf = has_conf
|
24 |
+
|
25 |
+
self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
|
26 |
+
|
27 |
+
def setup(self, croconet):
|
28 |
+
pass
|
29 |
+
|
30 |
+
def forward(self, decout, img_shape):
|
31 |
+
H, W = img_shape
|
32 |
+
tokens = decout[-1]
|
33 |
+
B, S, D = tokens.shape
|
34 |
+
|
35 |
+
# extract 3D points
|
36 |
+
feat = self.proj(tokens) # B,S,D
|
37 |
+
feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
|
38 |
+
feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
|
39 |
+
|
40 |
+
# permute + norm depth
|
41 |
+
return postprocess(feat, self.depth_mode, self.conf_mode)
|
mini_dust3r/heads/postprocess.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# post process function for all heads: extract 3D points/confidence from output
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def postprocess(out, depth_mode, conf_mode):
|
11 |
+
"""
|
12 |
+
extract 3D points/confidence from prediction head output
|
13 |
+
"""
|
14 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,3
|
15 |
+
res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
|
16 |
+
|
17 |
+
if conf_mode is not None:
|
18 |
+
res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
|
19 |
+
return res
|
20 |
+
|
21 |
+
|
22 |
+
def reg_dense_depth(xyz, mode):
|
23 |
+
"""
|
24 |
+
extract 3D points from prediction head output
|
25 |
+
"""
|
26 |
+
mode, vmin, vmax = mode
|
27 |
+
|
28 |
+
no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
|
29 |
+
assert no_bounds
|
30 |
+
|
31 |
+
if mode == 'linear':
|
32 |
+
if no_bounds:
|
33 |
+
return xyz # [-inf, +inf]
|
34 |
+
return xyz.clip(min=vmin, max=vmax)
|
35 |
+
|
36 |
+
# distance to origin
|
37 |
+
d = xyz.norm(dim=-1, keepdim=True)
|
38 |
+
xyz = xyz / d.clip(min=1e-8)
|
39 |
+
|
40 |
+
if mode == 'square':
|
41 |
+
return xyz * d.square()
|
42 |
+
|
43 |
+
if mode == 'exp':
|
44 |
+
return xyz * torch.expm1(d)
|
45 |
+
|
46 |
+
raise ValueError(f'bad {mode=}')
|
47 |
+
|
48 |
+
|
49 |
+
def reg_dense_conf(x, mode):
|
50 |
+
"""
|
51 |
+
extract confidence from prediction head output
|
52 |
+
"""
|
53 |
+
mode, vmin, vmax = mode
|
54 |
+
if mode == 'exp':
|
55 |
+
return vmin + x.exp().clip(max=vmax-vmin)
|
56 |
+
if mode == 'sigmoid':
|
57 |
+
return (vmax - vmin) * torch.sigmoid(x) + vmin
|
58 |
+
raise ValueError(f'bad {mode=}')
|
mini_dust3r/image_pairs.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilities needed to load image pairs
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from mini_dust3r.utils.image import ImageDict
|
10 |
+
|
11 |
+
|
12 |
+
def make_pairs(
|
13 |
+
imgs: list[ImageDict],
|
14 |
+
scene_graph: str = "complete",
|
15 |
+
prefilter=None,
|
16 |
+
symmetrize=True,
|
17 |
+
) -> list[tuple[ImageDict, ImageDict]]:
|
18 |
+
pairs = []
|
19 |
+
if scene_graph == "complete": # complete graph
|
20 |
+
for i in range(len(imgs)):
|
21 |
+
for j in range(i):
|
22 |
+
pairs.append((imgs[i], imgs[j]))
|
23 |
+
elif scene_graph.startswith("swin"):
|
24 |
+
winsize = int(scene_graph.split("-")[1]) if "-" in scene_graph else 3
|
25 |
+
pairsid = set()
|
26 |
+
for i in range(len(imgs)):
|
27 |
+
for j in range(1, winsize + 1):
|
28 |
+
idx = (i + j) % len(imgs) # explicit loop closure
|
29 |
+
pairsid.add((i, idx) if i < idx else (idx, i))
|
30 |
+
for i, j in pairsid:
|
31 |
+
pairs.append((imgs[i], imgs[j]))
|
32 |
+
elif scene_graph.startswith("oneref"):
|
33 |
+
refid = int(scene_graph.split("-")[1]) if "-" in scene_graph else 0
|
34 |
+
for j in range(len(imgs)):
|
35 |
+
if j != refid:
|
36 |
+
pairs.append((imgs[refid], imgs[j]))
|
37 |
+
if symmetrize:
|
38 |
+
pairs += [(img2, img1) for img1, img2 in pairs]
|
39 |
+
|
40 |
+
# now, remove edges
|
41 |
+
if isinstance(prefilter, str) and prefilter.startswith("seq"):
|
42 |
+
pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
|
43 |
+
|
44 |
+
if isinstance(prefilter, str) and prefilter.startswith("cyc"):
|
45 |
+
pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
|
46 |
+
|
47 |
+
return pairs
|
48 |
+
|
49 |
+
|
50 |
+
def sel(x, kept):
|
51 |
+
if isinstance(x, dict):
|
52 |
+
return {k: sel(v, kept) for k, v in x.items()}
|
53 |
+
if isinstance(x, (torch.Tensor, np.ndarray)):
|
54 |
+
return x[kept]
|
55 |
+
if isinstance(x, (tuple, list)):
|
56 |
+
return type(x)([x[k] for k in kept])
|
57 |
+
|
58 |
+
|
59 |
+
def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
|
60 |
+
# number of images
|
61 |
+
n = max(max(e) for e in edges) + 1
|
62 |
+
|
63 |
+
kept = []
|
64 |
+
for e, (i, j) in enumerate(edges):
|
65 |
+
dis = abs(i - j)
|
66 |
+
if cyclic:
|
67 |
+
dis = min(dis, abs(i + n - j), abs(i - n - j))
|
68 |
+
if dis <= seq_dis_thr:
|
69 |
+
kept.append(e)
|
70 |
+
return kept
|
71 |
+
|
72 |
+
|
73 |
+
def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
|
74 |
+
edges = [(img1["idx"], img2["idx"]) for img1, img2 in pairs]
|
75 |
+
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
|
76 |
+
return [pairs[i] for i in kept]
|
77 |
+
|
78 |
+
|
79 |
+
def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
|
80 |
+
edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])]
|
81 |
+
kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
|
82 |
+
print(
|
83 |
+
f">> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges"
|
84 |
+
)
|
85 |
+
return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
|
mini_dust3r/inference.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilities needed for the inference
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import tqdm
|
8 |
+
import torch
|
9 |
+
from mini_dust3r.utils.device import to_cpu, collate_with_cat
|
10 |
+
from mini_dust3r.utils.misc import invalid_to_nans
|
11 |
+
from mini_dust3r.utils.geometry import depthmap_to_pts3d, geotrf
|
12 |
+
from mini_dust3r.utils.image import ImageDict
|
13 |
+
from mini_dust3r.model import AsymmetricCroCo3DStereo
|
14 |
+
|
15 |
+
from typing import Literal, TypedDict, Optional
|
16 |
+
from jaxtyping import Float32
|
17 |
+
|
18 |
+
|
19 |
+
class Dust3rPred1(TypedDict):
|
20 |
+
pts3d: Float32[torch.Tensor, "b h w c"]
|
21 |
+
conf: Float32[torch.Tensor, "b h w"]
|
22 |
+
|
23 |
+
|
24 |
+
class Dust3rPred2(TypedDict):
|
25 |
+
pts3d_in_other_view: Float32[torch.Tensor, "b h w c"]
|
26 |
+
conf: Float32[torch.Tensor, "b h w"]
|
27 |
+
|
28 |
+
|
29 |
+
class Dust3rResult(TypedDict):
|
30 |
+
view1: ImageDict
|
31 |
+
view2: ImageDict
|
32 |
+
pred1: Dust3rPred1
|
33 |
+
pred2: Dust3rPred2
|
34 |
+
loss: Optional[int]
|
35 |
+
|
36 |
+
|
37 |
+
def _interleave_imgs(img1, img2):
|
38 |
+
res = {}
|
39 |
+
for key, value1 in img1.items():
|
40 |
+
value2 = img2[key]
|
41 |
+
if isinstance(value1, torch.Tensor):
|
42 |
+
value = torch.stack((value1, value2), dim=1).flatten(0, 1)
|
43 |
+
else:
|
44 |
+
value = [x for pair in zip(value1, value2) for x in pair]
|
45 |
+
res[key] = value
|
46 |
+
return res
|
47 |
+
|
48 |
+
|
49 |
+
def make_batch_symmetric(batch):
|
50 |
+
view1, view2 = batch
|
51 |
+
view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
|
52 |
+
return view1, view2
|
53 |
+
|
54 |
+
|
55 |
+
def loss_of_one_batch(
|
56 |
+
batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None
|
57 |
+
):
|
58 |
+
view1, view2 = batch
|
59 |
+
for view in batch:
|
60 |
+
for name in (
|
61 |
+
"img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split()
|
62 |
+
): # pseudo_focal
|
63 |
+
if name not in view:
|
64 |
+
continue
|
65 |
+
view[name] = view[name].to(device, non_blocking=True)
|
66 |
+
|
67 |
+
if symmetrize_batch:
|
68 |
+
view1, view2 = make_batch_symmetric(batch)
|
69 |
+
|
70 |
+
with torch.cuda.amp.autocast(enabled=bool(use_amp)):
|
71 |
+
pred1, pred2 = model(view1, view2)
|
72 |
+
|
73 |
+
# loss is supposed to be symmetric
|
74 |
+
with torch.cuda.amp.autocast(enabled=False):
|
75 |
+
loss = (
|
76 |
+
criterion(view1, view2, pred1, pred2) if criterion is not None else None
|
77 |
+
)
|
78 |
+
|
79 |
+
result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
|
80 |
+
return result[ret] if ret else result
|
81 |
+
|
82 |
+
|
83 |
+
@torch.no_grad()
|
84 |
+
def inference(
|
85 |
+
pairs: list[tuple[ImageDict, ImageDict]],
|
86 |
+
model: AsymmetricCroCo3DStereo,
|
87 |
+
device: Literal["cpu", "cuda", "mps"],
|
88 |
+
batch_size: int = 8,
|
89 |
+
verbose: bool = True,
|
90 |
+
) -> Dust3rResult:
|
91 |
+
if verbose:
|
92 |
+
print(f">> Inference with model on {len(pairs)} image pairs")
|
93 |
+
result = []
|
94 |
+
|
95 |
+
# first, check if all images have the same size
|
96 |
+
multiple_shapes = not (check_if_same_size(pairs))
|
97 |
+
if multiple_shapes: # force bs=1
|
98 |
+
batch_size = 1
|
99 |
+
|
100 |
+
for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose):
|
101 |
+
res: Dust3rResult = loss_of_one_batch(
|
102 |
+
collate_with_cat(pairs[i : i + batch_size]), model, None, device
|
103 |
+
)
|
104 |
+
result.append(to_cpu(res))
|
105 |
+
|
106 |
+
result = collate_with_cat(result, lists=multiple_shapes)
|
107 |
+
|
108 |
+
return result
|
109 |
+
|
110 |
+
|
111 |
+
def check_if_same_size(pairs):
|
112 |
+
shapes1 = [img1["img"].shape[-2:] for img1, img2 in pairs]
|
113 |
+
shapes2 = [img2["img"].shape[-2:] for img1, img2 in pairs]
|
114 |
+
return all(shapes1[0] == s for s in shapes1) and all(
|
115 |
+
shapes2[0] == s for s in shapes2
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
def get_pred_pts3d(gt, pred, use_pose=False):
|
120 |
+
if "depth" in pred and "pseudo_focal" in pred:
|
121 |
+
try:
|
122 |
+
pp = gt["camera_intrinsics"][..., :2, 2]
|
123 |
+
except KeyError:
|
124 |
+
pp = None
|
125 |
+
pts3d = depthmap_to_pts3d(**pred, pp=pp)
|
126 |
+
|
127 |
+
elif "pts3d" in pred:
|
128 |
+
# pts3d from my camera
|
129 |
+
pts3d = pred["pts3d"]
|
130 |
+
|
131 |
+
elif "pts3d_in_other_view" in pred:
|
132 |
+
# pts3d from the other camera, already transformed
|
133 |
+
assert use_pose is True
|
134 |
+
return pred["pts3d_in_other_view"] # return!
|
135 |
+
|
136 |
+
if use_pose:
|
137 |
+
camera_pose = pred.get("camera_pose")
|
138 |
+
assert camera_pose is not None
|
139 |
+
pts3d = geotrf(camera_pose, pts3d)
|
140 |
+
|
141 |
+
return pts3d
|
142 |
+
|
143 |
+
|
144 |
+
def find_opt_scaling(
|
145 |
+
gt_pts1,
|
146 |
+
gt_pts2,
|
147 |
+
pr_pts1,
|
148 |
+
pr_pts2=None,
|
149 |
+
fit_mode="weiszfeld_stop_grad",
|
150 |
+
valid1=None,
|
151 |
+
valid2=None,
|
152 |
+
):
|
153 |
+
assert gt_pts1.ndim == pr_pts1.ndim == 4
|
154 |
+
assert gt_pts1.shape == pr_pts1.shape
|
155 |
+
if gt_pts2 is not None:
|
156 |
+
assert gt_pts2.ndim == pr_pts2.ndim == 4
|
157 |
+
assert gt_pts2.shape == pr_pts2.shape
|
158 |
+
|
159 |
+
# concat the pointcloud
|
160 |
+
nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
|
161 |
+
nan_gt_pts2 = (
|
162 |
+
invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
|
163 |
+
)
|
164 |
+
|
165 |
+
pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
|
166 |
+
pr_pts2 = (
|
167 |
+
invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
|
168 |
+
)
|
169 |
+
|
170 |
+
all_gt = (
|
171 |
+
torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1)
|
172 |
+
if gt_pts2 is not None
|
173 |
+
else nan_gt_pts1
|
174 |
+
)
|
175 |
+
all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
|
176 |
+
|
177 |
+
dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
|
178 |
+
dot_gt_gt = all_gt.square().sum(dim=-1)
|
179 |
+
|
180 |
+
if fit_mode.startswith("avg"):
|
181 |
+
# scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
|
182 |
+
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
|
183 |
+
elif fit_mode.startswith("median"):
|
184 |
+
scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
|
185 |
+
elif fit_mode.startswith("weiszfeld"):
|
186 |
+
# init scaling with l2 closed form
|
187 |
+
scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
|
188 |
+
# iterative re-weighted least-squares
|
189 |
+
for iter in range(10):
|
190 |
+
# re-weighting by inverse of distance
|
191 |
+
dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
|
192 |
+
# print(dis.nanmean(-1))
|
193 |
+
w = dis.clip_(min=1e-8).reciprocal()
|
194 |
+
# update the scaling with the new weights
|
195 |
+
scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
|
196 |
+
else:
|
197 |
+
raise ValueError(f"bad {fit_mode=}")
|
198 |
+
|
199 |
+
if fit_mode.endswith("stop_grad"):
|
200 |
+
scaling = scaling.detach()
|
201 |
+
|
202 |
+
scaling = scaling.clip(min=1e-3)
|
203 |
+
# assert scaling.isfinite().all(), bb()
|
204 |
+
return scaling
|
mini_dust3r/model.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# DUSt3R model class
|
6 |
+
# --------------------------------------------------------
|
7 |
+
from copy import deepcopy
|
8 |
+
import torch
|
9 |
+
import os
|
10 |
+
from packaging import version
|
11 |
+
import huggingface_hub
|
12 |
+
|
13 |
+
from .utils.misc import (
|
14 |
+
fill_default_args,
|
15 |
+
freeze_all_params,
|
16 |
+
is_symmetrized,
|
17 |
+
interleave,
|
18 |
+
transpose_to_landscape,
|
19 |
+
)
|
20 |
+
from .heads import head_factory
|
21 |
+
from mini_dust3r.patch_embed import get_patch_embed
|
22 |
+
|
23 |
+
from mini_dust3r.croco.croco import CroCoNet
|
24 |
+
|
25 |
+
inf = float("inf")
|
26 |
+
|
27 |
+
hf_version_number = huggingface_hub.__version__
|
28 |
+
assert version.parse(hf_version_number) >= version.parse(
|
29 |
+
"0.22.0"
|
30 |
+
), "Outdated huggingface_hub version, please reinstall requirements.txt"
|
31 |
+
|
32 |
+
|
33 |
+
def load_model(model_path, device, verbose=True):
|
34 |
+
if verbose:
|
35 |
+
print("... loading model from", model_path)
|
36 |
+
ckpt = torch.load(model_path, map_location="cpu")
|
37 |
+
args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
|
38 |
+
if "landscape_only" not in args:
|
39 |
+
args = args[:-1] + ", landscape_only=False)"
|
40 |
+
else:
|
41 |
+
args = args.replace(" ", "").replace(
|
42 |
+
"landscape_only=True", "landscape_only=False"
|
43 |
+
)
|
44 |
+
assert "landscape_only=False" in args
|
45 |
+
if verbose:
|
46 |
+
print(f"instantiating : {args}")
|
47 |
+
net = eval(args)
|
48 |
+
s = net.load_state_dict(ckpt["model"], strict=False)
|
49 |
+
if verbose:
|
50 |
+
print(s)
|
51 |
+
return net.to(device)
|
52 |
+
|
53 |
+
|
54 |
+
class AsymmetricCroCo3DStereo(
|
55 |
+
CroCoNet,
|
56 |
+
huggingface_hub.PyTorchModelHubMixin,
|
57 |
+
library_name="dust3r",
|
58 |
+
repo_url="https://github.com/naver/dust3r",
|
59 |
+
tags=["image-to-3d"],
|
60 |
+
):
|
61 |
+
"""Two siamese encoders, followed by two decoders.
|
62 |
+
The goal is to output 3d points directly, both images in view1's frame
|
63 |
+
(hence the asymmetry).
|
64 |
+
"""
|
65 |
+
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
output_mode="pts3d",
|
69 |
+
head_type="linear",
|
70 |
+
depth_mode=("exp", -inf, inf),
|
71 |
+
conf_mode=("exp", 1, inf),
|
72 |
+
freeze="none",
|
73 |
+
landscape_only=True,
|
74 |
+
patch_embed_cls="PatchEmbedDust3R", # PatchEmbedDust3R or ManyAR_PatchEmbed
|
75 |
+
**croco_kwargs,
|
76 |
+
):
|
77 |
+
self.patch_embed_cls = patch_embed_cls
|
78 |
+
self.croco_args = fill_default_args(croco_kwargs, super().__init__)
|
79 |
+
super().__init__(**croco_kwargs)
|
80 |
+
|
81 |
+
# dust3r specific initialization
|
82 |
+
self.dec_blocks2 = deepcopy(self.dec_blocks)
|
83 |
+
self.set_downstream_head(
|
84 |
+
output_mode,
|
85 |
+
head_type,
|
86 |
+
landscape_only,
|
87 |
+
depth_mode,
|
88 |
+
conf_mode,
|
89 |
+
**croco_kwargs,
|
90 |
+
)
|
91 |
+
self.set_freeze(freeze)
|
92 |
+
|
93 |
+
@classmethod
|
94 |
+
def from_pretrained(cls, pretrained_model_name_or_path, **kw):
|
95 |
+
if os.path.isfile(pretrained_model_name_or_path):
|
96 |
+
return load_model(pretrained_model_name_or_path, device="cpu")
|
97 |
+
else:
|
98 |
+
return super(AsymmetricCroCo3DStereo, cls).from_pretrained(
|
99 |
+
pretrained_model_name_or_path, **kw
|
100 |
+
)
|
101 |
+
|
102 |
+
def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
|
103 |
+
self.patch_embed = get_patch_embed(
|
104 |
+
self.patch_embed_cls, img_size, patch_size, enc_embed_dim
|
105 |
+
)
|
106 |
+
|
107 |
+
def load_state_dict(self, ckpt, **kw):
|
108 |
+
# duplicate all weights for the second decoder if not present
|
109 |
+
new_ckpt = dict(ckpt)
|
110 |
+
if not any(k.startswith("dec_blocks2") for k in ckpt):
|
111 |
+
for key, value in ckpt.items():
|
112 |
+
if key.startswith("dec_blocks"):
|
113 |
+
new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value
|
114 |
+
return super().load_state_dict(new_ckpt, **kw)
|
115 |
+
|
116 |
+
def set_freeze(self, freeze): # this is for use by downstream models
|
117 |
+
self.freeze = freeze
|
118 |
+
to_be_frozen = {
|
119 |
+
"none": [],
|
120 |
+
"mask": [self.mask_token],
|
121 |
+
"encoder": [self.mask_token, self.patch_embed, self.enc_blocks],
|
122 |
+
}
|
123 |
+
freeze_all_params(to_be_frozen[freeze])
|
124 |
+
|
125 |
+
def _set_prediction_head(self, *args, **kwargs):
|
126 |
+
"""No prediction head"""
|
127 |
+
return
|
128 |
+
|
129 |
+
def set_downstream_head(
|
130 |
+
self,
|
131 |
+
output_mode,
|
132 |
+
head_type,
|
133 |
+
landscape_only,
|
134 |
+
depth_mode,
|
135 |
+
conf_mode,
|
136 |
+
patch_size,
|
137 |
+
img_size,
|
138 |
+
**kw,
|
139 |
+
):
|
140 |
+
assert (
|
141 |
+
img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
|
142 |
+
), f"{img_size=} must be multiple of {patch_size=}"
|
143 |
+
self.output_mode = output_mode
|
144 |
+
self.head_type = head_type
|
145 |
+
self.depth_mode = depth_mode
|
146 |
+
self.conf_mode = conf_mode
|
147 |
+
# allocate heads
|
148 |
+
self.downstream_head1 = head_factory(
|
149 |
+
head_type, output_mode, self, has_conf=bool(conf_mode)
|
150 |
+
)
|
151 |
+
self.downstream_head2 = head_factory(
|
152 |
+
head_type, output_mode, self, has_conf=bool(conf_mode)
|
153 |
+
)
|
154 |
+
# magic wrapper
|
155 |
+
self.head1 = transpose_to_landscape(
|
156 |
+
self.downstream_head1, activate=landscape_only
|
157 |
+
)
|
158 |
+
self.head2 = transpose_to_landscape(
|
159 |
+
self.downstream_head2, activate=landscape_only
|
160 |
+
)
|
161 |
+
|
162 |
+
def _encode_image(self, image, true_shape):
|
163 |
+
# embed the image into patches (x has size B x Npatches x C)
|
164 |
+
x, pos = self.patch_embed(image, true_shape=true_shape)
|
165 |
+
|
166 |
+
# add positional embedding without cls token
|
167 |
+
assert self.enc_pos_embed is None
|
168 |
+
|
169 |
+
# now apply the transformer encoder and normalization
|
170 |
+
for blk in self.enc_blocks:
|
171 |
+
x = blk(x, pos)
|
172 |
+
|
173 |
+
x = self.enc_norm(x)
|
174 |
+
return x, pos, None
|
175 |
+
|
176 |
+
def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
|
177 |
+
if img1.shape[-2:] == img2.shape[-2:]:
|
178 |
+
out, pos, _ = self._encode_image(
|
179 |
+
torch.cat((img1, img2), dim=0),
|
180 |
+
torch.cat((true_shape1, true_shape2), dim=0),
|
181 |
+
)
|
182 |
+
out, out2 = out.chunk(2, dim=0)
|
183 |
+
pos, pos2 = pos.chunk(2, dim=0)
|
184 |
+
else:
|
185 |
+
out, pos, _ = self._encode_image(img1, true_shape1)
|
186 |
+
out2, pos2, _ = self._encode_image(img2, true_shape2)
|
187 |
+
return out, out2, pos, pos2
|
188 |
+
|
189 |
+
def _encode_symmetrized(self, view1, view2):
|
190 |
+
img1 = view1["img"]
|
191 |
+
img2 = view2["img"]
|
192 |
+
B = img1.shape[0]
|
193 |
+
# Recover true_shape when available, otherwise assume that the img shape is the true one
|
194 |
+
shape1 = view1.get(
|
195 |
+
"true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1)
|
196 |
+
)
|
197 |
+
shape2 = view2.get(
|
198 |
+
"true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1)
|
199 |
+
)
|
200 |
+
# warning! maybe the images have different portrait/landscape orientations
|
201 |
+
|
202 |
+
if is_symmetrized(view1, view2):
|
203 |
+
# computing half of forward pass!'
|
204 |
+
feat1, feat2, pos1, pos2 = self._encode_image_pairs(
|
205 |
+
img1[::2], img2[::2], shape1[::2], shape2[::2]
|
206 |
+
)
|
207 |
+
feat1, feat2 = interleave(feat1, feat2)
|
208 |
+
pos1, pos2 = interleave(pos1, pos2)
|
209 |
+
else:
|
210 |
+
feat1, feat2, pos1, pos2 = self._encode_image_pairs(
|
211 |
+
img1, img2, shape1, shape2
|
212 |
+
)
|
213 |
+
|
214 |
+
return (shape1, shape2), (feat1, feat2), (pos1, pos2)
|
215 |
+
|
216 |
+
def _decoder(self, f1, pos1, f2, pos2):
|
217 |
+
final_output = [(f1, f2)] # before projection
|
218 |
+
|
219 |
+
# project to decoder dim
|
220 |
+
f1 = self.decoder_embed(f1)
|
221 |
+
f2 = self.decoder_embed(f2)
|
222 |
+
|
223 |
+
final_output.append((f1, f2))
|
224 |
+
for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
|
225 |
+
# img1 side
|
226 |
+
f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
|
227 |
+
# img2 side
|
228 |
+
f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
|
229 |
+
# store the result
|
230 |
+
final_output.append((f1, f2))
|
231 |
+
|
232 |
+
# normalize last output
|
233 |
+
del final_output[1] # duplicate with final_output[0]
|
234 |
+
final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
|
235 |
+
return zip(*final_output)
|
236 |
+
|
237 |
+
def _downstream_head(self, head_num, decout, img_shape):
|
238 |
+
B, S, D = decout[-1].shape
|
239 |
+
# img_shape = tuple(map(int, img_shape))
|
240 |
+
head = getattr(self, f"head{head_num}")
|
241 |
+
return head(decout, img_shape)
|
242 |
+
|
243 |
+
def forward(self, view1, view2):
|
244 |
+
# encode the two images --> B,S,D
|
245 |
+
(shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(
|
246 |
+
view1, view2
|
247 |
+
)
|
248 |
+
|
249 |
+
# combine all ref images into object-centric representation
|
250 |
+
dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
|
251 |
+
|
252 |
+
with torch.cuda.amp.autocast(enabled=False):
|
253 |
+
res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
|
254 |
+
res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
|
255 |
+
|
256 |
+
res2["pts3d_in_other_view"] = res2.pop(
|
257 |
+
"pts3d"
|
258 |
+
) # predict view2's pts3d in view1's frame
|
259 |
+
return res1, res2
|
mini_dust3r/optim_factory.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# optimization functions
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
|
9 |
+
def adjust_learning_rate_by_lr(optimizer, lr):
|
10 |
+
for param_group in optimizer.param_groups:
|
11 |
+
if "lr_scale" in param_group:
|
12 |
+
param_group["lr"] = lr * param_group["lr_scale"]
|
13 |
+
else:
|
14 |
+
param_group["lr"] = lr
|
mini_dust3r/patch_embed.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# PatchEmbed implementation for DUST3R,
|
6 |
+
# in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
|
7 |
+
# --------------------------------------------------------
|
8 |
+
import torch
|
9 |
+
from mini_dust3r.croco.blocks import PatchEmbed
|
10 |
+
|
11 |
+
|
12 |
+
def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
|
13 |
+
assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
|
14 |
+
patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
|
15 |
+
return patch_embed
|
16 |
+
|
17 |
+
|
18 |
+
class PatchEmbedDust3R(PatchEmbed):
|
19 |
+
def forward(self, x, **kw):
|
20 |
+
B, C, H, W = x.shape
|
21 |
+
assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
|
22 |
+
assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
|
23 |
+
x = self.proj(x)
|
24 |
+
pos = self.position_getter(B, x.size(2), x.size(3), x.device)
|
25 |
+
if self.flatten:
|
26 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
27 |
+
x = self.norm(x)
|
28 |
+
return x, pos
|
29 |
+
|
30 |
+
|
31 |
+
class ManyAR_PatchEmbed (PatchEmbed):
|
32 |
+
""" Handle images with non-square aspect ratio.
|
33 |
+
All images in the same batch have the same aspect ratio.
|
34 |
+
true_shape = [(height, width) ...] indicates the actual shape of each image.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
|
38 |
+
self.embed_dim = embed_dim
|
39 |
+
super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
|
40 |
+
|
41 |
+
def forward(self, img, true_shape):
|
42 |
+
B, C, H, W = img.shape
|
43 |
+
assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
|
44 |
+
assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
|
45 |
+
assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
|
46 |
+
assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
|
47 |
+
|
48 |
+
# size expressed in tokens
|
49 |
+
W //= self.patch_size[0]
|
50 |
+
H //= self.patch_size[1]
|
51 |
+
n_tokens = H * W
|
52 |
+
|
53 |
+
height, width = true_shape.T
|
54 |
+
is_landscape = (width >= height)
|
55 |
+
is_portrait = ~is_landscape
|
56 |
+
|
57 |
+
# allocate result
|
58 |
+
x = img.new_zeros((B, n_tokens, self.embed_dim))
|
59 |
+
pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
|
60 |
+
|
61 |
+
# linear projection, transposed if necessary
|
62 |
+
x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
|
63 |
+
x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
|
64 |
+
|
65 |
+
pos[is_landscape] = self.position_getter(1, H, W, pos.device)
|
66 |
+
pos[is_portrait] = self.position_getter(1, W, H, pos.device)
|
67 |
+
|
68 |
+
x = self.norm(x)
|
69 |
+
return x, pos
|
mini_dust3r/post_process.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilities for interpreting the DUST3R output
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from mini_dust3r.utils.geometry import xy_grid
|
10 |
+
|
11 |
+
|
12 |
+
def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf):
|
13 |
+
""" Reprojection method, for when the absolute depth is known:
|
14 |
+
1) estimate the camera focal using a robust estimator
|
15 |
+
2) reproject points onto true rays, minimizing a certain error
|
16 |
+
"""
|
17 |
+
B, H, W, THREE = pts3d.shape
|
18 |
+
assert THREE == 3
|
19 |
+
|
20 |
+
# centered pixel grid
|
21 |
+
pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
|
22 |
+
pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
|
23 |
+
|
24 |
+
if focal_mode == 'median':
|
25 |
+
with torch.no_grad():
|
26 |
+
# direct estimation of focal
|
27 |
+
u, v = pixels.unbind(dim=-1)
|
28 |
+
x, y, z = pts3d.unbind(dim=-1)
|
29 |
+
fx_votes = (u * z) / x
|
30 |
+
fy_votes = (v * z) / y
|
31 |
+
|
32 |
+
# assume square pixels, hence same focal for X and Y
|
33 |
+
f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
|
34 |
+
focal = torch.nanmedian(f_votes, dim=-1).values
|
35 |
+
|
36 |
+
elif focal_mode == 'weiszfeld':
|
37 |
+
# init focal with l2 closed form
|
38 |
+
# we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
|
39 |
+
xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
|
40 |
+
|
41 |
+
dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
|
42 |
+
dot_xy_xy = xy_over_z.square().sum(dim=-1)
|
43 |
+
|
44 |
+
focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
|
45 |
+
|
46 |
+
# iterative re-weighted least-squares
|
47 |
+
for iter in range(10):
|
48 |
+
# re-weighting by inverse of distance
|
49 |
+
dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
|
50 |
+
# print(dis.nanmean(-1))
|
51 |
+
w = dis.clip(min=1e-8).reciprocal()
|
52 |
+
# update the scaling with the new weights
|
53 |
+
focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
|
54 |
+
else:
|
55 |
+
raise ValueError(f'bad {focal_mode=}')
|
56 |
+
|
57 |
+
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
|
58 |
+
focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
|
59 |
+
# print(focal)
|
60 |
+
return focal
|
mini_dust3r/utils/device.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilitary functions for DUSt3R
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def todevice(batch, device, callback=None, non_blocking=False):
|
12 |
+
''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
|
13 |
+
|
14 |
+
batch: list, tuple, dict of tensors or other things
|
15 |
+
device: pytorch device or 'numpy'
|
16 |
+
callback: function that would be called on every sub-elements.
|
17 |
+
'''
|
18 |
+
if callback:
|
19 |
+
batch = callback(batch)
|
20 |
+
|
21 |
+
if isinstance(batch, dict):
|
22 |
+
return {k: todevice(v, device) for k, v in batch.items()}
|
23 |
+
|
24 |
+
if isinstance(batch, (tuple, list)):
|
25 |
+
return type(batch)(todevice(x, device) for x in batch)
|
26 |
+
|
27 |
+
x = batch
|
28 |
+
if device == 'numpy':
|
29 |
+
if isinstance(x, torch.Tensor):
|
30 |
+
x = x.detach().cpu().numpy()
|
31 |
+
elif x is not None:
|
32 |
+
if isinstance(x, np.ndarray):
|
33 |
+
x = torch.from_numpy(x)
|
34 |
+
if torch.is_tensor(x):
|
35 |
+
x = x.to(device, non_blocking=non_blocking)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
to_device = todevice # alias
|
40 |
+
|
41 |
+
|
42 |
+
def to_numpy(x): return todevice(x, 'numpy')
|
43 |
+
def to_cpu(x): return todevice(x, 'cpu')
|
44 |
+
def to_cuda(x): return todevice(x, 'cuda')
|
45 |
+
|
46 |
+
|
47 |
+
def collate_with_cat(whatever, lists=False):
|
48 |
+
if isinstance(whatever, dict):
|
49 |
+
return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
|
50 |
+
|
51 |
+
elif isinstance(whatever, (tuple, list)):
|
52 |
+
if len(whatever) == 0:
|
53 |
+
return whatever
|
54 |
+
elem = whatever[0]
|
55 |
+
T = type(whatever)
|
56 |
+
|
57 |
+
if elem is None:
|
58 |
+
return None
|
59 |
+
if isinstance(elem, (bool, float, int, str)):
|
60 |
+
return whatever
|
61 |
+
if isinstance(elem, tuple):
|
62 |
+
return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
|
63 |
+
if isinstance(elem, dict):
|
64 |
+
return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
|
65 |
+
|
66 |
+
if isinstance(elem, torch.Tensor):
|
67 |
+
return listify(whatever) if lists else torch.cat(whatever)
|
68 |
+
if isinstance(elem, np.ndarray):
|
69 |
+
return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
|
70 |
+
|
71 |
+
# otherwise, we just chain lists
|
72 |
+
return sum(whatever, T())
|
73 |
+
|
74 |
+
|
75 |
+
def listify(elems):
|
76 |
+
return [x for e in elems for x in e]
|
mini_dust3r/utils/geometry.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# geometry utilitary functions
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from scipy.spatial import cKDTree as KDTree
|
10 |
+
|
11 |
+
from mini_dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
|
12 |
+
from mini_dust3r.utils.device import to_numpy
|
13 |
+
|
14 |
+
|
15 |
+
def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
|
16 |
+
""" Output a (H,W,2) array of int32
|
17 |
+
with output[j,i,0] = i + origin[0]
|
18 |
+
output[j,i,1] = j + origin[1]
|
19 |
+
"""
|
20 |
+
if device is None:
|
21 |
+
# numpy
|
22 |
+
arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
|
23 |
+
else:
|
24 |
+
# torch
|
25 |
+
arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
|
26 |
+
meshgrid, stack = torch.meshgrid, torch.stack
|
27 |
+
ones = lambda *a: torch.ones(*a, device=device)
|
28 |
+
|
29 |
+
tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)]
|
30 |
+
grid = meshgrid(tw, th, indexing='xy')
|
31 |
+
if homogeneous:
|
32 |
+
grid = grid + (ones((H, W)),)
|
33 |
+
if unsqueeze is not None:
|
34 |
+
grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
|
35 |
+
if cat_dim is not None:
|
36 |
+
grid = stack(grid, cat_dim)
|
37 |
+
return grid
|
38 |
+
|
39 |
+
|
40 |
+
def geotrf(Trf, pts, ncol=None, norm=False):
|
41 |
+
""" Apply a geometric transformation to a list of 3-D points.
|
42 |
+
|
43 |
+
H: 3x3 or 4x4 projection matrix (typically a Homography)
|
44 |
+
p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
|
45 |
+
|
46 |
+
ncol: int. number of columns of the result (2 or 3)
|
47 |
+
norm: float. if != 0, the resut is projected on the z=norm plane.
|
48 |
+
|
49 |
+
Returns an array of projected 2d points.
|
50 |
+
"""
|
51 |
+
assert Trf.ndim >= 2
|
52 |
+
if isinstance(Trf, np.ndarray):
|
53 |
+
pts = np.asarray(pts)
|
54 |
+
elif isinstance(Trf, torch.Tensor):
|
55 |
+
pts = torch.as_tensor(pts, dtype=Trf.dtype)
|
56 |
+
|
57 |
+
# adapt shape if necessary
|
58 |
+
output_reshape = pts.shape[:-1]
|
59 |
+
ncol = ncol or pts.shape[-1]
|
60 |
+
|
61 |
+
# optimized code
|
62 |
+
if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
|
63 |
+
Trf.ndim == 3 and pts.ndim == 4):
|
64 |
+
d = pts.shape[3]
|
65 |
+
if Trf.shape[-1] == d:
|
66 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
|
67 |
+
elif Trf.shape[-1] == d+1:
|
68 |
+
pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
|
69 |
+
else:
|
70 |
+
raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
|
71 |
+
else:
|
72 |
+
if Trf.ndim >= 3:
|
73 |
+
n = Trf.ndim-2
|
74 |
+
assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
|
75 |
+
Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
|
76 |
+
|
77 |
+
if pts.ndim > Trf.ndim:
|
78 |
+
# Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
|
79 |
+
pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
|
80 |
+
elif pts.ndim == 2:
|
81 |
+
# Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
|
82 |
+
pts = pts[:, None, :]
|
83 |
+
|
84 |
+
if pts.shape[-1]+1 == Trf.shape[-1]:
|
85 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
86 |
+
pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
|
87 |
+
elif pts.shape[-1] == Trf.shape[-1]:
|
88 |
+
Trf = Trf.swapaxes(-1, -2) # transpose Trf
|
89 |
+
pts = pts @ Trf
|
90 |
+
else:
|
91 |
+
pts = Trf @ pts.T
|
92 |
+
if pts.ndim >= 2:
|
93 |
+
pts = pts.swapaxes(-1, -2)
|
94 |
+
|
95 |
+
if norm:
|
96 |
+
pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
|
97 |
+
if norm != 1:
|
98 |
+
pts *= norm
|
99 |
+
|
100 |
+
res = pts[..., :ncol].reshape(*output_reshape, ncol)
|
101 |
+
return res
|
102 |
+
|
103 |
+
|
104 |
+
def inv(mat):
|
105 |
+
""" Invert a torch or numpy matrix
|
106 |
+
"""
|
107 |
+
if isinstance(mat, torch.Tensor):
|
108 |
+
return torch.linalg.inv(mat)
|
109 |
+
if isinstance(mat, np.ndarray):
|
110 |
+
return np.linalg.inv(mat)
|
111 |
+
raise ValueError(f'bad matrix type = {type(mat)}')
|
112 |
+
|
113 |
+
|
114 |
+
def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
|
115 |
+
"""
|
116 |
+
Args:
|
117 |
+
- depthmap (BxHxW array):
|
118 |
+
- pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
|
119 |
+
Returns:
|
120 |
+
pointmap of absolute coordinates (BxHxWx3 array)
|
121 |
+
"""
|
122 |
+
|
123 |
+
if len(depth.shape) == 4:
|
124 |
+
B, H, W, n = depth.shape
|
125 |
+
else:
|
126 |
+
B, H, W = depth.shape
|
127 |
+
n = None
|
128 |
+
|
129 |
+
if len(pseudo_focal.shape) == 3: # [B,H,W]
|
130 |
+
pseudo_focalx = pseudo_focaly = pseudo_focal
|
131 |
+
elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
|
132 |
+
pseudo_focalx = pseudo_focal[:, 0]
|
133 |
+
if pseudo_focal.shape[1] == 2:
|
134 |
+
pseudo_focaly = pseudo_focal[:, 1]
|
135 |
+
else:
|
136 |
+
pseudo_focaly = pseudo_focalx
|
137 |
+
else:
|
138 |
+
raise NotImplementedError("Error, unknown input focal shape format.")
|
139 |
+
|
140 |
+
assert pseudo_focalx.shape == depth.shape[:3]
|
141 |
+
assert pseudo_focaly.shape == depth.shape[:3]
|
142 |
+
grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
|
143 |
+
|
144 |
+
# set principal point
|
145 |
+
if pp is None:
|
146 |
+
grid_x = grid_x - (W-1)/2
|
147 |
+
grid_y = grid_y - (H-1)/2
|
148 |
+
else:
|
149 |
+
grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
|
150 |
+
grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
|
151 |
+
|
152 |
+
if n is None:
|
153 |
+
pts3d = torch.empty((B, H, W, 3), device=depth.device)
|
154 |
+
pts3d[..., 0] = depth * grid_x / pseudo_focalx
|
155 |
+
pts3d[..., 1] = depth * grid_y / pseudo_focaly
|
156 |
+
pts3d[..., 2] = depth
|
157 |
+
else:
|
158 |
+
pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
|
159 |
+
pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
|
160 |
+
pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
|
161 |
+
pts3d[..., 2, :] = depth
|
162 |
+
return pts3d
|
163 |
+
|
164 |
+
|
165 |
+
def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
|
166 |
+
"""
|
167 |
+
Args:
|
168 |
+
- depthmap (HxW array):
|
169 |
+
- camera_intrinsics: a 3x3 matrix
|
170 |
+
Returns:
|
171 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
|
172 |
+
"""
|
173 |
+
camera_intrinsics = np.float32(camera_intrinsics)
|
174 |
+
H, W = depthmap.shape
|
175 |
+
|
176 |
+
# Compute 3D ray associated with each pixel
|
177 |
+
# Strong assumption: there are no skew terms
|
178 |
+
assert camera_intrinsics[0, 1] == 0.0
|
179 |
+
assert camera_intrinsics[1, 0] == 0.0
|
180 |
+
if pseudo_focal is None:
|
181 |
+
fu = camera_intrinsics[0, 0]
|
182 |
+
fv = camera_intrinsics[1, 1]
|
183 |
+
else:
|
184 |
+
assert pseudo_focal.shape == (H, W)
|
185 |
+
fu = fv = pseudo_focal
|
186 |
+
cu = camera_intrinsics[0, 2]
|
187 |
+
cv = camera_intrinsics[1, 2]
|
188 |
+
|
189 |
+
u, v = np.meshgrid(np.arange(W), np.arange(H))
|
190 |
+
z_cam = depthmap
|
191 |
+
x_cam = (u - cu) * z_cam / fu
|
192 |
+
y_cam = (v - cv) * z_cam / fv
|
193 |
+
X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
|
194 |
+
|
195 |
+
# Mask for valid coordinates
|
196 |
+
valid_mask = (depthmap > 0.0)
|
197 |
+
return X_cam, valid_mask
|
198 |
+
|
199 |
+
|
200 |
+
def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
|
201 |
+
"""
|
202 |
+
Args:
|
203 |
+
- depthmap (HxW array):
|
204 |
+
- camera_intrinsics: a 3x3 matrix
|
205 |
+
- camera_pose: a 4x3 or 4x4 cam2world matrix
|
206 |
+
Returns:
|
207 |
+
pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
|
208 |
+
X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
|
209 |
+
|
210 |
+
# R_cam2world = np.float32(camera_params["R_cam2world"])
|
211 |
+
# t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
|
212 |
+
R_cam2world = camera_pose[:3, :3]
|
213 |
+
t_cam2world = camera_pose[:3, 3]
|
214 |
+
|
215 |
+
# Express in absolute coordinates (invalid depth values)
|
216 |
+
X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
|
217 |
+
return X_world, valid_mask
|
218 |
+
|
219 |
+
|
220 |
+
def colmap_to_opencv_intrinsics(K):
|
221 |
+
"""
|
222 |
+
Modify camera intrinsics to follow a different convention.
|
223 |
+
Coordinates of the center of the top-left pixels are by default:
|
224 |
+
- (0.5, 0.5) in Colmap
|
225 |
+
- (0,0) in OpenCV
|
226 |
+
"""
|
227 |
+
K = K.copy()
|
228 |
+
K[0, 2] -= 0.5
|
229 |
+
K[1, 2] -= 0.5
|
230 |
+
return K
|
231 |
+
|
232 |
+
|
233 |
+
def opencv_to_colmap_intrinsics(K):
|
234 |
+
"""
|
235 |
+
Modify camera intrinsics to follow a different convention.
|
236 |
+
Coordinates of the center of the top-left pixels are by default:
|
237 |
+
- (0.5, 0.5) in Colmap
|
238 |
+
- (0,0) in OpenCV
|
239 |
+
"""
|
240 |
+
K = K.copy()
|
241 |
+
K[0, 2] += 0.5
|
242 |
+
K[1, 2] += 0.5
|
243 |
+
return K
|
244 |
+
|
245 |
+
|
246 |
+
def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None):
|
247 |
+
""" renorm pointmaps pts1, pts2 with norm_mode
|
248 |
+
"""
|
249 |
+
assert pts1.ndim >= 3 and pts1.shape[-1] == 3
|
250 |
+
assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
|
251 |
+
norm_mode, dis_mode = norm_mode.split('_')
|
252 |
+
|
253 |
+
if norm_mode == 'avg':
|
254 |
+
# gather all points together (joint normalization)
|
255 |
+
nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
|
256 |
+
nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
|
257 |
+
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
258 |
+
|
259 |
+
# compute distance to origin
|
260 |
+
all_dis = all_pts.norm(dim=-1)
|
261 |
+
if dis_mode == 'dis':
|
262 |
+
pass # do nothing
|
263 |
+
elif dis_mode == 'log1p':
|
264 |
+
all_dis = torch.log1p(all_dis)
|
265 |
+
elif dis_mode == 'warp-log1p':
|
266 |
+
# actually warp input points before normalizing them
|
267 |
+
log_dis = torch.log1p(all_dis)
|
268 |
+
warp_factor = log_dis / all_dis.clip(min=1e-8)
|
269 |
+
H1, W1 = pts1.shape[1:-1]
|
270 |
+
pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1)
|
271 |
+
if pts2 is not None:
|
272 |
+
H2, W2 = pts2.shape[1:-1]
|
273 |
+
pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1)
|
274 |
+
all_dis = log_dis # this is their true distance afterwards
|
275 |
+
else:
|
276 |
+
raise ValueError(f'bad {dis_mode=}')
|
277 |
+
|
278 |
+
norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
|
279 |
+
else:
|
280 |
+
# gather all points together (joint normalization)
|
281 |
+
nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
|
282 |
+
nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
|
283 |
+
all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
|
284 |
+
|
285 |
+
# compute distance to origin
|
286 |
+
all_dis = all_pts.norm(dim=-1)
|
287 |
+
|
288 |
+
if norm_mode == 'avg':
|
289 |
+
norm_factor = all_dis.nanmean(dim=1)
|
290 |
+
elif norm_mode == 'median':
|
291 |
+
norm_factor = all_dis.nanmedian(dim=1).values.detach()
|
292 |
+
elif norm_mode == 'sqrt':
|
293 |
+
norm_factor = all_dis.sqrt().nanmean(dim=1)**2
|
294 |
+
else:
|
295 |
+
raise ValueError(f'bad {norm_mode=}')
|
296 |
+
|
297 |
+
norm_factor = norm_factor.clip(min=1e-8)
|
298 |
+
while norm_factor.ndim < pts1.ndim:
|
299 |
+
norm_factor.unsqueeze_(-1)
|
300 |
+
|
301 |
+
res = pts1 / norm_factor
|
302 |
+
if pts2 is not None:
|
303 |
+
res = (res, pts2 / norm_factor)
|
304 |
+
return res
|
305 |
+
|
306 |
+
|
307 |
+
@torch.no_grad()
|
308 |
+
def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
|
309 |
+
# set invalid points to NaN
|
310 |
+
_z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
|
311 |
+
_z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
|
312 |
+
_z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
|
313 |
+
|
314 |
+
# compute median depth overall (ignoring nans)
|
315 |
+
if quantile == 0.5:
|
316 |
+
shift_z = torch.nanmedian(_z, dim=-1).values
|
317 |
+
else:
|
318 |
+
shift_z = torch.nanquantile(_z, quantile, dim=-1)
|
319 |
+
return shift_z # (B,)
|
320 |
+
|
321 |
+
|
322 |
+
@torch.no_grad()
|
323 |
+
def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
|
324 |
+
# set invalid points to NaN
|
325 |
+
_pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
|
326 |
+
_pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
|
327 |
+
_pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
|
328 |
+
|
329 |
+
# compute median center
|
330 |
+
_center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
|
331 |
+
if z_only:
|
332 |
+
_center[..., :2] = 0 # do not center X and Y
|
333 |
+
|
334 |
+
# compute median norm
|
335 |
+
_norm = ((_pts - _center) if center else _pts).norm(dim=-1)
|
336 |
+
scale = torch.nanmedian(_norm, dim=1).values
|
337 |
+
return _center[:, None, :, :], scale[:, None, None, None]
|
338 |
+
|
339 |
+
|
340 |
+
def find_reciprocal_matches(P1, P2):
|
341 |
+
"""
|
342 |
+
returns 3 values:
|
343 |
+
1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
|
344 |
+
2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
|
345 |
+
3 - reciprocal_in_P2.sum(): the number of matches
|
346 |
+
"""
|
347 |
+
tree1 = KDTree(P1)
|
348 |
+
tree2 = KDTree(P2)
|
349 |
+
|
350 |
+
_, nn1_in_P2 = tree2.query(P1, workers=8)
|
351 |
+
_, nn2_in_P1 = tree1.query(P2, workers=8)
|
352 |
+
|
353 |
+
reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
|
354 |
+
reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
|
355 |
+
assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
|
356 |
+
return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
|
357 |
+
|
358 |
+
|
359 |
+
def get_med_dist_between_poses(poses):
|
360 |
+
from scipy.spatial.distance import pdist
|
361 |
+
return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
|
mini_dust3r/utils/image.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilitary functions about images (loading/converting...)
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
import PIL.Image
|
11 |
+
from PIL.ImageOps import exif_transpose
|
12 |
+
import torchvision.transforms as tvf
|
13 |
+
|
14 |
+
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
15 |
+
import cv2 # noqa
|
16 |
+
from typing import Literal, TypedDict
|
17 |
+
from jaxtyping import Float32, Int32
|
18 |
+
|
19 |
+
try:
|
20 |
+
from pillow_heif import register_heif_opener # noqa
|
21 |
+
|
22 |
+
register_heif_opener()
|
23 |
+
heif_support_enabled = True
|
24 |
+
except ImportError:
|
25 |
+
heif_support_enabled = False
|
26 |
+
|
27 |
+
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
28 |
+
|
29 |
+
|
30 |
+
class ImageDict(TypedDict):
|
31 |
+
img: Float32[torch.Tensor, "b c h w"]
|
32 |
+
true_shape: tuple[int, int] | Int32[torch.Tensor, "b 2"]
|
33 |
+
idx: int | list[int]
|
34 |
+
instance: str | list[str]
|
35 |
+
|
36 |
+
|
37 |
+
def imread_cv2(path, options=cv2.IMREAD_COLOR):
|
38 |
+
"""Open an image or a depthmap with opencv-python."""
|
39 |
+
if path.endswith((".exr", "EXR")):
|
40 |
+
options = cv2.IMREAD_ANYDEPTH
|
41 |
+
img = cv2.imread(path, options)
|
42 |
+
if img is None:
|
43 |
+
raise IOError(f"Could not load image={path} with {options=}")
|
44 |
+
if img.ndim == 3:
|
45 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
46 |
+
return img
|
47 |
+
|
48 |
+
|
49 |
+
def rgb(ftensor, true_shape=None):
|
50 |
+
if isinstance(ftensor, list):
|
51 |
+
return [rgb(x, true_shape=true_shape) for x in ftensor]
|
52 |
+
if isinstance(ftensor, torch.Tensor):
|
53 |
+
ftensor = ftensor.detach().cpu().numpy() # H,W,3
|
54 |
+
if ftensor.ndim == 3 and ftensor.shape[0] == 3:
|
55 |
+
ftensor = ftensor.transpose(1, 2, 0)
|
56 |
+
elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
|
57 |
+
ftensor = ftensor.transpose(0, 2, 3, 1)
|
58 |
+
if true_shape is not None:
|
59 |
+
H, W = true_shape
|
60 |
+
ftensor = ftensor[:H, :W]
|
61 |
+
if ftensor.dtype == np.uint8:
|
62 |
+
img = np.float32(ftensor) / 255
|
63 |
+
else:
|
64 |
+
img = (ftensor * 0.5) + 0.5
|
65 |
+
return img.clip(min=0, max=1)
|
66 |
+
|
67 |
+
|
68 |
+
def _resize_pil_image(img, long_edge_size):
|
69 |
+
S = max(img.size)
|
70 |
+
if S > long_edge_size:
|
71 |
+
interp = PIL.Image.LANCZOS
|
72 |
+
elif S <= long_edge_size:
|
73 |
+
interp = PIL.Image.BICUBIC
|
74 |
+
new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
|
75 |
+
return img.resize(new_size, interp)
|
76 |
+
|
77 |
+
|
78 |
+
def load_images(
|
79 |
+
folder_or_list: str | list,
|
80 |
+
size: Literal[224, 512],
|
81 |
+
square_ok: bool = False,
|
82 |
+
verbose: bool = True,
|
83 |
+
) -> list[ImageDict]:
|
84 |
+
"""open and convert all images in a list or folder to proper input format for DUSt3R"""
|
85 |
+
if isinstance(folder_or_list, str):
|
86 |
+
if verbose:
|
87 |
+
print(f">> Loading images from {folder_or_list}")
|
88 |
+
root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
|
89 |
+
|
90 |
+
elif isinstance(folder_or_list, list):
|
91 |
+
if verbose:
|
92 |
+
print(f">> Loading a list of {len(folder_or_list)} images")
|
93 |
+
root, folder_content = "", folder_or_list
|
94 |
+
|
95 |
+
else:
|
96 |
+
raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
|
97 |
+
|
98 |
+
supported_images_extensions = [".jpg", ".jpeg", ".png"]
|
99 |
+
if heif_support_enabled:
|
100 |
+
supported_images_extensions += [".heic", ".heif"]
|
101 |
+
supported_images_extensions = tuple(supported_images_extensions)
|
102 |
+
|
103 |
+
imgs = []
|
104 |
+
for path in folder_content:
|
105 |
+
if not path.lower().endswith(supported_images_extensions):
|
106 |
+
continue
|
107 |
+
img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
|
108 |
+
W1, H1 = img.size
|
109 |
+
if size == 224:
|
110 |
+
# resize short side to 224 (then crop)
|
111 |
+
img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
|
112 |
+
else:
|
113 |
+
# resize long side to 512
|
114 |
+
img = _resize_pil_image(img, size)
|
115 |
+
W, H = img.size
|
116 |
+
cx, cy = W // 2, H // 2
|
117 |
+
if size == 224:
|
118 |
+
half = min(cx, cy)
|
119 |
+
img = img.crop((cx - half, cy - half, cx + half, cy + half))
|
120 |
+
else:
|
121 |
+
halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
|
122 |
+
if not (square_ok) and W == H:
|
123 |
+
halfh = 3 * halfw / 4
|
124 |
+
img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
|
125 |
+
|
126 |
+
W2, H2 = img.size
|
127 |
+
if verbose:
|
128 |
+
print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
|
129 |
+
imgs.append(
|
130 |
+
dict(
|
131 |
+
img=ImgNorm(img)[None],
|
132 |
+
true_shape=np.int32([img.size[::-1]]),
|
133 |
+
idx=len(imgs),
|
134 |
+
instance=str(len(imgs)),
|
135 |
+
)
|
136 |
+
)
|
137 |
+
|
138 |
+
assert imgs, "no images foud at " + root
|
139 |
+
if verbose:
|
140 |
+
print(f" (Found {len(imgs)} images)")
|
141 |
+
return imgs
|
mini_dust3r/utils/misc.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# utilitary functions for DUSt3R
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
def fill_default_args(kwargs, func):
|
11 |
+
import inspect # a bit hacky but it works reliably
|
12 |
+
signature = inspect.signature(func)
|
13 |
+
|
14 |
+
for k, v in signature.parameters.items():
|
15 |
+
if v.default is inspect.Parameter.empty:
|
16 |
+
continue
|
17 |
+
kwargs.setdefault(k, v.default)
|
18 |
+
|
19 |
+
return kwargs
|
20 |
+
|
21 |
+
|
22 |
+
def freeze_all_params(modules):
|
23 |
+
for module in modules:
|
24 |
+
try:
|
25 |
+
for n, param in module.named_parameters():
|
26 |
+
param.requires_grad = False
|
27 |
+
except AttributeError:
|
28 |
+
# module is directly a parameter
|
29 |
+
module.requires_grad = False
|
30 |
+
|
31 |
+
|
32 |
+
def is_symmetrized(gt1, gt2):
|
33 |
+
x = gt1['instance']
|
34 |
+
y = gt2['instance']
|
35 |
+
if len(x) == len(y) and len(x) == 1:
|
36 |
+
return False # special case of batchsize 1
|
37 |
+
ok = True
|
38 |
+
for i in range(0, len(x), 2):
|
39 |
+
ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
|
40 |
+
return ok
|
41 |
+
|
42 |
+
|
43 |
+
def flip(tensor):
|
44 |
+
""" flip so that tensor[0::2] <=> tensor[1::2] """
|
45 |
+
return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
|
46 |
+
|
47 |
+
|
48 |
+
def interleave(tensor1, tensor2):
|
49 |
+
res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
|
50 |
+
res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
|
51 |
+
return res1, res2
|
52 |
+
|
53 |
+
|
54 |
+
def transpose_to_landscape(head, activate=True):
|
55 |
+
""" Predict in the correct aspect-ratio,
|
56 |
+
then transpose the result in landscape
|
57 |
+
and stack everything back together.
|
58 |
+
"""
|
59 |
+
def wrapper_no(decout, true_shape):
|
60 |
+
B = len(true_shape)
|
61 |
+
assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
|
62 |
+
H, W = true_shape[0].cpu().tolist()
|
63 |
+
res = head(decout, (H, W))
|
64 |
+
return res
|
65 |
+
|
66 |
+
def wrapper_yes(decout, true_shape):
|
67 |
+
B = len(true_shape)
|
68 |
+
# by definition, the batch is in landscape mode so W >= H
|
69 |
+
H, W = int(true_shape.min()), int(true_shape.max())
|
70 |
+
|
71 |
+
height, width = true_shape.T
|
72 |
+
is_landscape = (width >= height)
|
73 |
+
is_portrait = ~is_landscape
|
74 |
+
|
75 |
+
# true_shape = true_shape.cpu()
|
76 |
+
if is_landscape.all():
|
77 |
+
return head(decout, (H, W))
|
78 |
+
if is_portrait.all():
|
79 |
+
return transposed(head(decout, (W, H)))
|
80 |
+
|
81 |
+
# batch is a mix of both portraint & landscape
|
82 |
+
def selout(ar): return [d[ar] for d in decout]
|
83 |
+
l_result = head(selout(is_landscape), (H, W))
|
84 |
+
p_result = transposed(head(selout(is_portrait), (W, H)))
|
85 |
+
|
86 |
+
# allocate full result
|
87 |
+
result = {}
|
88 |
+
for k in l_result | p_result:
|
89 |
+
x = l_result[k].new(B, *l_result[k].shape[1:])
|
90 |
+
x[is_landscape] = l_result[k]
|
91 |
+
x[is_portrait] = p_result[k]
|
92 |
+
result[k] = x
|
93 |
+
|
94 |
+
return result
|
95 |
+
|
96 |
+
return wrapper_yes if activate else wrapper_no
|
97 |
+
|
98 |
+
|
99 |
+
def transposed(dic):
|
100 |
+
return {k: v.swapaxes(1, 2) for k, v in dic.items()}
|
101 |
+
|
102 |
+
|
103 |
+
def invalid_to_nans(arr, valid_mask, ndim=999):
|
104 |
+
if valid_mask is not None:
|
105 |
+
arr = arr.clone()
|
106 |
+
arr[~valid_mask] = float('nan')
|
107 |
+
if arr.ndim > ndim:
|
108 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
109 |
+
return arr
|
110 |
+
|
111 |
+
|
112 |
+
def invalid_to_zeros(arr, valid_mask, ndim=999):
|
113 |
+
if valid_mask is not None:
|
114 |
+
arr = arr.clone()
|
115 |
+
arr[~valid_mask] = 0
|
116 |
+
nnz = valid_mask.view(len(valid_mask), -1).sum(1)
|
117 |
+
else:
|
118 |
+
nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
|
119 |
+
if arr.ndim > ndim:
|
120 |
+
arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
|
121 |
+
return arr, nnz
|
mini_dust3r/viz.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# Visualization utilities using trimesh
|
6 |
+
# --------------------------------------------------------
|
7 |
+
import PIL.Image
|
8 |
+
import numpy as np
|
9 |
+
from scipy.spatial.transform import Rotation
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from mini_dust3r.utils.geometry import geotrf, get_med_dist_between_poses
|
13 |
+
from mini_dust3r.utils.device import to_numpy
|
14 |
+
from mini_dust3r.utils.image import rgb
|
15 |
+
|
16 |
+
try:
|
17 |
+
import trimesh
|
18 |
+
except ImportError:
|
19 |
+
print('/!\\ module trimesh is not installed, cannot visualize results /!\\')
|
20 |
+
|
21 |
+
|
22 |
+
def cat_3d(vecs):
|
23 |
+
if isinstance(vecs, (np.ndarray, torch.Tensor)):
|
24 |
+
vecs = [vecs]
|
25 |
+
return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
|
26 |
+
|
27 |
+
|
28 |
+
def show_raw_pointcloud(pts3d, colors, point_size=2):
|
29 |
+
scene = trimesh.Scene()
|
30 |
+
|
31 |
+
pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
|
32 |
+
scene.add_geometry(pct)
|
33 |
+
|
34 |
+
scene.show(line_settings={'point_size': point_size})
|
35 |
+
|
36 |
+
|
37 |
+
def pts3d_to_trimesh(img, pts3d, valid=None):
|
38 |
+
H, W, THREE = img.shape
|
39 |
+
assert THREE == 3
|
40 |
+
assert img.shape == pts3d.shape
|
41 |
+
|
42 |
+
vertices = pts3d.reshape(-1, 3)
|
43 |
+
|
44 |
+
# make squares: each pixel == 2 triangles
|
45 |
+
idx = np.arange(len(vertices)).reshape(H, W)
|
46 |
+
idx1 = idx[:-1, :-1].ravel() # top-left corner
|
47 |
+
idx2 = idx[:-1, +1:].ravel() # right-left corner
|
48 |
+
idx3 = idx[+1:, :-1].ravel() # bottom-left corner
|
49 |
+
idx4 = idx[+1:, +1:].ravel() # bottom-right corner
|
50 |
+
faces = np.concatenate((
|
51 |
+
np.c_[idx1, idx2, idx3],
|
52 |
+
np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
|
53 |
+
np.c_[idx2, idx3, idx4],
|
54 |
+
np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
|
55 |
+
), axis=0)
|
56 |
+
|
57 |
+
# prepare triangle colors
|
58 |
+
face_colors = np.concatenate((
|
59 |
+
img[:-1, :-1].reshape(-1, 3),
|
60 |
+
img[:-1, :-1].reshape(-1, 3),
|
61 |
+
img[+1:, +1:].reshape(-1, 3),
|
62 |
+
img[+1:, +1:].reshape(-1, 3)
|
63 |
+
), axis=0)
|
64 |
+
|
65 |
+
# remove invalid faces
|
66 |
+
if valid is not None:
|
67 |
+
assert valid.shape == (H, W)
|
68 |
+
valid_idxs = valid.ravel()
|
69 |
+
valid_faces = valid_idxs[faces].all(axis=-1)
|
70 |
+
faces = faces[valid_faces]
|
71 |
+
face_colors = face_colors[valid_faces]
|
72 |
+
|
73 |
+
assert len(faces) == len(face_colors)
|
74 |
+
return dict(vertices=vertices, face_colors=face_colors, faces=faces)
|
75 |
+
|
76 |
+
|
77 |
+
def cat_meshes(meshes):
|
78 |
+
vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
|
79 |
+
n_vertices = np.cumsum([0]+[len(v) for v in vertices])
|
80 |
+
for i in range(len(faces)):
|
81 |
+
faces[i][:] += n_vertices[i]
|
82 |
+
|
83 |
+
vertices = np.concatenate(vertices)
|
84 |
+
colors = np.concatenate(colors)
|
85 |
+
faces = np.concatenate(faces)
|
86 |
+
return dict(vertices=vertices, face_colors=colors, faces=faces)
|
87 |
+
|
88 |
+
|
89 |
+
def show_duster_pairs(view1, view2, pred1, pred2):
|
90 |
+
import matplotlib.pyplot as pl
|
91 |
+
pl.ion()
|
92 |
+
|
93 |
+
for e in range(len(view1['instance'])):
|
94 |
+
i = view1['idx'][e]
|
95 |
+
j = view2['idx'][e]
|
96 |
+
img1 = rgb(view1['img'][e])
|
97 |
+
img2 = rgb(view2['img'][e])
|
98 |
+
conf1 = pred1['conf'][e].squeeze()
|
99 |
+
conf2 = pred2['conf'][e].squeeze()
|
100 |
+
score = conf1.mean()*conf2.mean()
|
101 |
+
print(f">> Showing pair #{e} {i}-{j} {score=:g}")
|
102 |
+
pl.clf()
|
103 |
+
pl.subplot(221).imshow(img1)
|
104 |
+
pl.subplot(223).imshow(img2)
|
105 |
+
pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
|
106 |
+
pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
|
107 |
+
pts1 = pred1['pts3d'][e]
|
108 |
+
pts2 = pred2['pts3d_in_other_view'][e]
|
109 |
+
pl.subplots_adjust(0, 0, 1, 1, 0, 0)
|
110 |
+
if input('show pointcloud? (y/n) ') == 'y':
|
111 |
+
show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
|
112 |
+
|
113 |
+
|
114 |
+
def auto_cam_size(im_poses):
|
115 |
+
return 0.1 * get_med_dist_between_poses(im_poses)
|
116 |
+
|
117 |
+
|
118 |
+
class SceneViz:
|
119 |
+
def __init__(self):
|
120 |
+
self.scene = trimesh.Scene()
|
121 |
+
|
122 |
+
def add_pointcloud(self, pts3d, color, mask=None):
|
123 |
+
pts3d = to_numpy(pts3d)
|
124 |
+
mask = to_numpy(mask)
|
125 |
+
if mask is None:
|
126 |
+
mask = [slice(None)] * len(pts3d)
|
127 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
128 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3))
|
129 |
+
|
130 |
+
if isinstance(color, (list, np.ndarray, torch.Tensor)):
|
131 |
+
color = to_numpy(color)
|
132 |
+
col = np.concatenate([p[m] for p, m in zip(color, mask)])
|
133 |
+
assert col.shape == pts.shape
|
134 |
+
pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
|
135 |
+
else:
|
136 |
+
assert len(color) == 3
|
137 |
+
pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
|
138 |
+
|
139 |
+
self.scene.add_geometry(pct)
|
140 |
+
return self
|
141 |
+
|
142 |
+
def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):
|
143 |
+
pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
|
144 |
+
add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size)
|
145 |
+
return self
|
146 |
+
|
147 |
+
def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):
|
148 |
+
def get(arr, idx): return None if arr is None else arr[idx]
|
149 |
+
for i, pose_c2w in enumerate(poses):
|
150 |
+
self.add_camera(pose_c2w, get(focals, i), image=get(images, i),
|
151 |
+
color=get(colors, i), imsize=get(imsizes, i), **kw)
|
152 |
+
return self
|
153 |
+
|
154 |
+
def show(self, point_size=2):
|
155 |
+
self.scene.show(line_settings={'point_size': point_size})
|
156 |
+
|
157 |
+
|
158 |
+
def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
|
159 |
+
point_size=2, cam_size=0.05, cam_color=None):
|
160 |
+
""" Visualization of a pointcloud with cameras
|
161 |
+
imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
|
162 |
+
pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
|
163 |
+
focals = (N,) or N-size list of [focal, ...]
|
164 |
+
cams2world = (N,4,4) or N-size list of [(4,4), ...]
|
165 |
+
"""
|
166 |
+
assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
|
167 |
+
pts3d = to_numpy(pts3d)
|
168 |
+
imgs = to_numpy(imgs)
|
169 |
+
focals = to_numpy(focals)
|
170 |
+
cams2world = to_numpy(cams2world)
|
171 |
+
|
172 |
+
scene = trimesh.Scene()
|
173 |
+
|
174 |
+
# full pointcloud
|
175 |
+
pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
|
176 |
+
col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
|
177 |
+
pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
|
178 |
+
scene.add_geometry(pct)
|
179 |
+
|
180 |
+
# add each camera
|
181 |
+
for i, pose_c2w in enumerate(cams2world):
|
182 |
+
if isinstance(cam_color, list):
|
183 |
+
camera_edge_color = cam_color[i]
|
184 |
+
else:
|
185 |
+
camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
|
186 |
+
add_scene_cam(scene, pose_c2w, camera_edge_color,
|
187 |
+
imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)
|
188 |
+
|
189 |
+
scene.show(line_settings={'point_size': point_size})
|
190 |
+
|
191 |
+
|
192 |
+
def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
|
193 |
+
|
194 |
+
if image is not None:
|
195 |
+
H, W, THREE = image.shape
|
196 |
+
assert THREE == 3
|
197 |
+
if image.dtype != np.uint8:
|
198 |
+
image = np.uint8(255*image)
|
199 |
+
elif imsize is not None:
|
200 |
+
W, H = imsize
|
201 |
+
elif focal is not None:
|
202 |
+
H = W = focal / 1.1
|
203 |
+
else:
|
204 |
+
H = W = 1
|
205 |
+
|
206 |
+
if focal is None:
|
207 |
+
focal = min(H, W) * 1.1 # default value
|
208 |
+
elif isinstance(focal, np.ndarray):
|
209 |
+
focal = focal[0]
|
210 |
+
|
211 |
+
# create fake camera
|
212 |
+
height = focal * screen_width / H
|
213 |
+
width = screen_width * 0.5**0.5
|
214 |
+
rot45 = np.eye(4)
|
215 |
+
rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
|
216 |
+
rot45[2, 3] = -height # set the tip of the cone = optical center
|
217 |
+
aspect_ratio = np.eye(4)
|
218 |
+
aspect_ratio[0, 0] = W/H
|
219 |
+
transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
|
220 |
+
cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
|
221 |
+
|
222 |
+
# this is the image
|
223 |
+
if image is not None:
|
224 |
+
vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
|
225 |
+
faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
|
226 |
+
img = trimesh.Trimesh(vertices=vertices, faces=faces)
|
227 |
+
uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
|
228 |
+
img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
|
229 |
+
scene.add_geometry(img)
|
230 |
+
|
231 |
+
# this is the camera mesh
|
232 |
+
rot2 = np.eye(4)
|
233 |
+
rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
|
234 |
+
vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
|
235 |
+
vertices = geotrf(transform, vertices)
|
236 |
+
faces = []
|
237 |
+
for face in cam.faces:
|
238 |
+
if 0 in face:
|
239 |
+
continue
|
240 |
+
a, b, c = face
|
241 |
+
a2, b2, c2 = face + len(cam.vertices)
|
242 |
+
a3, b3, c3 = face + 2*len(cam.vertices)
|
243 |
+
|
244 |
+
# add 3 pseudo-edges
|
245 |
+
faces.append((a, b, b2))
|
246 |
+
faces.append((a, a2, c))
|
247 |
+
faces.append((c2, b, c))
|
248 |
+
|
249 |
+
faces.append((a, b, b3))
|
250 |
+
faces.append((a, a3, c))
|
251 |
+
faces.append((c3, b, c))
|
252 |
+
|
253 |
+
# no culling
|
254 |
+
faces += [(c, b, a) for a, b, c in faces]
|
255 |
+
|
256 |
+
cam = trimesh.Trimesh(vertices=vertices, faces=faces)
|
257 |
+
cam.visual.face_colors[:, :3] = edge_color
|
258 |
+
scene.add_geometry(cam)
|
259 |
+
|
260 |
+
|
261 |
+
def cat(a, b):
|
262 |
+
return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
|
263 |
+
|
264 |
+
|
265 |
+
OPENGL = np.array([[1, 0, 0, 0],
|
266 |
+
[0, -1, 0, 0],
|
267 |
+
[0, 0, -1, 0],
|
268 |
+
[0, 0, 0, 1]])
|
269 |
+
|
270 |
+
|
271 |
+
CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
|
272 |
+
(128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
|
273 |
+
|
274 |
+
|
275 |
+
def uint8(colors):
|
276 |
+
if not isinstance(colors, np.ndarray):
|
277 |
+
colors = np.array(colors)
|
278 |
+
if np.issubdtype(colors.dtype, np.floating):
|
279 |
+
colors *= 255
|
280 |
+
assert 0 <= colors.min() and colors.max() < 256
|
281 |
+
return np.uint8(colors)
|
282 |
+
|
283 |
+
|
284 |
+
def segment_sky(image):
|
285 |
+
import cv2
|
286 |
+
from scipy import ndimage
|
287 |
+
|
288 |
+
# Convert to HSV
|
289 |
+
image = to_numpy(image)
|
290 |
+
if np.issubdtype(image.dtype, np.floating):
|
291 |
+
image = np.uint8(255*image.clip(min=0, max=1))
|
292 |
+
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
293 |
+
|
294 |
+
# Define range for blue color and create mask
|
295 |
+
lower_blue = np.array([0, 0, 100])
|
296 |
+
upper_blue = np.array([30, 255, 255])
|
297 |
+
mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
|
298 |
+
|
299 |
+
# add luminous gray
|
300 |
+
mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
|
301 |
+
mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
|
302 |
+
mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
|
303 |
+
|
304 |
+
# Morphological operations
|
305 |
+
kernel = np.ones((5, 5), np.uint8)
|
306 |
+
mask2 = ndimage.binary_opening(mask, structure=kernel)
|
307 |
+
|
308 |
+
# keep only largest CC
|
309 |
+
_, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)
|
310 |
+
cc_sizes = stats[1:, cv2.CC_STAT_AREA]
|
311 |
+
order = cc_sizes.argsort()[::-1] # bigger first
|
312 |
+
i = 0
|
313 |
+
selection = []
|
314 |
+
while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
|
315 |
+
selection.append(1 + order[i])
|
316 |
+
i += 1
|
317 |
+
mask3 = np.in1d(labels, selection).reshape(labels.shape)
|
318 |
+
|
319 |
+
# Apply mask
|
320 |
+
return torch.from_numpy(mask3)
|