kxhit commited on
Commit
5ca3a35
1 Parent(s): d161cfd

cuda reinit?

Browse files
app.py CHANGED
@@ -268,7 +268,7 @@ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_
268
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
  import math
270
 
271
- @spaces.GPU(duration=120)
272
  def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
273
  cam_color=None, as_pointcloud=False,
274
  transparent_cams=False, silent=False, same_focals=False):
@@ -321,7 +321,7 @@ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world,
321
  scene.export(file_obj=outfile)
322
  return outfile
323
 
324
- @spaces.GPU(duration=120)
325
  def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
326
  clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
327
  """
 
268
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
  import math
270
 
271
+ # @spaces.GPU(duration=120)
272
  def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
273
  cam_color=None, as_pointcloud=False,
274
  transparent_cams=False, silent=False, same_focals=False):
 
321
  scene.export(file_obj=outfile)
322
  return outfile
323
 
324
+ # @spaces.GPU(duration=120)
325
  def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
326
  clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
327
  """
app_bk.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ print("cuda is available: ", torch.cuda.is_available())
4
+
5
+ import gradio as gr
6
+ import os
7
+ import shutil
8
+ import rembg
9
+ import numpy as np
10
+ import math
11
+ import open3d as o3d
12
+ from PIL import Image
13
+ import torchvision
14
+ import trimesh
15
+ from skimage.io import imsave
16
+ import imageio
17
+ import cv2
18
+ import matplotlib.pyplot as pl
19
+ pl.ion()
20
+
21
+ CaPE_TYPE = "6DoF"
22
+ device = 'cuda' #if torch.cuda.is_available() else 'cpu'
23
+ weight_dtype = torch.float16
24
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
+
26
+ # EscherNet
27
+ # create angles in archimedean spiral with N steps
28
+ def get_archimedean_spiral(sphere_radius, num_steps=250):
29
+ # x-z plane, around upper y
30
+ '''
31
+ https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
32
+ '''
33
+ a = 40
34
+ r = sphere_radius
35
+
36
+ translations = []
37
+ angles = []
38
+
39
+ # i = a / 2
40
+ i = 0.01
41
+ while i < a:
42
+ theta = i / a * math.pi
43
+ x = r * math.sin(theta) * math.cos(-i)
44
+ z = r * math.sin(-theta + math.pi) * math.sin(-i)
45
+ y = r * - math.cos(theta)
46
+
47
+ # translations.append((x, y, z)) # origin
48
+ translations.append((x, z, -y))
49
+ angles.append([np.rad2deg(-i), np.rad2deg(theta)])
50
+
51
+ # i += a / (2 * num_steps)
52
+ i += a / (1 * num_steps)
53
+
54
+ return np.array(translations), np.stack(angles)
55
+
56
+ def look_at(origin, target, up):
57
+ forward = (target - origin)
58
+ forward = forward / np.linalg.norm(forward)
59
+ right = np.cross(up, forward)
60
+ right = right / np.linalg.norm(right)
61
+ new_up = np.cross(forward, right)
62
+ rotation_matrix = np.column_stack((right, new_up, -forward, target))
63
+ matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
64
+ return matrix
65
+
66
+ import einops
67
+ import sys
68
+
69
+ sys.path.insert(0, "./6DoF/") # TODO change it when deploying
70
+ # use the customized diffusers modules
71
+ from diffusers import DDIMScheduler
72
+ from dataset import get_pose
73
+ from CN_encoder import CN_encoder
74
+ from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
+ from segment_anything import sam_model_registry, SamPredictor
76
+
77
+ # import rembg
78
+ from carvekit.api.high import HiInterface
79
+
80
+
81
+ pretrained_model_name_or_path = "kxic/EscherNet_demo"
82
+ resolution = 256
83
+ h,w = resolution,resolution
84
+ guidance_scale = 3.0
85
+ radius = 2.2
86
+ bg_color = [1., 1., 1., 1.]
87
+ image_transforms = torchvision.transforms.Compose(
88
+ [
89
+ torchvision.transforms.Resize((resolution, resolution)), # 256, 256
90
+ torchvision.transforms.ToTensor(),
91
+ torchvision.transforms.Normalize([0.5], [0.5])
92
+ ]
93
+ )
94
+ xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
95
+ # only half toop
96
+ xyzs_spiral = xyzs_spiral[:100]
97
+ angles_spiral = angles_spiral[:100]
98
+
99
+ # Init pipeline
100
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
101
+ image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
102
+ pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
103
+ pretrained_model_name_or_path,
104
+ revision=None,
105
+ scheduler=scheduler,
106
+ image_encoder=None,
107
+ safety_checker=None,
108
+ feature_extractor=None,
109
+ torch_dtype=weight_dtype,
110
+ )
111
+ pipeline.image_encoder = image_encoder.to(weight_dtype)
112
+
113
+ pipeline.set_progress_bar_config(disable=False)
114
+
115
+ pipeline = pipeline.to(device)
116
+
117
+ # pipeline.enable_xformers_memory_efficient_attention()
118
+ # enable vae slicing
119
+ pipeline.enable_vae_slicing()
120
+ # pipeline.enable_xformers_memory_efficient_attention()
121
+
122
+
123
+ #### object segmentation
124
+ def sam_init():
125
+ sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
126
+ if os.path.exists(sam_checkpoint) is False:
127
+ os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
128
+ model_type = "vit_h"
129
+
130
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
131
+ predictor = SamPredictor(sam)
132
+ return predictor
133
+
134
+ def create_carvekit_interface():
135
+ # Check doc strings for more information
136
+ interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
137
+ batch_size_seg=6,
138
+ batch_size_matting=1,
139
+ device="cpu",
140
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
141
+ matting_mask_size=2048,
142
+ trimap_prob_threshold=231,
143
+ trimap_dilation=30,
144
+ trimap_erosion_iters=5,
145
+ fp16=True)
146
+
147
+ return interface
148
+
149
+
150
+ # rembg_session = rembg.new_session()
151
+ rembg_session = create_carvekit_interface()
152
+ predictor = sam_init()
153
+
154
+
155
+
156
+ @spaces.GPU(duration=120)
157
+ def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
158
+ # set the random seed
159
+ generator = torch.Generator(device=device).manual_seed(sample_seed)
160
+ # generator = None
161
+ T_out = nvs_num
162
+ T_in = len(eschernet_input_dict['imgs'])
163
+ ####### output pose
164
+ # TODO choose T_out number of poses sequentially from the spiral
165
+ xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
166
+ angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
167
+
168
+ ####### input's max radius for translation scaling
169
+ radii = eschernet_input_dict['radii']
170
+ max_t = np.max(radii)
171
+ min_t = np.min(radii)
172
+
173
+ ####### input pose
174
+ pose_in = []
175
+ for T_in_index in range(T_in):
176
+ pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
177
+ pose[1:3, :] *= -1 # coordinate system conversion
178
+ pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
179
+ pose_in.append(torch.from_numpy(pose))
180
+
181
+ ####### input image
182
+ img = eschernet_input_dict['imgs'] / 255.
183
+ img[img[:, :, :, -1] == 0.] = bg_color
184
+ # TODO batch image_transforms
185
+ input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
186
+
187
+ ####### nvs pose
188
+ pose_out = []
189
+ for T_out_index in range(T_out):
190
+ azimuth, polar = angles_out[T_out_index]
191
+ if CaPE_TYPE == "4DoF":
192
+ pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
193
+ elif CaPE_TYPE == "6DoF":
194
+ pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
195
+ pose = np.linalg.inv(pose)
196
+ pose[2, :] *= -1
197
+ pose_out.append(torch.from_numpy(get_pose(pose)))
198
+
199
+
200
+
201
+ # [B, T, C, H, W]
202
+ input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
203
+ # [B, T, 4]
204
+ pose_in = np.stack(pose_in)
205
+ pose_out = np.stack(pose_out)
206
+
207
+ if CaPE_TYPE == "6DoF":
208
+ pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
209
+ pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
210
+ pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
211
+ pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
212
+
213
+ pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
214
+ pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
215
+
216
+ input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
217
+ assert T_in == input_image.shape[0]
218
+ assert T_in == pose_in.shape[1]
219
+ assert T_out == pose_out.shape[1]
220
+
221
+ # run inference
222
+ # pipeline.to(device)
223
+ pipeline.enable_xformers_memory_efficient_attention()
224
+ image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
225
+ poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
226
+ height=h, width=w, T_in=T_in, T_out=T_out,
227
+ guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
228
+ output_type="numpy").images
229
+
230
+ # save output image
231
+ output_dir = os.path.join(tmpdirname, "eschernet")
232
+ if os.path.exists(output_dir):
233
+ shutil.rmtree(output_dir)
234
+ os.makedirs(output_dir, exist_ok=True)
235
+ # # save to N imgs
236
+ # for i in range(T_out):
237
+ # imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
238
+ # make a gif
239
+ frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
240
+ # frame_one = frames[0]
241
+ # frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
242
+ # save_all=True, duration=50, loop=1)
243
+
244
+ # get a video
245
+ video_path = os.path.join(output_dir, "output.mp4")
246
+ imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
247
+
248
+
249
+ return video_path
250
+
251
+ # TODO mesh it
252
+ @spaces.GPU(duration=120)
253
+ def make3d():
254
+ pass
255
+
256
+
257
+
258
+ ############################ Dust3r as Pose Estimation ############################
259
+ from scipy.spatial.transform import Rotation
260
+ import copy
261
+
262
+ from dust3r.inference import inference
263
+ from dust3r.model import AsymmetricCroCo3DStereo
264
+ from dust3r.image_pairs import make_pairs
265
+ from dust3r.utils.image import load_images, rgb
266
+ from dust3r.utils.device import to_numpy
267
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
268
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
+ import math
270
+
271
+ @spaces.GPU(duration=120)
272
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
273
+ cam_color=None, as_pointcloud=False,
274
+ transparent_cams=False, silent=False, same_focals=False):
275
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
276
+ if not same_focals:
277
+ assert (len(cams2world) == len(focals))
278
+ pts3d = to_numpy(pts3d)
279
+ imgs = to_numpy(imgs)
280
+ focals = to_numpy(focals)
281
+ cams2world = to_numpy(cams2world)
282
+
283
+ scene = trimesh.Scene()
284
+
285
+ # add axes
286
+ scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
287
+
288
+ # full pointcloud
289
+ if as_pointcloud:
290
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
291
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
292
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
293
+ scene.add_geometry(pct)
294
+ else:
295
+ meshes = []
296
+ for i in range(len(imgs)):
297
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
298
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
299
+ scene.add_geometry(mesh)
300
+
301
+ # add each camera
302
+ for i, pose_c2w in enumerate(cams2world):
303
+ if isinstance(cam_color, list):
304
+ camera_edge_color = cam_color[i]
305
+ else:
306
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
307
+ if same_focals:
308
+ focal = focals[0]
309
+ else:
310
+ focal = focals[i]
311
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
312
+ None if transparent_cams else imgs[i], focal,
313
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
314
+
315
+ rot = np.eye(4)
316
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
317
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
318
+ outfile = os.path.join(outdir, 'scene.glb')
319
+ if not silent:
320
+ print('(exporting 3D scene to', outfile, ')')
321
+ scene.export(file_obj=outfile)
322
+ return outfile
323
+
324
+ @spaces.GPU(duration=120)
325
+ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
326
+ clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
327
+ """
328
+ extract 3D_model (glb file) from a reconstructed scene
329
+ """
330
+ if scene is None:
331
+ return None
332
+ # post processes
333
+ if clean_depth:
334
+ scene = scene.clean_pointcloud()
335
+ if mask_sky:
336
+ scene = scene.mask_sky()
337
+
338
+ # get optimized values from scene
339
+ rgbimg = to_numpy(scene.imgs)
340
+ focals = to_numpy(scene.get_focals().cpu())
341
+ # cams2world = to_numpy(scene.get_im_poses().cpu())
342
+ # TODO use the vis_poses
343
+ cams2world = scene.vis_poses
344
+
345
+ # 3D pointcloud from depthmap, poses and intrinsics
346
+ # pts3d = to_numpy(scene.get_pts3d())
347
+ # TODO use the vis_poses
348
+ pts3d = scene.vis_pts3d
349
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
350
+ msk = to_numpy(scene.get_masks())
351
+
352
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
353
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
354
+ same_focals=same_focals)
355
+
356
+ @spaces.GPU(duration=120)
357
+ def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
358
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
359
+ scenegraph_type, winsize, refid, same_focals):
360
+ """
361
+ from a list of images, run dust3r inference, global aligner.
362
+ then run get_3D_model_from_scene
363
+ """
364
+ silent = False
365
+ image_size = 224
366
+ # remove the directory if it already exists
367
+ outdir = tmpdirname
368
+ if os.path.exists(outdir):
369
+ shutil.rmtree(outdir)
370
+ os.makedirs(outdir, exist_ok=True)
371
+ imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
372
+ if len(imgs) == 1:
373
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
374
+ imgs[1]['idx'] = 1
375
+ if scenegraph_type == "swin":
376
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
377
+ elif scenegraph_type == "oneref":
378
+ scenegraph_type = scenegraph_type + "-" + str(refid)
379
+
380
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
381
+ output = inference(pairs, model, device, batch_size=1, verbose=not silent)
382
+
383
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
384
+ scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
385
+ lr = 0.01
386
+
387
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
388
+ loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
389
+
390
+ # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
391
+ # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
392
+
393
+ # also return rgb, depth and confidence imgs
394
+ # depth is normalized with the max value for all images
395
+ # we apply the jet colormap on the confidence maps
396
+ rgbimg = scene.imgs
397
+ # depths = to_numpy(scene.get_depthmaps())
398
+ # confs = to_numpy([c for c in scene.im_conf])
399
+ # cmap = pl.get_cmap('jet')
400
+ # depths_max = max([d.max() for d in depths])
401
+ # depths = [d / depths_max for d in depths]
402
+ # confs_max = max([d.max() for d in confs])
403
+ # confs = [cmap(d / confs_max) for d in confs]
404
+
405
+ imgs = []
406
+ rgbaimg = []
407
+ for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
408
+ imgs.append(rgbimg[i])
409
+ # imgs.append(rgb(depths[i]))
410
+ # imgs.append(rgb(confs[i]))
411
+ # imgs.append(imgs_rgba[i])
412
+ if len(imgs_rgba) == 1 and i == 1:
413
+ imgs.append(imgs_rgba[0])
414
+ rgbaimg.append(np.array(imgs_rgba[0]))
415
+ else:
416
+ imgs.append(imgs_rgba[i])
417
+ rgbaimg.append(np.array(imgs_rgba[i]))
418
+
419
+ rgbaimg = np.array(rgbaimg)
420
+
421
+ # for eschernet
422
+ # get optimized values from scene
423
+ rgbimg = to_numpy(scene.imgs)
424
+ # focals = to_numpy(scene.get_focals().cpu())
425
+ cams2world = to_numpy(scene.get_im_poses().cpu())
426
+
427
+ # 3D pointcloud from depthmap, poses and intrinsics
428
+ pts3d = to_numpy(scene.get_pts3d())
429
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
430
+ msk = to_numpy(scene.get_masks())
431
+ obj_mask = rgbaimg[..., 3] > 0
432
+
433
+ # TODO set global coordinate system at the center of the scene, z-axis is up
434
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
435
+ pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
436
+ centroid = np.mean(pts_obj, axis=0) # obj center
437
+ obj2world = np.eye(4)
438
+ obj2world[:3, 3] = -centroid # T_wc
439
+
440
+ # get z_up vector
441
+ # TODO fit a plane and get the normal vector
442
+ pcd = o3d.geometry.PointCloud()
443
+ pcd.points = o3d.utility.Vector3dVector(pts)
444
+ plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
445
+ # get the normalised normal vector dim = 3
446
+ normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
447
+ # the normal direction should be pointing up
448
+ if normal[1] < 0:
449
+ normal = -normal
450
+ # print("normal", normal)
451
+
452
+ # # TODO z-up 180
453
+ # z_up = np.array([[1,0,0,0],
454
+ # [0,-1,0,0],
455
+ # [0,0,-1,0],
456
+ # [0,0,0,1]])
457
+ # obj2world = z_up @ obj2world
458
+
459
+ # # avg the y
460
+ # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
461
+ # # import pdb; pdb.set_trace()
462
+ # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
463
+ # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
464
+ # rot = Rotation.from_rotvec(rot_angle * rot_axis)
465
+ # z_up = np.eye(4)
466
+ # z_up[:3, :3] = rot.as_matrix()
467
+
468
+ # get the rotation matrix from normal to z-axis
469
+ z_axis = np.array([0, 0, 1])
470
+ rot_axis = np.cross(normal, z_axis)
471
+ rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
472
+ rot = Rotation.from_rotvec(rot_angle * rot_axis)
473
+ z_up = np.eye(4)
474
+ z_up[:3, :3] = rot.as_matrix()
475
+ obj2world = z_up @ obj2world
476
+ # flip 180
477
+ flip_rot = np.array([[1, 0, 0, 0],
478
+ [0, -1, 0, 0],
479
+ [0, 0, -1, 0],
480
+ [0, 0, 0, 1]])
481
+ obj2world = flip_rot @ obj2world
482
+
483
+ # get new cams2obj
484
+ cams2obj = []
485
+ for i, cam2world in enumerate(cams2world):
486
+ cams2obj.append(obj2world @ cam2world)
487
+ # TODO transform pts3d to the new coordinate system
488
+ for i, pts in enumerate(pts3d):
489
+ pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
490
+ -1)) \
491
+ .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
492
+ cams2world = np.array(cams2obj)
493
+ # TODO rewrite hack
494
+ scene.vis_poses = cams2world.copy()
495
+ scene.vis_pts3d = pts3d.copy()
496
+
497
+ # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
498
+ for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
499
+ np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
500
+ pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
501
+ pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
502
+ # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
503
+ # save the min/max radius of camera
504
+ radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
505
+ np.save(os.path.join(outdir, "radii.npy"), radii)
506
+
507
+ eschernet_input = {"poses": cams2world,
508
+ "radii": radii,
509
+ "imgs": rgbaimg}
510
+ print("got eschernet input")
511
+ outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
512
+ clean_depth, transparent_cams, cam_size, same_focals=same_focals)
513
+
514
+ return scene, outfile, imgs, eschernet_input
515
+
516
+
517
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
518
+ num_files = len(inputfiles) if inputfiles is not None else 1
519
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
520
+ if scenegraph_type == "swin":
521
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
522
+ minimum=1, maximum=max_winsize, step=1, visible=True)
523
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
524
+ maximum=num_files - 1, step=1, visible=False)
525
+ elif scenegraph_type == "oneref":
526
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
527
+ minimum=1, maximum=max_winsize, step=1, visible=False)
528
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
529
+ maximum=num_files - 1, step=1, visible=True)
530
+ else:
531
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
532
+ minimum=1, maximum=max_winsize, step=1, visible=False)
533
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
534
+ maximum=num_files - 1, step=1, visible=False)
535
+ return winsize, refid
536
+
537
+
538
+ def get_examples(path):
539
+ objs = []
540
+ for obj_name in sorted(os.listdir(path)):
541
+ img_files = []
542
+ for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
543
+ img_files.append(os.path.join(path, obj_name, img_file))
544
+ objs.append([img_files])
545
+ print("objs = ", objs)
546
+ return objs
547
+
548
+ def preview_input(inputfiles):
549
+ if inputfiles is None:
550
+ return None
551
+ imgs = []
552
+ for img_file in inputfiles:
553
+ img = pl.imread(img_file)
554
+ imgs.append(img)
555
+ return imgs
556
+
557
+ # def main():
558
+ # dustr init
559
+ silent = False
560
+ image_size = 224
561
+ weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
562
+ model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
563
+ # dust3r will write the 3D model inside tmpdirname
564
+ # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
565
+ tmpdirname = os.path.join('logs/user_object')
566
+ # remove the directory if it already exists
567
+ if os.path.exists(tmpdirname):
568
+ shutil.rmtree(tmpdirname)
569
+ os.makedirs(tmpdirname, exist_ok=True)
570
+ if not silent:
571
+ print('Outputing stuff in', tmpdirname)
572
+
573
+ _HEADER_ = '''
574
+ <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
575
+ <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
576
+
577
+ Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
578
+
579
+ <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
580
+ <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
581
+ <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
582
+
583
+ <h4><b>Tips:</b></h4>
584
+
585
+ - Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
586
+
587
+ - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
588
+
589
+ - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
590
+
591
+ - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
592
+
593
+ '''
594
+
595
+ _CITE_ = r"""
596
+ 📝 <b>Citation</b>:
597
+ ```bibtex
598
+ @article{kong2024eschernet,
599
+ title={EscherNet: A Generative Model for Scalable View Synthesis},
600
+ author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
601
+ journal={arXiv preprint arXiv:2402.03908},
602
+ year={2024}
603
+ }
604
+ ```
605
+ """
606
+
607
+ with gr.Blocks() as demo:
608
+ gr.Markdown(_HEADER_)
609
+ # mv_images = gr.State()
610
+ scene = gr.State(None)
611
+ eschernet_input = gr.State(None)
612
+ with gr.Row(variant="panel"):
613
+ # left column
614
+ with gr.Column():
615
+ with gr.Row():
616
+ input_image = gr.File(file_count="multiple")
617
+ with gr.Row():
618
+ run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
619
+ with gr.Row():
620
+ processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
621
+ with gr.Row(variant="panel"):
622
+ # input examples under "examples" folder
623
+ gr.Examples(
624
+ examples=get_examples('examples'),
625
+ inputs=[input_image],
626
+ label="Examples (click one set of images to start!)",
627
+ examples_per_page=20
628
+ )
629
+
630
+
631
+
632
+
633
+
634
+ # right column
635
+ with gr.Column():
636
+
637
+ with gr.Row():
638
+ outmodel = gr.Model3D()
639
+
640
+ with gr.Row():
641
+ gr.Markdown('''
642
+ <h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
643
+ ''')
644
+
645
+ with gr.Row():
646
+ with gr.Group():
647
+ do_remove_background = gr.Checkbox(
648
+ label="Remove Background", value=True
649
+ )
650
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
651
+
652
+ sample_steps = gr.Slider(
653
+ label="Sample Steps",
654
+ minimum=30,
655
+ maximum=75,
656
+ value=50,
657
+ step=5,
658
+ visible=False
659
+ )
660
+
661
+ nvs_num = gr.Slider(
662
+ label="Number of Novel Views",
663
+ minimum=5,
664
+ maximum=100,
665
+ value=30,
666
+ step=1
667
+ )
668
+
669
+ nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
670
+ value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
671
+
672
+ with gr.Row():
673
+ gr.Markdown('''
674
+ <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
675
+ ''')
676
+
677
+ with gr.Row():
678
+ submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
679
+
680
+ with gr.Row():
681
+ with gr.Column():
682
+ output_video = gr.Video(
683
+ label="video", format="mp4",
684
+ width=379,
685
+ autoplay=True,
686
+ interactive=False
687
+ )
688
+
689
+ with gr.Row():
690
+ gr.Markdown('''
691
+ <h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
692
+ ''')
693
+
694
+ gr.Markdown(_CITE_)
695
+
696
+ # set dust3r parameter invisible to be clean
697
+ with gr.Column():
698
+ with gr.Row():
699
+ schedule = gr.Dropdown(["linear", "cosine"],
700
+ value='linear', label="schedule", info="For global alignment!", visible=False)
701
+ niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
702
+ label="num_iterations", info="For global alignment!", visible=False)
703
+ scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
704
+ value='complete', label="Scenegraph",
705
+ info="Define how to make pairs",
706
+ interactive=True, visible=False)
707
+ same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
708
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
709
+ minimum=1, maximum=1, step=1, visible=False)
710
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
711
+
712
+ with gr.Row():
713
+ # adjust the confidence threshold
714
+ min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
715
+ # adjust the camera size in the output pointcloud
716
+ cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
717
+ with gr.Row():
718
+ as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
719
+ # two post process implemented
720
+ mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
721
+ clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
722
+ transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
723
+
724
+ # events
725
+ # scenegraph_type.change(set_scenegraph_options,
726
+ # inputs=[input_image, winsize, refid, scenegraph_type],
727
+ # outputs=[winsize, refid])
728
+ # min_conf_thr.release(fn=model_from_scene_fun,
729
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
730
+ # clean_depth, transparent_cams, cam_size, same_focals],
731
+ # outputs=outmodel)
732
+ # cam_size.change(fn=model_from_scene_fun,
733
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
734
+ # clean_depth, transparent_cams, cam_size, same_focals],
735
+ # outputs=outmodel)
736
+ # as_pointcloud.change(fn=model_from_scene_fun,
737
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
738
+ # clean_depth, transparent_cams, cam_size, same_focals],
739
+ # outputs=outmodel)
740
+ # mask_sky.change(fn=model_from_scene_fun,
741
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
742
+ # clean_depth, transparent_cams, cam_size, same_focals],
743
+ # outputs=outmodel)
744
+ # clean_depth.change(fn=model_from_scene_fun,
745
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
746
+ # clean_depth, transparent_cams, cam_size, same_focals],
747
+ # outputs=outmodel)
748
+ # transparent_cams.change(model_from_scene_fun,
749
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
750
+ # clean_depth, transparent_cams, cam_size, same_focals],
751
+ # outputs=outmodel)
752
+ # run_dust3r.click(fn=recon_fun,
753
+ # inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
754
+ # mask_sky, clean_depth, transparent_cams, cam_size,
755
+ # scenegraph_type, winsize, refid, same_focals],
756
+ # outputs=[scene, outmodel, processed_image, eschernet_input])
757
+
758
+ # events
759
+ input_image.change(set_scenegraph_options,
760
+ inputs=[input_image, winsize, refid, scenegraph_type],
761
+ outputs=[winsize, refid])
762
+ run_dust3r.click(fn=get_reconstructed_scene,
763
+ inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
764
+ mask_sky, clean_depth, transparent_cams, cam_size,
765
+ scenegraph_type, winsize, refid, same_focals],
766
+ outputs=[scene, outmodel, processed_image, eschernet_input])
767
+
768
+
769
+ # events
770
+ input_image.change(fn=preview_input,
771
+ inputs=[input_image],
772
+ outputs=[processed_image])
773
+
774
+ submit.click(fn=run_eschernet,
775
+ inputs=[eschernet_input, sample_steps, sample_seed,
776
+ nvs_num, nvs_mode],
777
+ outputs=[output_video])
778
+
779
+
780
+
781
+ # demo.queue(max_size=10)
782
+ # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
783
+ demo.queue(max_size=10).launch()
784
+
785
+ # if __name__ == '__main__':
786
+ # main()
mini_dust3r/__init__.py ADDED
File without changes
mini_dust3r/api/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .inference import inferece_dust3r, OptimizedResult, log_optimized_result
2
+
3
+ __all__ = ["inferece_dust3r", "OptimizedResult", "log_optimized_result"]
mini_dust3r/api/inference.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import rerun as rr
2
+ from pathlib import Path
3
+ from typing import Literal
4
+ import copy
5
+ import torch
6
+ import numpy as np
7
+ from jaxtyping import Float32, Bool
8
+ import trimesh
9
+ from tqdm import tqdm
10
+
11
+ from mini_dust3r.utils.image import load_images, ImageDict
12
+ from mini_dust3r.inference import inference, Dust3rResult
13
+ from mini_dust3r.model import AsymmetricCroCo3DStereo
14
+ from mini_dust3r.image_pairs import make_pairs
15
+ from mini_dust3r.cloud_opt import global_aligner, GlobalAlignerMode
16
+ from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
17
+ from mini_dust3r.viz import pts3d_to_trimesh, cat_meshes
18
+ from dataclasses import dataclass
19
+
20
+
21
+ @dataclass
22
+ class OptimizedResult:
23
+ K_b33: Float32[np.ndarray, "b 3 3"]
24
+ world_T_cam_b44: Float32[np.ndarray, "b 4 4"]
25
+ rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]]
26
+ depth_hw_list: list[Float32[np.ndarray, "h w"]]
27
+ conf_hw_list: list[Float32[np.ndarray, "h w"]]
28
+ masks_list: Bool[np.ndarray, "h w"]
29
+ point_cloud: trimesh.PointCloud
30
+ mesh: trimesh.Trimesh
31
+
32
+
33
+ def log_optimized_result(
34
+ optimized_result: OptimizedResult, parent_log_path: Path
35
+ ) -> None:
36
+ rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
37
+ # log pointcloud
38
+ rr.log(
39
+ f"{parent_log_path}/pointcloud",
40
+ rr.Points3D(
41
+ positions=optimized_result.point_cloud.vertices,
42
+ colors=optimized_result.point_cloud.colors,
43
+ ),
44
+ timeless=True,
45
+ )
46
+
47
+ mesh = optimized_result.mesh
48
+ rr.log(
49
+ f"{parent_log_path}/mesh",
50
+ rr.Mesh3D(
51
+ vertex_positions=mesh.vertices,
52
+ vertex_colors=mesh.visual.vertex_colors,
53
+ indices=mesh.faces,
54
+ ),
55
+ timeless=True,
56
+ )
57
+ pbar = tqdm(
58
+ zip(
59
+ optimized_result.rgb_hw3_list,
60
+ optimized_result.depth_hw_list,
61
+ optimized_result.K_b33,
62
+ optimized_result.world_T_cam_b44,
63
+ ),
64
+ total=len(optimized_result.rgb_hw3_list),
65
+ )
66
+ for i, (rgb_hw3, depth_hw, k_33, world_T_cam_44) in enumerate(pbar):
67
+ camera_log_path = f"{parent_log_path}/camera_{i}"
68
+ height, width, _ = rgb_hw3.shape
69
+ rr.log(
70
+ f"{camera_log_path}",
71
+ rr.Transform3D(
72
+ translation=world_T_cam_44[:3, 3],
73
+ mat3x3=world_T_cam_44[:3, :3],
74
+ from_parent=False,
75
+ ),
76
+ )
77
+ rr.log(
78
+ f"{camera_log_path}/pinhole",
79
+ rr.Pinhole(
80
+ image_from_camera=k_33,
81
+ height=height,
82
+ width=width,
83
+ camera_xyz=rr.ViewCoordinates.RDF,
84
+ ),
85
+ )
86
+ rr.log(
87
+ f"{camera_log_path}/pinhole/rgb",
88
+ rr.Image(rgb_hw3),
89
+ )
90
+ rr.log(
91
+ f"{camera_log_path}/pinhole/depth",
92
+ rr.DepthImage(depth_hw),
93
+ )
94
+
95
+
96
+ def scene_to_results(scene: BasePCOptimizer, min_conf_thr: int) -> OptimizedResult:
97
+ ### get camera parameters K and T
98
+ K_b33: Float32[np.ndarray, "b 3 3"] = scene.get_intrinsics().numpy(force=True)
99
+ world_T_cam_b44: Float32[np.ndarray, "b 4 4"] = scene.get_im_poses().numpy(
100
+ force=True
101
+ )
102
+ ### image, confidence, depths
103
+ rgb_hw3_list: list[Float32[np.ndarray, "h w 3"]] = scene.imgs
104
+ depth_hw_list: list[Float32[np.ndarray, "h w"]] = [
105
+ depth.numpy(force=True) for depth in scene.get_depthmaps()
106
+ ]
107
+ # normalized depth
108
+ # depth_hw_list = [depth_hw / depth_hw.max() for depth_hw in depth_hw_list]
109
+
110
+ conf_hw_list: list[Float32[np.ndarray, "h w"]] = [
111
+ c.numpy(force=True) for c in scene.im_conf
112
+ ]
113
+ # normalize confidence
114
+ # conf_hw_list = [conf_hw / conf_hw.max() for conf_hw in conf_hw_list]
115
+
116
+ # point cloud, mesh
117
+ pts3d_list: list[Float32[np.ndarray, "h w 3"]] = [
118
+ pt3d.numpy(force=True) for pt3d in scene.get_pts3d()
119
+ ]
120
+ # get log confidence
121
+ log_conf_trf: Float32[torch.Tensor, ""] = scene.conf_trf(torch.tensor(min_conf_thr))
122
+ # set the minimum confidence threshold
123
+ scene.min_conf_thr = float(log_conf_trf)
124
+ masks_list: Bool[np.ndarray, "h w"] = [
125
+ mask.numpy(force=True) for mask in scene.get_masks()
126
+ ]
127
+
128
+ point_cloud: Float32[np.ndarray, "num_points 3"] = np.concatenate(
129
+ [p[m] for p, m in zip(pts3d_list, masks_list)]
130
+ )
131
+ colors: Float32[np.ndarray, "num_points 3"] = np.concatenate(
132
+ [p[m] for p, m in zip(rgb_hw3_list, masks_list)]
133
+ )
134
+ point_cloud = trimesh.PointCloud(
135
+ point_cloud.reshape(-1, 3), colors=colors.reshape(-1, 3)
136
+ )
137
+
138
+ meshes = []
139
+ pbar = tqdm(zip(rgb_hw3_list, pts3d_list, masks_list), total=len(rgb_hw3_list))
140
+ for rgb_hw3, pts3d, mask in pbar:
141
+ meshes.append(pts3d_to_trimesh(rgb_hw3, pts3d, mask))
142
+
143
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
144
+ optimised_result = OptimizedResult(
145
+ K_b33=K_b33,
146
+ world_T_cam_b44=world_T_cam_b44,
147
+ rgb_hw3_list=rgb_hw3_list,
148
+ depth_hw_list=depth_hw_list,
149
+ conf_hw_list=conf_hw_list,
150
+ masks_list=masks_list,
151
+ point_cloud=point_cloud,
152
+ mesh=mesh,
153
+ )
154
+ return optimised_result
155
+
156
+
157
+ def inferece_dust3r(
158
+ image_dir_or_list: Path | list[Path],
159
+ model: AsymmetricCroCo3DStereo,
160
+ device: Literal["cpu", "cuda", "mps"],
161
+ batch_size: int = 1,
162
+ image_size: Literal[224, 512] = 512,
163
+ niter: int = 100,
164
+ schedule: Literal["linear", "cosine"] = "linear",
165
+ min_conf_thr: float = 10,
166
+ ) -> OptimizedResult:
167
+ """
168
+ Perform inference using the Dust3r algorithm.
169
+
170
+ Args:
171
+ image_dir_or_list (Union[Path, List[Path]]): Path to the directory containing images or a list of image paths.
172
+ model (AsymmetricCroCo3DStereo): The Dust3r model to use for inference.
173
+ device (Literal["cpu", "cuda", "mps"]): The device to use for inference ("cpu", "cuda", or "mps").
174
+ batch_size (int, optional): The batch size for inference. Defaults to 1.
175
+ image_size (Literal[224, 512], optional): The size of the input images. Defaults to 512.
176
+ niter (int, optional): The number of iterations for the global alignment optimization. Defaults to 100.
177
+ schedule (Literal["linear", "cosine"], optional): The learning rate schedule for the global alignment optimization. Defaults to "linear".
178
+ min_conf_thr (float, optional): The minimum confidence threshold for the optimized result. Defaults to 10.
179
+
180
+ Returns:
181
+ OptimizedResult: The optimized result containing the RGB, depth, and confidence images.
182
+
183
+ Raises:
184
+ ValueError: If `image_dir_or_list` is neither a list of paths nor a path.
185
+ """
186
+ if isinstance(image_dir_or_list, list):
187
+ imgs: list[ImageDict] = load_images(
188
+ folder_or_list=image_dir_or_list, size=image_size, verbose=True
189
+ )
190
+ elif isinstance(image_dir_or_list, Path):
191
+ imgs: list[ImageDict] = load_images(
192
+ folder_or_list=str(image_dir_or_list), size=image_size, verbose=True
193
+ )
194
+ else:
195
+ raise ValueError("image_dir_or_list should be a list of paths or a path")
196
+
197
+ # if only one image was loaded, duplicate it to feed into stereo network
198
+ if len(imgs) == 1:
199
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
200
+ imgs[1]["idx"] = 1
201
+
202
+ pairs: list[tuple[ImageDict, ImageDict]] = make_pairs(
203
+ imgs, scene_graph="complete", prefilter=None, symmetrize=True
204
+ )
205
+ output: Dust3rResult = inference(pairs, model, device, batch_size=batch_size)
206
+
207
+ mode = (
208
+ GlobalAlignerMode.PointCloudOptimizer
209
+ if len(imgs) > 2
210
+ else GlobalAlignerMode.PairViewer
211
+ )
212
+ scene: BasePCOptimizer = global_aligner(
213
+ dust3r_output=output, device=device, mode=mode
214
+ )
215
+
216
+ lr = 0.01
217
+
218
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
219
+ loss = scene.compute_global_alignment(
220
+ init="mst", niter=niter, schedule=schedule, lr=lr
221
+ )
222
+
223
+ # get the optimized result from the scene
224
+ optimized_result: OptimizedResult = scene_to_results(scene, min_conf_thr)
225
+ return optimized_result
mini_dust3r/cloud_opt/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # global alignment optimization wrapper function
6
+ # --------------------------------------------------------
7
+ from enum import Enum
8
+
9
+ from .optimizer import PointCloudOptimizer
10
+ from .modular_optimizer import ModularPointCloudOptimizer
11
+ from .pair_viewer import PairViewer
12
+ from mini_dust3r.inference import Dust3rResult
13
+ from typing import Literal
14
+
15
+
16
+ class GlobalAlignerMode(Enum):
17
+ PointCloudOptimizer = "PointCloudOptimizer"
18
+ ModularPointCloudOptimizer = "ModularPointCloudOptimizer"
19
+ PairViewer = "PairViewer"
20
+
21
+
22
+ def global_aligner(
23
+ dust3r_output: Dust3rResult,
24
+ device: Literal["cpu", "cuda", "mps"],
25
+ mode: GlobalAlignerMode = GlobalAlignerMode.PointCloudOptimizer,
26
+ **optim_kw,
27
+ ):
28
+ # extract all inputs
29
+ view1, view2, pred1, pred2 = [
30
+ dust3r_output[k] for k in "view1 view2 pred1 pred2".split()
31
+ ]
32
+ # build the optimizer
33
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
34
+ net = PointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(device)
35
+ elif mode == GlobalAlignerMode.ModularPointCloudOptimizer:
36
+ net = ModularPointCloudOptimizer(view1, view2, pred1, pred2, **optim_kw).to(
37
+ device
38
+ )
39
+ elif mode == GlobalAlignerMode.PairViewer:
40
+ net = PairViewer(view1, view2, pred1, pred2, **optim_kw).to(device)
41
+ else:
42
+ raise NotImplementedError(f"Unknown mode {mode}")
43
+
44
+ return net
mini_dust3r/cloud_opt/base_opt.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Base class for the global alignement procedure
6
+ # --------------------------------------------------------
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import roma
13
+ from copy import deepcopy
14
+ import tqdm
15
+
16
+ from mini_dust3r.utils.geometry import inv, geotrf
17
+ from mini_dust3r.utils.device import to_numpy
18
+ from mini_dust3r.utils.image import rgb
19
+ from mini_dust3r.viz import SceneViz, segment_sky, auto_cam_size
20
+ from mini_dust3r.optim_factory import adjust_learning_rate_by_lr
21
+
22
+ from mini_dust3r.cloud_opt.commons import (edge_str, ALL_DISTS, NoGradParamDict, get_imshapes, signed_expm1, signed_log1p,
23
+ cosine_schedule, linear_schedule, get_conf_trf)
24
+ import mini_dust3r.cloud_opt.init_im_poses as init_fun
25
+
26
+
27
+ class BasePCOptimizer (nn.Module):
28
+ """ Optimize a global scene, given a list of pairwise observations.
29
+ Graph node: images
30
+ Graph edges: observations = (pred1, pred2)
31
+ """
32
+
33
+ def __init__(self, *args, **kwargs):
34
+ if len(args) == 1 and len(kwargs) == 0:
35
+ other = deepcopy(args[0])
36
+ attrs = '''edges is_symmetrized dist n_imgs pred_i pred_j imshapes
37
+ min_conf_thr conf_thr conf_i conf_j im_conf
38
+ base_scale norm_pw_scale POSE_DIM pw_poses
39
+ pw_adaptors pw_adaptors has_im_poses rand_pose imgs verbose'''.split()
40
+ self.__dict__.update({k: other[k] for k in attrs})
41
+ else:
42
+ self._init_from_views(*args, **kwargs)
43
+
44
+ def _init_from_views(self, view1, view2, pred1, pred2,
45
+ dist='l1',
46
+ conf='log',
47
+ min_conf_thr=3,
48
+ base_scale=0.5,
49
+ allow_pw_adaptors=False,
50
+ pw_break=20,
51
+ rand_pose=torch.randn,
52
+ iterationsCount=None,
53
+ verbose=True):
54
+ super().__init__()
55
+ if not isinstance(view1['idx'], list):
56
+ view1['idx'] = view1['idx'].tolist()
57
+ if not isinstance(view2['idx'], list):
58
+ view2['idx'] = view2['idx'].tolist()
59
+ self.edges = [(int(i), int(j)) for i, j in zip(view1['idx'], view2['idx'])]
60
+ self.is_symmetrized = set(self.edges) == {(j, i) for i, j in self.edges}
61
+ self.dist = ALL_DISTS[dist]
62
+ self.verbose = verbose
63
+
64
+ self.n_imgs = self._check_edges()
65
+
66
+ # input data
67
+ pred1_pts = pred1['pts3d']
68
+ pred2_pts = pred2['pts3d_in_other_view']
69
+ self.pred_i = NoGradParamDict({ij: pred1_pts[n] for n, ij in enumerate(self.str_edges)})
70
+ self.pred_j = NoGradParamDict({ij: pred2_pts[n] for n, ij in enumerate(self.str_edges)})
71
+ self.imshapes = get_imshapes(self.edges, pred1_pts, pred2_pts)
72
+
73
+ # work in log-scale with conf
74
+ pred1_conf = pred1['conf']
75
+ pred2_conf = pred2['conf']
76
+ self.min_conf_thr = min_conf_thr
77
+ self.conf_trf = get_conf_trf(conf)
78
+
79
+ self.conf_i = NoGradParamDict({ij: pred1_conf[n] for n, ij in enumerate(self.str_edges)})
80
+ self.conf_j = NoGradParamDict({ij: pred2_conf[n] for n, ij in enumerate(self.str_edges)})
81
+ self.im_conf = self._compute_img_conf(pred1_conf, pred2_conf)
82
+
83
+ # pairwise pose parameters
84
+ self.base_scale = base_scale
85
+ self.norm_pw_scale = True
86
+ self.pw_break = pw_break
87
+ self.POSE_DIM = 7
88
+ self.pw_poses = nn.Parameter(rand_pose((self.n_edges, 1+self.POSE_DIM))) # pairwise poses
89
+ self.pw_adaptors = nn.Parameter(torch.zeros((self.n_edges, 2))) # slight xy/z adaptation
90
+ self.pw_adaptors.requires_grad_(allow_pw_adaptors)
91
+ self.has_im_poses = False
92
+ self.rand_pose = rand_pose
93
+
94
+ # possibly store images for show_pointcloud
95
+ self.imgs = None
96
+ if 'img' in view1 and 'img' in view2:
97
+ imgs = [torch.zeros((3,)+hw) for hw in self.imshapes]
98
+ for v in range(len(self.edges)):
99
+ idx = view1['idx'][v]
100
+ imgs[idx] = view1['img'][v]
101
+ idx = view2['idx'][v]
102
+ imgs[idx] = view2['img'][v]
103
+ self.imgs = rgb(imgs)
104
+
105
+ @property
106
+ def n_edges(self):
107
+ return len(self.edges)
108
+
109
+ @property
110
+ def str_edges(self):
111
+ return [edge_str(i, j) for i, j in self.edges]
112
+
113
+ @property
114
+ def imsizes(self):
115
+ return [(w, h) for h, w in self.imshapes]
116
+
117
+ @property
118
+ def device(self):
119
+ return next(iter(self.parameters())).device
120
+
121
+ def state_dict(self, trainable=True):
122
+ all_params = super().state_dict()
123
+ return {k: v for k, v in all_params.items() if k.startswith(('_', 'pred_i.', 'pred_j.', 'conf_i.', 'conf_j.')) != trainable}
124
+
125
+ def load_state_dict(self, data):
126
+ return super().load_state_dict(self.state_dict(trainable=False) | data)
127
+
128
+ def _check_edges(self):
129
+ indices = sorted({i for edge in self.edges for i in edge})
130
+ assert indices == list(range(len(indices))), 'bad pair indices: missing values '
131
+ return len(indices)
132
+
133
+ @torch.no_grad()
134
+ def _compute_img_conf(self, pred1_conf, pred2_conf):
135
+ im_conf = nn.ParameterList([torch.zeros(hw, device=self.device) for hw in self.imshapes])
136
+ for e, (i, j) in enumerate(self.edges):
137
+ im_conf[i] = torch.maximum(im_conf[i], pred1_conf[e])
138
+ im_conf[j] = torch.maximum(im_conf[j], pred2_conf[e])
139
+ return im_conf
140
+
141
+ def get_adaptors(self):
142
+ adapt = self.pw_adaptors
143
+ adapt = torch.cat((adapt[:, 0:1], adapt), dim=-1) # (scale_xy, scale_xy, scale_z)
144
+ if self.norm_pw_scale: # normalize so that the product == 1
145
+ adapt = adapt - adapt.mean(dim=1, keepdim=True)
146
+ return (adapt / self.pw_break).exp()
147
+
148
+ def _get_poses(self, poses):
149
+ # normalize rotation
150
+ Q = poses[:, :4]
151
+ T = signed_expm1(poses[:, 4:7])
152
+ RT = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()
153
+ return RT
154
+
155
+ def _set_pose(self, poses, idx, R, T=None, scale=None, force=False):
156
+ # all poses == cam-to-world
157
+ pose = poses[idx]
158
+ if not (pose.requires_grad or force):
159
+ return pose
160
+
161
+ if R.shape == (4, 4):
162
+ assert T is None
163
+ T = R[:3, 3]
164
+ R = R[:3, :3]
165
+
166
+ if R is not None:
167
+ pose.data[0:4] = roma.rotmat_to_unitquat(R)
168
+ if T is not None:
169
+ pose.data[4:7] = signed_log1p(T / (scale or 1)) # translation is function of scale
170
+
171
+ if scale is not None:
172
+ assert poses.shape[-1] in (8, 13)
173
+ pose.data[-1] = np.log(float(scale))
174
+ return pose
175
+
176
+ def get_pw_norm_scale_factor(self):
177
+ if self.norm_pw_scale:
178
+ # normalize scales so that things cannot go south
179
+ # we want that exp(scale) ~= self.base_scale
180
+ return (np.log(self.base_scale) - self.pw_poses[:, -1].mean()).exp()
181
+ else:
182
+ return 1 # don't norm scale for known poses
183
+
184
+ def get_pw_scale(self):
185
+ scale = self.pw_poses[:, -1].exp() # (n_edges,)
186
+ scale = scale * self.get_pw_norm_scale_factor()
187
+ return scale
188
+
189
+ def get_pw_poses(self): # cam to world
190
+ RT = self._get_poses(self.pw_poses)
191
+ scaled_RT = RT.clone()
192
+ scaled_RT[:, :3] *= self.get_pw_scale().view(-1, 1, 1) # scale the rotation AND translation
193
+ return scaled_RT
194
+
195
+ def get_masks(self):
196
+ return [(conf > self.min_conf_thr) for conf in self.im_conf]
197
+
198
+ def depth_to_pts3d(self):
199
+ raise NotImplementedError()
200
+
201
+ def get_pts3d(self, raw=False):
202
+ res = self.depth_to_pts3d()
203
+ if not raw:
204
+ res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
205
+ return res
206
+
207
+ def _set_focal(self, idx, focal, force=False):
208
+ raise NotImplementedError()
209
+
210
+ def get_focals(self):
211
+ raise NotImplementedError()
212
+
213
+ def get_known_focal_mask(self):
214
+ raise NotImplementedError()
215
+
216
+ def get_principal_points(self):
217
+ raise NotImplementedError()
218
+
219
+ def get_conf(self, mode=None):
220
+ trf = self.conf_trf if mode is None else get_conf_trf(mode)
221
+ return [trf(c) for c in self.im_conf]
222
+
223
+ def get_im_poses(self):
224
+ raise NotImplementedError()
225
+
226
+ def _set_depthmap(self, idx, depth, force=False):
227
+ raise NotImplementedError()
228
+
229
+ def get_depthmaps(self, raw=False):
230
+ raise NotImplementedError()
231
+
232
+ @torch.no_grad()
233
+ def clean_pointcloud(self, tol=0.001, max_bad_conf=0):
234
+ """ Method:
235
+ 1) express all 3d points in each camera coordinate frame
236
+ 2) if they're in front of a depthmap --> then lower their confidence
237
+ """
238
+ assert 0 <= tol < 1
239
+ cams = inv(self.get_im_poses())
240
+ K = self.get_intrinsics()
241
+ depthmaps = self.get_depthmaps()
242
+ res = deepcopy(self)
243
+
244
+ for i, pts3d in enumerate(self.depth_to_pts3d()):
245
+ for j in range(self.n_imgs):
246
+ if i == j:
247
+ continue
248
+
249
+ # project 3dpts in other view
250
+ Hi, Wi = self.imshapes[i]
251
+ Hj, Wj = self.imshapes[j]
252
+ proj = geotrf(cams[j], pts3d[:Hi*Wi]).reshape(Hi, Wi, 3)
253
+ proj_depth = proj[:, :, 2]
254
+ u, v = geotrf(K[j], proj, norm=1, ncol=2).round().long().unbind(-1)
255
+
256
+ # check which points are actually in the visible cone
257
+ msk_i = (proj_depth > 0) & (0 <= u) & (u < Wj) & (0 <= v) & (v < Hj)
258
+ msk_j = v[msk_i], u[msk_i]
259
+
260
+ # find bad points = those in front but less confident
261
+ bad_points = (proj_depth[msk_i] < (1-tol) * depthmaps[j][msk_j]
262
+ ) & (res.im_conf[i][msk_i] < res.im_conf[j][msk_j])
263
+
264
+ bad_msk_i = msk_i.clone()
265
+ bad_msk_i[msk_i] = bad_points
266
+ res.im_conf[i][bad_msk_i] = res.im_conf[i][bad_msk_i].clip_(max=max_bad_conf)
267
+
268
+ return res
269
+
270
+ def forward(self, ret_details=False):
271
+ pw_poses = self.get_pw_poses() # cam-to-world
272
+ pw_adapt = self.get_adaptors()
273
+ proj_pts3d = self.get_pts3d()
274
+ # pre-compute pixel weights
275
+ weight_i = {i_j: self.conf_trf(c) for i_j, c in self.conf_i.items()}
276
+ weight_j = {i_j: self.conf_trf(c) for i_j, c in self.conf_j.items()}
277
+
278
+ loss = 0
279
+ if ret_details:
280
+ details = -torch.ones((self.n_imgs, self.n_imgs))
281
+
282
+ for e, (i, j) in enumerate(self.edges):
283
+ i_j = edge_str(i, j)
284
+ # distance in image i and j
285
+ aligned_pred_i = geotrf(pw_poses[e], pw_adapt[e] * self.pred_i[i_j])
286
+ aligned_pred_j = geotrf(pw_poses[e], pw_adapt[e] * self.pred_j[i_j])
287
+ li = self.dist(proj_pts3d[i], aligned_pred_i, weight=weight_i[i_j]).mean()
288
+ lj = self.dist(proj_pts3d[j], aligned_pred_j, weight=weight_j[i_j]).mean()
289
+ loss = loss + li + lj
290
+
291
+ if ret_details:
292
+ details[i, j] = li + lj
293
+ loss /= self.n_edges # average over all pairs
294
+
295
+ if ret_details:
296
+ return loss, details
297
+ return loss
298
+
299
+ @torch.cuda.amp.autocast(enabled=False)
300
+ def compute_global_alignment(self, init=None, niter_PnP=10, **kw):
301
+ if init is None:
302
+ pass
303
+ elif init == 'msp' or init == 'mst':
304
+ init_fun.init_minimum_spanning_tree(self, niter_PnP=niter_PnP)
305
+ elif init == 'known_poses':
306
+ init_fun.init_from_known_poses(self, min_conf_thr=self.min_conf_thr,
307
+ niter_PnP=niter_PnP)
308
+ else:
309
+ raise ValueError(f'bad value for {init=}')
310
+
311
+ return global_alignment_loop(self, **kw)
312
+
313
+ @torch.no_grad()
314
+ def mask_sky(self):
315
+ res = deepcopy(self)
316
+ for i in range(self.n_imgs):
317
+ sky = segment_sky(self.imgs[i])
318
+ res.im_conf[i][sky] = 0
319
+ return res
320
+
321
+ def show(self, show_pw_cams=False, show_pw_pts3d=False, cam_size=None, **kw):
322
+ viz = SceneViz()
323
+ if self.imgs is None:
324
+ colors = np.random.randint(0, 256, size=(self.n_imgs, 3))
325
+ colors = list(map(tuple, colors.tolist()))
326
+ for n in range(self.n_imgs):
327
+ viz.add_pointcloud(self.get_pts3d()[n], colors[n], self.get_masks()[n])
328
+ else:
329
+ viz.add_pointcloud(self.get_pts3d(), self.imgs, self.get_masks())
330
+ colors = np.random.randint(256, size=(self.n_imgs, 3))
331
+
332
+ # camera poses
333
+ im_poses = to_numpy(self.get_im_poses())
334
+ if cam_size is None:
335
+ cam_size = auto_cam_size(im_poses)
336
+ viz.add_cameras(im_poses, self.get_focals(), colors=colors,
337
+ images=self.imgs, imsizes=self.imsizes, cam_size=cam_size)
338
+ if show_pw_cams:
339
+ pw_poses = self.get_pw_poses()
340
+ viz.add_cameras(pw_poses, color=(192, 0, 192), cam_size=cam_size)
341
+
342
+ if show_pw_pts3d:
343
+ pts = [geotrf(pw_poses[e], self.pred_i[edge_str(i, j)]) for e, (i, j) in enumerate(self.edges)]
344
+ viz.add_pointcloud(pts, (128, 0, 128))
345
+
346
+ viz.show(**kw)
347
+ return viz
348
+
349
+
350
+ def global_alignment_loop(net, lr=0.01, niter=300, schedule='cosine', lr_min=1e-6):
351
+ params = [p for p in net.parameters() if p.requires_grad]
352
+ if not params:
353
+ return net
354
+
355
+ verbose = net.verbose
356
+ if verbose:
357
+ print('Global alignement - optimizing for:')
358
+ print([name for name, value in net.named_parameters() if value.requires_grad])
359
+
360
+ lr_base = lr
361
+ optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.9))
362
+
363
+ loss = float('inf')
364
+ if verbose:
365
+ with tqdm.tqdm(total=niter) as bar:
366
+ while bar.n < bar.total:
367
+ loss = global_alignment_iter(net, bar.n, niter, lr_base, lr_min, optimizer, schedule)
368
+ bar.set_postfix_str(f'{lr=:g} loss={loss:g}')
369
+ bar.update()
370
+ else:
371
+ for n in range(niter):
372
+ loss = global_alignment_iter(net, n, niter, lr_base, lr_min, optimizer, schedule)
373
+ return loss
374
+
375
+
376
+ def global_alignment_iter(net, cur_iter, niter, lr_base, lr_min, optimizer, schedule):
377
+ t = cur_iter / niter
378
+ if schedule == 'cosine':
379
+ lr = cosine_schedule(t, lr_base, lr_min)
380
+ elif schedule == 'linear':
381
+ lr = linear_schedule(t, lr_base, lr_min)
382
+ else:
383
+ raise ValueError(f'bad lr {schedule=}')
384
+ adjust_learning_rate_by_lr(optimizer, lr)
385
+ optimizer.zero_grad()
386
+ loss = net()
387
+ loss.backward()
388
+ optimizer.step()
389
+
390
+ return float(loss)
mini_dust3r/cloud_opt/commons.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utility functions for global alignment
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import torch.nn as nn
9
+ import numpy as np
10
+
11
+
12
+ def edge_str(i, j):
13
+ return f'{i}_{j}'
14
+
15
+
16
+ def i_j_ij(ij):
17
+ return edge_str(*ij), ij
18
+
19
+
20
+ def edge_conf(conf_i, conf_j, edge):
21
+ return float(conf_i[edge].mean() * conf_j[edge].mean())
22
+
23
+
24
+ def compute_edge_scores(edges, conf_i, conf_j):
25
+ return {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges}
26
+
27
+
28
+ def NoGradParamDict(x):
29
+ assert isinstance(x, dict)
30
+ return nn.ParameterDict(x).requires_grad_(False)
31
+
32
+
33
+ def get_imshapes(edges, pred_i, pred_j):
34
+ n_imgs = max(max(e) for e in edges) + 1
35
+ imshapes = [None] * n_imgs
36
+ for e, (i, j) in enumerate(edges):
37
+ shape_i = tuple(pred_i[e].shape[0:2])
38
+ shape_j = tuple(pred_j[e].shape[0:2])
39
+ if imshapes[i]:
40
+ assert imshapes[i] == shape_i, f'incorrect shape for image {i}'
41
+ if imshapes[j]:
42
+ assert imshapes[j] == shape_j, f'incorrect shape for image {j}'
43
+ imshapes[i] = shape_i
44
+ imshapes[j] = shape_j
45
+ return imshapes
46
+
47
+
48
+ def get_conf_trf(mode):
49
+ if mode == 'log':
50
+ def conf_trf(x): return x.log()
51
+ elif mode == 'sqrt':
52
+ def conf_trf(x): return x.sqrt()
53
+ elif mode == 'm1':
54
+ def conf_trf(x): return x-1
55
+ elif mode in ('id', 'none'):
56
+ def conf_trf(x): return x
57
+ else:
58
+ raise ValueError(f'bad mode for {mode=}')
59
+ return conf_trf
60
+
61
+
62
+ def l2_dist(a, b, weight):
63
+ return ((a - b).square().sum(dim=-1) * weight)
64
+
65
+
66
+ def l1_dist(a, b, weight):
67
+ return ((a - b).norm(dim=-1) * weight)
68
+
69
+
70
+ ALL_DISTS = dict(l1=l1_dist, l2=l2_dist)
71
+
72
+
73
+ def signed_log1p(x):
74
+ sign = torch.sign(x)
75
+ return sign * torch.log1p(torch.abs(x))
76
+
77
+
78
+ def signed_expm1(x):
79
+ sign = torch.sign(x)
80
+ return sign * torch.expm1(torch.abs(x))
81
+
82
+
83
+ def cosine_schedule(t, lr_start, lr_end):
84
+ assert 0 <= t <= 1
85
+ return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2
86
+
87
+
88
+ def linear_schedule(t, lr_start, lr_end):
89
+ assert 0 <= t <= 1
90
+ return lr_start + (lr_end - lr_start) * t
mini_dust3r/cloud_opt/init_im_poses.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Initialization functions for global alignment
6
+ # --------------------------------------------------------
7
+ from functools import cache
8
+
9
+ import numpy as np
10
+ import scipy.sparse as sp
11
+ import torch
12
+ import cv2
13
+ import roma
14
+ from tqdm import tqdm
15
+
16
+ from mini_dust3r.utils.geometry import geotrf, inv, get_med_dist_between_poses
17
+ from mini_dust3r.post_process import estimate_focal_knowing_depth
18
+ from mini_dust3r.viz import to_numpy
19
+
20
+ from mini_dust3r.cloud_opt.commons import edge_str, i_j_ij, compute_edge_scores
21
+
22
+
23
+ @torch.no_grad()
24
+ def init_from_known_poses(self, niter_PnP=10, min_conf_thr=3):
25
+ device = self.device
26
+
27
+ # indices of known poses
28
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
29
+ assert nkp == self.n_imgs, 'not all poses are known'
30
+
31
+ # get all focals
32
+ nkf, _, im_focals = get_known_focals(self)
33
+ assert nkf == self.n_imgs
34
+ im_pp = self.get_principal_points()
35
+
36
+ best_depthmaps = {}
37
+ # init all pairwise poses
38
+ for e, (i, j) in enumerate(tqdm(self.edges, disable=not self.verbose)):
39
+ i_j = edge_str(i, j)
40
+
41
+ # find relative pose for this pair
42
+ P1 = torch.eye(4, device=device)
43
+ msk = self.conf_i[i_j] > min(min_conf_thr, self.conf_i[i_j].min() - 0.1)
44
+ _, P2 = fast_pnp(self.pred_j[i_j], float(im_focals[i].mean()),
45
+ pp=im_pp[i], msk=msk, device=device, niter_PnP=niter_PnP)
46
+
47
+ # align the two predicted camera with the two gt cameras
48
+ s, R, T = align_multiple_poses(torch.stack((P1, P2)), known_poses[[i, j]])
49
+ # normally we have known_poses[i] ~= sRT_to_4x4(s,R,T,device) @ P1
50
+ # and geotrf(sRT_to_4x4(1,R,T,device), s*P2[:3,3])
51
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
52
+
53
+ # remember if this is a good depthmap
54
+ score = float(self.conf_i[i_j].mean())
55
+ if score > best_depthmaps.get(i, (0,))[0]:
56
+ best_depthmaps[i] = score, i_j, s
57
+
58
+ # init all image poses
59
+ for n in range(self.n_imgs):
60
+ assert known_poses_msk[n]
61
+ _, i_j, scale = best_depthmaps[n]
62
+ depth = self.pred_i[i_j][:, :, 2]
63
+ self._set_depthmap(n, depth * scale)
64
+
65
+
66
+ @torch.no_grad()
67
+ def init_minimum_spanning_tree(self, **kw):
68
+ """ Init all camera poses (image-wise and pairwise poses) given
69
+ an initial set of pairwise estimations.
70
+ """
71
+ device = self.device
72
+ pts3d, _, im_focals, im_poses = minimum_spanning_tree(self.imshapes, self.edges,
73
+ self.pred_i, self.pred_j, self.conf_i, self.conf_j, self.im_conf, self.min_conf_thr,
74
+ device, has_im_poses=self.has_im_poses, verbose=self.verbose,
75
+ **kw)
76
+
77
+ return init_from_pts3d(self, pts3d, im_focals, im_poses)
78
+
79
+
80
+ def init_from_pts3d(self, pts3d, im_focals, im_poses):
81
+ # init poses
82
+ nkp, known_poses_msk, known_poses = get_known_poses(self)
83
+ if nkp == 1:
84
+ raise NotImplementedError("Would be simpler to just align everything afterwards on the single known pose")
85
+ elif nkp > 1:
86
+ # global rigid SE3 alignment
87
+ s, R, T = align_multiple_poses(im_poses[known_poses_msk], known_poses[known_poses_msk])
88
+ trf = sRT_to_4x4(s, R, T, device=known_poses.device)
89
+
90
+ # rotate everything
91
+ im_poses = trf @ im_poses
92
+ im_poses[:, :3, :3] /= s # undo scaling on the rotation part
93
+ for img_pts3d in pts3d:
94
+ img_pts3d[:] = geotrf(trf, img_pts3d)
95
+
96
+ # set all pairwise poses
97
+ for e, (i, j) in enumerate(self.edges):
98
+ i_j = edge_str(i, j)
99
+ # compute transform that goes from cam to world
100
+ s, R, T = rigid_points_registration(self.pred_i[i_j], pts3d[i], conf=self.conf_i[i_j])
101
+ self._set_pose(self.pw_poses, e, R, T, scale=s)
102
+
103
+ # take into account the scale normalization
104
+ s_factor = self.get_pw_norm_scale_factor()
105
+ im_poses[:, :3, 3] *= s_factor # apply downscaling factor
106
+ for img_pts3d in pts3d:
107
+ img_pts3d *= s_factor
108
+
109
+ # init all image poses
110
+ if self.has_im_poses:
111
+ for i in range(self.n_imgs):
112
+ cam2world = im_poses[i]
113
+ depth = geotrf(inv(cam2world), pts3d[i])[..., 2]
114
+ self._set_depthmap(i, depth)
115
+ self._set_pose(self.im_poses, i, cam2world)
116
+ if im_focals[i] is not None:
117
+ self._set_focal(i, im_focals[i])
118
+
119
+ if self.verbose:
120
+ print(' init loss =', float(self()))
121
+
122
+
123
+ def minimum_spanning_tree(imshapes, edges, pred_i, pred_j, conf_i, conf_j, im_conf, min_conf_thr,
124
+ device, has_im_poses=True, niter_PnP=10, verbose=True):
125
+ n_imgs = len(imshapes)
126
+ sparse_graph = -dict_to_sparse_graph(compute_edge_scores(map(i_j_ij, edges), conf_i, conf_j))
127
+ msp = sp.csgraph.minimum_spanning_tree(sparse_graph).tocoo()
128
+
129
+ # temp variable to store 3d points
130
+ pts3d = [None] * len(imshapes)
131
+
132
+ todo = sorted(zip(-msp.data, msp.row, msp.col)) # sorted edges
133
+ im_poses = [None] * n_imgs
134
+ im_focals = [None] * n_imgs
135
+
136
+ # init with strongest edge
137
+ score, i, j = todo.pop()
138
+ if verbose:
139
+ print(f' init edge ({i}*,{j}*) {score=}')
140
+ i_j = edge_str(i, j)
141
+ pts3d[i] = pred_i[i_j].clone()
142
+ pts3d[j] = pred_j[i_j].clone()
143
+ done = {i, j}
144
+ if has_im_poses:
145
+ im_poses[i] = torch.eye(4, device=device)
146
+ im_focals[i] = estimate_focal(pred_i[i_j])
147
+
148
+ # set initial pointcloud based on pairwise graph
149
+ msp_edges = [(i, j)]
150
+ while todo:
151
+ # each time, predict the next one
152
+ score, i, j = todo.pop()
153
+
154
+ if im_focals[i] is None:
155
+ im_focals[i] = estimate_focal(pred_i[i_j])
156
+
157
+ if i in done:
158
+ if verbose:
159
+ print(f' init edge ({i},{j}*) {score=}')
160
+ assert j not in done
161
+ # align pred[i] with pts3d[i], and then set j accordingly
162
+ i_j = edge_str(i, j)
163
+ s, R, T = rigid_points_registration(pred_i[i_j], pts3d[i], conf=conf_i[i_j])
164
+ trf = sRT_to_4x4(s, R, T, device)
165
+ pts3d[j] = geotrf(trf, pred_j[i_j])
166
+ done.add(j)
167
+ msp_edges.append((i, j))
168
+
169
+ if has_im_poses and im_poses[i] is None:
170
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
171
+
172
+ elif j in done:
173
+ if verbose:
174
+ print(f' init edge ({i}*,{j}) {score=}')
175
+ assert i not in done
176
+ i_j = edge_str(i, j)
177
+ s, R, T = rigid_points_registration(pred_j[i_j], pts3d[j], conf=conf_j[i_j])
178
+ trf = sRT_to_4x4(s, R, T, device)
179
+ pts3d[i] = geotrf(trf, pred_i[i_j])
180
+ done.add(i)
181
+ msp_edges.append((i, j))
182
+
183
+ if has_im_poses and im_poses[i] is None:
184
+ im_poses[i] = sRT_to_4x4(1, R, T, device)
185
+ else:
186
+ # let's try again later
187
+ todo.insert(0, (score, i, j))
188
+
189
+ if has_im_poses:
190
+ # complete all missing informations
191
+ pair_scores = list(sparse_graph.values()) # already negative scores: less is best
192
+ edges_from_best_to_worse = np.array(list(sparse_graph.keys()))[np.argsort(pair_scores)]
193
+ for i, j in edges_from_best_to_worse.tolist():
194
+ if im_focals[i] is None:
195
+ im_focals[i] = estimate_focal(pred_i[edge_str(i, j)])
196
+
197
+ for i in range(n_imgs):
198
+ if im_poses[i] is None:
199
+ msk = im_conf[i] > min_conf_thr
200
+ res = fast_pnp(pts3d[i], im_focals[i], msk=msk, device=device, niter_PnP=niter_PnP)
201
+ if res:
202
+ im_focals[i], im_poses[i] = res
203
+ if im_poses[i] is None:
204
+ im_poses[i] = torch.eye(4, device=device)
205
+ im_poses = torch.stack(im_poses)
206
+ else:
207
+ im_poses = im_focals = None
208
+
209
+ return pts3d, msp_edges, im_focals, im_poses
210
+
211
+
212
+ def dict_to_sparse_graph(dic):
213
+ n_imgs = max(max(e) for e in dic) + 1
214
+ res = sp.dok_array((n_imgs, n_imgs))
215
+ for edge, value in dic.items():
216
+ res[edge] = value
217
+ return res
218
+
219
+
220
+ def rigid_points_registration(pts1, pts2, conf):
221
+ R, T, s = roma.rigid_points_registration(
222
+ pts1.reshape(-1, 3), pts2.reshape(-1, 3), weights=conf.ravel(), compute_scaling=True)
223
+ return s, R, T # return un-scaled (R, T)
224
+
225
+
226
+ def sRT_to_4x4(scale, R, T, device):
227
+ trf = torch.eye(4, device=device)
228
+ trf[:3, :3] = R * scale
229
+ trf[:3, 3] = T.ravel() # doesn't need scaling
230
+ return trf
231
+
232
+
233
+ def estimate_focal(pts3d_i, pp=None):
234
+ if pp is None:
235
+ H, W, THREE = pts3d_i.shape
236
+ assert THREE == 3
237
+ pp = torch.tensor((W/2, H/2), device=pts3d_i.device)
238
+ focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel()
239
+ return float(focal)
240
+
241
+
242
+ @cache
243
+ def pixel_grid(H, W):
244
+ return np.mgrid[:W, :H].T.astype(np.float32)
245
+
246
+
247
+ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
248
+ # extract camera poses and focals with RANSAC-PnP
249
+ if msk.sum() < 4:
250
+ return None # we need at least 4 points for PnP
251
+ pts3d, msk = map(to_numpy, (pts3d, msk))
252
+
253
+ H, W, THREE = pts3d.shape
254
+ assert THREE == 3
255
+ pixels = pixel_grid(H, W)
256
+
257
+ if focal is None:
258
+ S = max(W, H)
259
+ tentative_focals = np.geomspace(S/2, S*3, 21)
260
+ else:
261
+ tentative_focals = [focal]
262
+
263
+ if pp is None:
264
+ pp = (W/2, H/2)
265
+ else:
266
+ pp = to_numpy(pp)
267
+
268
+ best = 0,
269
+ for focal in tentative_focals:
270
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
271
+
272
+ success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
273
+ iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
274
+ if not success:
275
+ continue
276
+
277
+ score = len(inliers)
278
+ if success and score > best[0]:
279
+ best = score, R, T, focal
280
+
281
+ if not best[0]:
282
+ return None
283
+
284
+ _, R, T, best_focal = best
285
+ R = cv2.Rodrigues(R)[0] # world to cam
286
+ R, T = map(torch.from_numpy, (R, T))
287
+ return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
288
+
289
+
290
+ def get_known_poses(self):
291
+ if self.has_im_poses:
292
+ known_poses_msk = torch.tensor([not (p.requires_grad) for p in self.im_poses])
293
+ known_poses = self.get_im_poses()
294
+ return known_poses_msk.sum(), known_poses_msk, known_poses
295
+ else:
296
+ return 0, None, None
297
+
298
+
299
+ def get_known_focals(self):
300
+ if self.has_im_poses:
301
+ known_focal_msk = self.get_known_focal_mask()
302
+ known_focals = self.get_focals()
303
+ return known_focal_msk.sum(), known_focal_msk, known_focals
304
+ else:
305
+ return 0, None, None
306
+
307
+
308
+ def align_multiple_poses(src_poses, target_poses):
309
+ N = len(src_poses)
310
+ assert src_poses.shape == target_poses.shape == (N, 4, 4)
311
+
312
+ def center_and_z(poses):
313
+ eps = get_med_dist_between_poses(poses) / 100
314
+ return torch.cat((poses[:, :3, 3], poses[:, :3, 3] + eps*poses[:, :3, 2]))
315
+ R, T, s = roma.rigid_points_registration(center_and_z(src_poses), center_and_z(target_poses), compute_scaling=True)
316
+ return s, R, T
mini_dust3r/cloud_opt/modular_optimizer.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Slower implementation of the global alignment that allows to freeze partial poses/intrinsics
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
12
+ from mini_dust3r.utils.geometry import geotrf
13
+ from mini_dust3r.utils.device import to_cpu, to_numpy
14
+ from mini_dust3r.utils.geometry import depthmap_to_pts3d
15
+
16
+
17
+ class ModularPointCloudOptimizer (BasePCOptimizer):
18
+ """ Optimize a global scene, given a list of pairwise observations.
19
+ Unlike PointCloudOptimizer, you can fix parts of the optimization process (partial poses/intrinsics)
20
+ Graph node: images
21
+ Graph edges: observations = (pred1, pred2)
22
+ """
23
+
24
+ def __init__(self, *args, optimize_pp=False, fx_and_fy=False, focal_brake=20, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ self.has_im_poses = True # by definition of this class
27
+ self.focal_brake = focal_brake
28
+
29
+ # adding thing to optimize
30
+ self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
31
+ self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
32
+ default_focals = [self.focal_brake * np.log(max(H, W)) for H, W in self.imshapes]
33
+ self.im_focals = nn.ParameterList(torch.FloatTensor([f, f] if fx_and_fy else [
34
+ f]) for f in default_focals) # camera intrinsics
35
+ self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
36
+ self.im_pp.requires_grad_(optimize_pp)
37
+
38
+ def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
39
+ if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
40
+ known_poses = [known_poses]
41
+ for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
42
+ if self.verbose:
43
+ print(f' (setting pose #{idx} = {pose[:3,3]})')
44
+ self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose), force=True))
45
+
46
+ # normalize scale if there's less than 1 known pose
47
+ n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
48
+ self.norm_pw_scale = (n_known_poses <= 1)
49
+
50
+ def preset_intrinsics(self, known_intrinsics, msk=None):
51
+ if isinstance(known_intrinsics, torch.Tensor) and known_intrinsics.ndim == 2:
52
+ known_intrinsics = [known_intrinsics]
53
+ for K in known_intrinsics:
54
+ assert K.shape == (3, 3)
55
+ self.preset_focal([K.diagonal()[:2].mean() for K in known_intrinsics], msk)
56
+ self.preset_principal_point([K[:2, 2] for K in known_intrinsics], msk)
57
+
58
+ def preset_focal(self, known_focals, msk=None):
59
+ for idx, focal in zip(self._get_msk_indices(msk), known_focals):
60
+ if self.verbose:
61
+ print(f' (setting focal #{idx} = {focal})')
62
+ self._no_grad(self._set_focal(idx, focal, force=True))
63
+
64
+ def preset_principal_point(self, known_pp, msk=None):
65
+ for idx, pp in zip(self._get_msk_indices(msk), known_pp):
66
+ if self.verbose:
67
+ print(f' (setting principal point #{idx} = {pp})')
68
+ self._no_grad(self._set_principal_point(idx, pp, force=True))
69
+
70
+ def _no_grad(self, tensor):
71
+ return tensor.requires_grad_(False)
72
+
73
+ def _get_msk_indices(self, msk):
74
+ if msk is None:
75
+ return range(self.n_imgs)
76
+ elif isinstance(msk, int):
77
+ return [msk]
78
+ elif isinstance(msk, (tuple, list)):
79
+ return self._get_msk_indices(np.array(msk))
80
+ elif msk.dtype in (bool, torch.bool, np.bool_):
81
+ assert len(msk) == self.n_imgs
82
+ return np.where(msk)[0]
83
+ elif np.issubdtype(msk.dtype, np.integer):
84
+ return msk
85
+ else:
86
+ raise ValueError(f'bad {msk=}')
87
+
88
+ def _set_focal(self, idx, focal, force=False):
89
+ param = self.im_focals[idx]
90
+ if param.requires_grad or force: # can only init a parameter not already initialized
91
+ param.data[:] = self.focal_brake * np.log(focal)
92
+ return param
93
+
94
+ def get_focals(self):
95
+ log_focals = torch.stack(list(self.im_focals), dim=0)
96
+ return (log_focals / self.focal_brake).exp()
97
+
98
+ def _set_principal_point(self, idx, pp, force=False):
99
+ param = self.im_pp[idx]
100
+ H, W = self.imshapes[idx]
101
+ if param.requires_grad or force: # can only init a parameter not already initialized
102
+ param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
103
+ return param
104
+
105
+ def get_principal_points(self):
106
+ return torch.stack([pp.new((W/2, H/2))+10*pp for pp, (H, W) in zip(self.im_pp, self.imshapes)])
107
+
108
+ def get_intrinsics(self):
109
+ K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
110
+ focals = self.get_focals().view(self.n_imgs, -1)
111
+ K[:, 0, 0] = focals[:, 0]
112
+ K[:, 1, 1] = focals[:, -1]
113
+ K[:, :2, 2] = self.get_principal_points()
114
+ K[:, 2, 2] = 1
115
+ return K
116
+
117
+ def get_im_poses(self): # cam to world
118
+ cam2world = self._get_poses(torch.stack(list(self.im_poses)))
119
+ return cam2world
120
+
121
+ def _set_depthmap(self, idx, depth, force=False):
122
+ param = self.im_depthmaps[idx]
123
+ if param.requires_grad or force: # can only init a parameter not already initialized
124
+ param.data[:] = depth.log().nan_to_num(neginf=0)
125
+ return param
126
+
127
+ def get_depthmaps(self):
128
+ return [d.exp() for d in self.im_depthmaps]
129
+
130
+ def depth_to_pts3d(self):
131
+ # Get depths and projection params if not provided
132
+ focals = self.get_focals()
133
+ pp = self.get_principal_points()
134
+ im_poses = self.get_im_poses()
135
+ depth = self.get_depthmaps()
136
+
137
+ # convert focal to (1,2,H,W) constant field
138
+ def focal_ex(i): return focals[i][..., None, None].expand(1, *focals[i].shape, *self.imshapes[i])
139
+ # get pointmaps in camera frame
140
+ rel_ptmaps = [depthmap_to_pts3d(depth[i][None], focal_ex(i), pp=pp[i:i+1])[0] for i in range(im_poses.shape[0])]
141
+ # project to world frame
142
+ return [geotrf(pose, ptmap) for pose, ptmap in zip(im_poses, rel_ptmaps)]
143
+
144
+ def get_pts3d(self):
145
+ return self.depth_to_pts3d()
mini_dust3r/cloud_opt/optimizer.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Main class for the implementation of the global alignment
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
12
+ from mini_dust3r.utils.geometry import xy_grid, geotrf
13
+ from mini_dust3r.utils.device import to_cpu, to_numpy
14
+
15
+
16
+ class PointCloudOptimizer(BasePCOptimizer):
17
+ """ Optimize a global scene, given a list of pairwise observations.
18
+ Graph node: images
19
+ Graph edges: observations = (pred1, pred2)
20
+ """
21
+
22
+ def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+
25
+ self.has_im_poses = True # by definition of this class
26
+ self.focal_break = focal_break
27
+
28
+ # adding thing to optimize
29
+ self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
30
+ self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
31
+ self.im_focals = nn.ParameterList(torch.FloatTensor(
32
+ [self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
33
+ self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
34
+ self.im_pp.requires_grad_(optimize_pp)
35
+
36
+ self.imshape = self.imshapes[0]
37
+ im_areas = [h*w for h, w in self.imshapes]
38
+ self.max_area = max(im_areas)
39
+
40
+ # adding thing to optimize
41
+ self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
42
+ self.im_poses = ParameterStack(self.im_poses, is_param=True)
43
+ self.im_focals = ParameterStack(self.im_focals, is_param=True)
44
+ self.im_pp = ParameterStack(self.im_pp, is_param=True)
45
+ self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
46
+ self.register_buffer('_grid', ParameterStack(
47
+ [xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
48
+
49
+ # pre-compute pixel weights
50
+ self.register_buffer('_weight_i', ParameterStack(
51
+ [self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
52
+ self.register_buffer('_weight_j', ParameterStack(
53
+ [self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
54
+
55
+ # precompute aa
56
+ self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
57
+ self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
58
+ self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
59
+ self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
60
+ self.total_area_i = sum([im_areas[i] for i, j in self.edges])
61
+ self.total_area_j = sum([im_areas[j] for i, j in self.edges])
62
+
63
+ def _check_all_imgs_are_selected(self, msk):
64
+ assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
65
+
66
+ def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
67
+ self._check_all_imgs_are_selected(pose_msk)
68
+
69
+ if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
70
+ known_poses = [known_poses]
71
+ for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
72
+ if self.verbose:
73
+ print(f' (setting pose #{idx} = {pose[:3,3]})')
74
+ self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
75
+
76
+ # normalize scale if there's less than 1 known pose
77
+ n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
78
+ self.norm_pw_scale = (n_known_poses <= 1)
79
+
80
+ self.im_poses.requires_grad_(False)
81
+ self.norm_pw_scale = False
82
+
83
+ def preset_focal(self, known_focals, msk=None):
84
+ self._check_all_imgs_are_selected(msk)
85
+
86
+ for idx, focal in zip(self._get_msk_indices(msk), known_focals):
87
+ if self.verbose:
88
+ print(f' (setting focal #{idx} = {focal})')
89
+ self._no_grad(self._set_focal(idx, focal))
90
+
91
+ self.im_focals.requires_grad_(False)
92
+
93
+ def preset_principal_point(self, known_pp, msk=None):
94
+ self._check_all_imgs_are_selected(msk)
95
+
96
+ for idx, pp in zip(self._get_msk_indices(msk), known_pp):
97
+ if self.verbose:
98
+ print(f' (setting principal point #{idx} = {pp})')
99
+ self._no_grad(self._set_principal_point(idx, pp))
100
+
101
+ self.im_pp.requires_grad_(False)
102
+
103
+ def _get_msk_indices(self, msk):
104
+ if msk is None:
105
+ return range(self.n_imgs)
106
+ elif isinstance(msk, int):
107
+ return [msk]
108
+ elif isinstance(msk, (tuple, list)):
109
+ return self._get_msk_indices(np.array(msk))
110
+ elif msk.dtype in (bool, torch.bool, np.bool_):
111
+ assert len(msk) == self.n_imgs
112
+ return np.where(msk)[0]
113
+ elif np.issubdtype(msk.dtype, np.integer):
114
+ return msk
115
+ else:
116
+ raise ValueError(f'bad {msk=}')
117
+
118
+ def _no_grad(self, tensor):
119
+ assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
120
+
121
+ def _set_focal(self, idx, focal, force=False):
122
+ param = self.im_focals[idx]
123
+ if param.requires_grad or force: # can only init a parameter not already initialized
124
+ param.data[:] = self.focal_break * np.log(focal)
125
+ return param
126
+
127
+ def get_focals(self):
128
+ log_focals = torch.stack(list(self.im_focals), dim=0)
129
+ return (log_focals / self.focal_break).exp()
130
+
131
+ def get_known_focal_mask(self):
132
+ return torch.tensor([not (p.requires_grad) for p in self.im_focals])
133
+
134
+ def _set_principal_point(self, idx, pp, force=False):
135
+ param = self.im_pp[idx]
136
+ H, W = self.imshapes[idx]
137
+ if param.requires_grad or force: # can only init a parameter not already initialized
138
+ param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
139
+ return param
140
+
141
+ def get_principal_points(self):
142
+ return self._pp + 10 * self.im_pp
143
+
144
+ def get_intrinsics(self):
145
+ K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
146
+ focals = self.get_focals().flatten()
147
+ K[:, 0, 0] = K[:, 1, 1] = focals
148
+ K[:, :2, 2] = self.get_principal_points()
149
+ K[:, 2, 2] = 1
150
+ return K
151
+
152
+ def get_im_poses(self): # cam to world
153
+ cam2world = self._get_poses(self.im_poses)
154
+ return cam2world
155
+
156
+ def _set_depthmap(self, idx, depth, force=False):
157
+ depth = _ravel_hw(depth, self.max_area)
158
+
159
+ param = self.im_depthmaps[idx]
160
+ if param.requires_grad or force: # can only init a parameter not already initialized
161
+ param.data[:] = depth.log().nan_to_num(neginf=0)
162
+ return param
163
+
164
+ def get_depthmaps(self, raw=False):
165
+ res = self.im_depthmaps.exp()
166
+ if not raw:
167
+ res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
168
+ return res
169
+
170
+ def depth_to_pts3d(self):
171
+ # Get depths and projection params if not provided
172
+ focals = self.get_focals()
173
+ pp = self.get_principal_points()
174
+ im_poses = self.get_im_poses()
175
+ depth = self.get_depthmaps(raw=True)
176
+
177
+ # get pointmaps in camera frame
178
+ rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
179
+ # project to world frame
180
+ return geotrf(im_poses, rel_ptmaps)
181
+
182
+ def get_pts3d(self, raw=False):
183
+ res = self.depth_to_pts3d()
184
+ if not raw:
185
+ res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
186
+ return res
187
+
188
+ def forward(self):
189
+ pw_poses = self.get_pw_poses() # cam-to-world
190
+ pw_adapt = self.get_adaptors().unsqueeze(1)
191
+ proj_pts3d = self.get_pts3d(raw=True)
192
+
193
+ # rotate pairwise prediction according to pw_poses
194
+ aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
195
+ aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
196
+
197
+ # compute the less
198
+ li = self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
199
+ lj = self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
200
+
201
+ return li + lj
202
+
203
+
204
+ def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
205
+ pp = pp.unsqueeze(1)
206
+ focal = focal.unsqueeze(1)
207
+ assert focal.shape == (len(depth), 1, 1)
208
+ assert pp.shape == (len(depth), 1, 2)
209
+ assert pixel_grid.shape == depth.shape + (2,)
210
+ depth = depth.unsqueeze(-1)
211
+ return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
212
+
213
+
214
+ def ParameterStack(params, keys=None, is_param=None, fill=0):
215
+ if keys is not None:
216
+ params = [params[k] for k in keys]
217
+
218
+ if fill > 0:
219
+ params = [_ravel_hw(p, fill) for p in params]
220
+
221
+ requires_grad = params[0].requires_grad
222
+ assert all(p.requires_grad == requires_grad for p in params)
223
+
224
+ params = torch.stack(list(params)).float().detach()
225
+ if is_param or requires_grad:
226
+ params = nn.Parameter(params)
227
+ params.requires_grad_(requires_grad)
228
+ return params
229
+
230
+
231
+ def _ravel_hw(tensor, fill=0):
232
+ # ravel H,W
233
+ tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
234
+
235
+ if len(tensor) < fill:
236
+ tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
237
+ return tensor
238
+
239
+
240
+ def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
241
+ focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
242
+ return minf*focal_base, maxf*focal_base
243
+
244
+
245
+ def apply_mask(img, msk):
246
+ img = img.copy()
247
+ img[msk] = 0
248
+ return img
mini_dust3r/cloud_opt/pair_viewer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Dummy optimizer for visualizing pairs
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import cv2
11
+
12
+ from mini_dust3r.cloud_opt.base_opt import BasePCOptimizer
13
+ from mini_dust3r.utils.geometry import inv, geotrf, depthmap_to_absolute_camera_coordinates
14
+ from mini_dust3r.cloud_opt.commons import edge_str
15
+ from mini_dust3r.post_process import estimate_focal_knowing_depth
16
+
17
+
18
+ class PairViewer (BasePCOptimizer):
19
+ """
20
+ This a Dummy Optimizer.
21
+ To use only when the goal is to visualize the results for a pair of images (with is_symmetrized)
22
+ """
23
+
24
+ def __init__(self, *args, **kwargs):
25
+ super().__init__(*args, **kwargs)
26
+ assert self.is_symmetrized and self.n_edges == 2
27
+ self.has_im_poses = True
28
+
29
+ # compute all parameters directly from raw input
30
+ self.focals = []
31
+ self.pp = []
32
+ rel_poses = []
33
+ confs = []
34
+ for i in range(self.n_imgs):
35
+ conf = float(self.conf_i[edge_str(i, 1-i)].mean() * self.conf_j[edge_str(i, 1-i)].mean())
36
+ if self.verbose:
37
+ print(f' - {conf=:.3} for edge {i}-{1-i}')
38
+ confs.append(conf)
39
+
40
+ H, W = self.imshapes[i]
41
+ pts3d = self.pred_i[edge_str(i, 1-i)]
42
+ pp = torch.tensor((W/2, H/2))
43
+ focal = float(estimate_focal_knowing_depth(pts3d[None], pp, focal_mode='weiszfeld'))
44
+ self.focals.append(focal)
45
+ self.pp.append(pp)
46
+
47
+ # estimate the pose of pts1 in image 2
48
+ pixels = np.mgrid[:W, :H].T.astype(np.float32)
49
+ pts3d = self.pred_j[edge_str(1-i, i)].numpy()
50
+ assert pts3d.shape[:2] == (H, W)
51
+ msk = self.get_masks()[i].numpy()
52
+ K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
53
+
54
+ try:
55
+ res = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None,
56
+ iterationsCount=100, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP)
57
+ success, R, T, inliers = res
58
+ assert success
59
+
60
+ R = cv2.Rodrigues(R)[0] # world to cam
61
+ pose = inv(np.r_[np.c_[R, T], [(0, 0, 0, 1)]]) # cam to world
62
+ except:
63
+ pose = np.eye(4)
64
+ rel_poses.append(torch.from_numpy(pose.astype(np.float32)))
65
+
66
+ # let's use the pair with the most confidence
67
+ if confs[0] > confs[1]:
68
+ # ptcloud is expressed in camera1
69
+ self.im_poses = [torch.eye(4), rel_poses[1]] # I, cam2-to-cam1
70
+ self.depth = [self.pred_i['0_1'][..., 2], geotrf(inv(rel_poses[1]), self.pred_j['0_1'])[..., 2]]
71
+ else:
72
+ # ptcloud is expressed in camera2
73
+ self.im_poses = [rel_poses[0], torch.eye(4)] # I, cam1-to-cam2
74
+ self.depth = [geotrf(inv(rel_poses[0]), self.pred_j['1_0'])[..., 2], self.pred_i['1_0'][..., 2]]
75
+
76
+ self.im_poses = nn.Parameter(torch.stack(self.im_poses, dim=0), requires_grad=False)
77
+ self.focals = nn.Parameter(torch.tensor(self.focals), requires_grad=False)
78
+ self.pp = nn.Parameter(torch.stack(self.pp, dim=0), requires_grad=False)
79
+ self.depth = nn.ParameterList(self.depth)
80
+ for p in self.parameters():
81
+ p.requires_grad = False
82
+
83
+ def _set_depthmap(self, idx, depth, force=False):
84
+ if self.verbose:
85
+ print('_set_depthmap is ignored in PairViewer')
86
+ return
87
+
88
+ def get_depthmaps(self, raw=False):
89
+ depth = [d.to(self.device) for d in self.depth]
90
+ return depth
91
+
92
+ def _set_focal(self, idx, focal, force=False):
93
+ self.focals[idx] = focal
94
+
95
+ def get_focals(self):
96
+ return self.focals
97
+
98
+ def get_known_focal_mask(self):
99
+ return torch.tensor([not (p.requires_grad) for p in self.focals])
100
+
101
+ def get_principal_points(self):
102
+ return self.pp
103
+
104
+ def get_intrinsics(self):
105
+ focals = self.get_focals()
106
+ pps = self.get_principal_points()
107
+ K = torch.zeros((len(focals), 3, 3), device=self.device)
108
+ for i in range(len(focals)):
109
+ K[i, 0, 0] = K[i, 1, 1] = focals[i]
110
+ K[i, :2, 2] = pps[i]
111
+ K[i, 2, 2] = 1
112
+ return K
113
+
114
+ def get_im_poses(self):
115
+ return self.im_poses
116
+
117
+ def depth_to_pts3d(self):
118
+ pts3d = []
119
+ for d, intrinsics, im_pose in zip(self.depth, self.get_intrinsics(), self.get_im_poses()):
120
+ pts, _ = depthmap_to_absolute_camera_coordinates(d.cpu().numpy(),
121
+ intrinsics.cpu().numpy(),
122
+ im_pose.cpu().numpy())
123
+ pts3d.append(torch.from_numpy(pts).to(device=self.device))
124
+ return pts3d
125
+
126
+ def forward(self):
127
+ return float('nan')
mini_dust3r/croco/blocks.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Main encoder/decoder blocks
7
+ # --------------------------------------------------------
8
+ # References:
9
+ # timm
10
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from itertools import repeat
21
+ import collections.abc
22
+
23
+
24
+ def _ntuple(n):
25
+ def parse(x):
26
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
27
+ return x
28
+ return tuple(repeat(x, n))
29
+ return parse
30
+ to_2tuple = _ntuple(2)
31
+
32
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
33
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34
+ """
35
+ if drop_prob == 0. or not training:
36
+ return x
37
+ keep_prob = 1 - drop_prob
38
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
39
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
40
+ if keep_prob > 0.0 and scale_by_keep:
41
+ random_tensor.div_(keep_prob)
42
+ return x * random_tensor
43
+
44
+ class DropPath(nn.Module):
45
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
+ """
47
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
48
+ super(DropPath, self).__init__()
49
+ self.drop_prob = drop_prob
50
+ self.scale_by_keep = scale_by_keep
51
+
52
+ def forward(self, x):
53
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
54
+
55
+ def extra_repr(self):
56
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
57
+
58
+ class Mlp(nn.Module):
59
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
60
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
61
+ super().__init__()
62
+ out_features = out_features or in_features
63
+ hidden_features = hidden_features or in_features
64
+ bias = to_2tuple(bias)
65
+ drop_probs = to_2tuple(drop)
66
+
67
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
68
+ self.act = act_layer()
69
+ self.drop1 = nn.Dropout(drop_probs[0])
70
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
71
+ self.drop2 = nn.Dropout(drop_probs[1])
72
+
73
+ def forward(self, x):
74
+ x = self.fc1(x)
75
+ x = self.act(x)
76
+ x = self.drop1(x)
77
+ x = self.fc2(x)
78
+ x = self.drop2(x)
79
+ return x
80
+
81
+ class Attention(nn.Module):
82
+
83
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
84
+ super().__init__()
85
+ self.num_heads = num_heads
86
+ head_dim = dim // num_heads
87
+ self.scale = head_dim ** -0.5
88
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
89
+ self.attn_drop = nn.Dropout(attn_drop)
90
+ self.proj = nn.Linear(dim, dim)
91
+ self.proj_drop = nn.Dropout(proj_drop)
92
+ self.rope = rope
93
+
94
+ def forward(self, x, xpos):
95
+ B, N, C = x.shape
96
+
97
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
98
+ q, k, v = [qkv[:,:,i] for i in range(3)]
99
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
100
+
101
+ if self.rope is not None:
102
+ q = self.rope(q, xpos)
103
+ k = self.rope(k, xpos)
104
+
105
+ attn = (q @ k.transpose(-2, -1)) * self.scale
106
+ attn = attn.softmax(dim=-1)
107
+ attn = self.attn_drop(attn)
108
+
109
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
110
+ x = self.proj(x)
111
+ x = self.proj_drop(x)
112
+ return x
113
+
114
+ class Block(nn.Module):
115
+
116
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
117
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
118
+ super().__init__()
119
+ self.norm1 = norm_layer(dim)
120
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
121
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
122
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
123
+ self.norm2 = norm_layer(dim)
124
+ mlp_hidden_dim = int(dim * mlp_ratio)
125
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
126
+
127
+ def forward(self, x, xpos):
128
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
129
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
130
+ return x
131
+
132
+ class CrossAttention(nn.Module):
133
+
134
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
135
+ super().__init__()
136
+ self.num_heads = num_heads
137
+ head_dim = dim // num_heads
138
+ self.scale = head_dim ** -0.5
139
+
140
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
141
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
142
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
143
+ self.attn_drop = nn.Dropout(attn_drop)
144
+ self.proj = nn.Linear(dim, dim)
145
+ self.proj_drop = nn.Dropout(proj_drop)
146
+
147
+ self.rope = rope
148
+
149
+ def forward(self, query, key, value, qpos, kpos):
150
+ B, Nq, C = query.shape
151
+ Nk = key.shape[1]
152
+ Nv = value.shape[1]
153
+
154
+ q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
155
+ k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
156
+ v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
157
+
158
+ if self.rope is not None:
159
+ q = self.rope(q, qpos)
160
+ k = self.rope(k, kpos)
161
+
162
+ attn = (q @ k.transpose(-2, -1)) * self.scale
163
+ attn = attn.softmax(dim=-1)
164
+ attn = self.attn_drop(attn)
165
+
166
+ x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
167
+ x = self.proj(x)
168
+ x = self.proj_drop(x)
169
+ return x
170
+
171
+ class DecoderBlock(nn.Module):
172
+
173
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
174
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
175
+ super().__init__()
176
+ self.norm1 = norm_layer(dim)
177
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
178
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
179
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
180
+ self.norm2 = norm_layer(dim)
181
+ self.norm3 = norm_layer(dim)
182
+ mlp_hidden_dim = int(dim * mlp_ratio)
183
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
184
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
185
+
186
+ def forward(self, x, y, xpos, ypos):
187
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
188
+ y_ = self.norm_y(y)
189
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
190
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
191
+ return x, y
192
+
193
+
194
+ # patch embedding
195
+ class PositionGetter(object):
196
+ """ return positions of patches """
197
+
198
+ def __init__(self):
199
+ self.cache_positions = {}
200
+
201
+ def __call__(self, b, h, w, device):
202
+ if not (h,w) in self.cache_positions:
203
+ x = torch.arange(w, device=device)
204
+ y = torch.arange(h, device=device)
205
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
206
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
207
+ return pos
208
+
209
+ class PatchEmbed(nn.Module):
210
+ """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
211
+
212
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
213
+ super().__init__()
214
+ img_size = to_2tuple(img_size)
215
+ patch_size = to_2tuple(patch_size)
216
+ self.img_size = img_size
217
+ self.patch_size = patch_size
218
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
219
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
220
+ self.flatten = flatten
221
+
222
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
223
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
224
+
225
+ self.position_getter = PositionGetter()
226
+
227
+ def forward(self, x):
228
+ B, C, H, W = x.shape
229
+ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
230
+ torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
231
+ x = self.proj(x)
232
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
233
+ if self.flatten:
234
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
235
+ x = self.norm(x)
236
+ return x, pos
237
+
238
+ def _init_weights(self):
239
+ w = self.proj.weight.data
240
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
241
+
mini_dust3r/croco/croco.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # CroCo model during pretraining
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
14
+ from functools import partial
15
+
16
+ from mini_dust3r.croco.blocks import Block, DecoderBlock, PatchEmbed
17
+ from mini_dust3r.croco.pos_embed import get_2d_sincos_pos_embed, RoPE2D
18
+ from mini_dust3r.croco.masking import RandomMask
19
+
20
+
21
+ class CroCoNet(nn.Module):
22
+
23
+ def __init__(self,
24
+ img_size=224, # input image size
25
+ patch_size=16, # patch_size
26
+ mask_ratio=0.9, # ratios of masked tokens
27
+ enc_embed_dim=768, # encoder feature dimension
28
+ enc_depth=12, # encoder depth
29
+ enc_num_heads=12, # encoder number of heads in the transformer block
30
+ dec_embed_dim=512, # decoder feature dimension
31
+ dec_depth=8, # decoder depth
32
+ dec_num_heads=16, # decoder number of heads in the transformer block
33
+ mlp_ratio=4,
34
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
35
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
36
+ pos_embed='cosine', # positional embedding (either cosine or RoPE100)
37
+ ):
38
+
39
+ super(CroCoNet, self).__init__()
40
+
41
+ # patch embeddings (with initialization done as in MAE)
42
+ self._set_patch_embed(img_size, patch_size, enc_embed_dim)
43
+
44
+ # mask generations
45
+ self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
46
+
47
+ self.pos_embed = pos_embed
48
+ if pos_embed=='cosine':
49
+ # positional embedding of the encoder
50
+ enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
51
+ self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
52
+ # positional embedding of the decoder
53
+ dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
54
+ self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
55
+ # pos embedding in each block
56
+ self.rope = None # nothing for cosine
57
+ elif pos_embed.startswith('RoPE'): # eg RoPE100
58
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
59
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
60
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
61
+ freq = float(pos_embed[len('RoPE'):])
62
+ self.rope = RoPE2D(freq=freq)
63
+ else:
64
+ raise NotImplementedError('Unknown pos_embed '+pos_embed)
65
+
66
+ # transformer for the encoder
67
+ self.enc_depth = enc_depth
68
+ self.enc_embed_dim = enc_embed_dim
69
+ self.enc_blocks = nn.ModuleList([
70
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
71
+ for i in range(enc_depth)])
72
+ self.enc_norm = norm_layer(enc_embed_dim)
73
+
74
+ # masked tokens
75
+ self._set_mask_token(dec_embed_dim)
76
+
77
+ # decoder
78
+ self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
79
+
80
+ # prediction head
81
+ self._set_prediction_head(dec_embed_dim, patch_size)
82
+
83
+ # initializer weights
84
+ self.initialize_weights()
85
+
86
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
87
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
88
+
89
+ def _set_mask_generator(self, num_patches, mask_ratio):
90
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
91
+
92
+ def _set_mask_token(self, dec_embed_dim):
93
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
94
+
95
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
96
+ self.dec_depth = dec_depth
97
+ self.dec_embed_dim = dec_embed_dim
98
+ # transfer from encoder to decoder
99
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
100
+ # transformer for the decoder
101
+ self.dec_blocks = nn.ModuleList([
102
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
103
+ for i in range(dec_depth)])
104
+ # final norm layer
105
+ self.dec_norm = norm_layer(dec_embed_dim)
106
+
107
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
108
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
109
+
110
+
111
+ def initialize_weights(self):
112
+ # patch embed
113
+ self.patch_embed._init_weights()
114
+ # mask tokens
115
+ if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
116
+ # linears and layer norms
117
+ self.apply(self._init_weights)
118
+
119
+ def _init_weights(self, m):
120
+ if isinstance(m, nn.Linear):
121
+ # we use xavier_uniform following official JAX ViT:
122
+ torch.nn.init.xavier_uniform_(m.weight)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ nn.init.constant_(m.bias, 0)
127
+ nn.init.constant_(m.weight, 1.0)
128
+
129
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
130
+ """
131
+ image has B x 3 x img_size x img_size
132
+ do_mask: whether to perform masking or not
133
+ return_all_blocks: if True, return the features at the end of every block
134
+ instead of just the features from the last block (eg for some prediction heads)
135
+ """
136
+ # embed the image into patches (x has size B x Npatches x C)
137
+ # and get position if each return patch (pos has size B x Npatches x 2)
138
+ x, pos = self.patch_embed(image)
139
+ # add positional embedding without cls token
140
+ if self.enc_pos_embed is not None:
141
+ x = x + self.enc_pos_embed[None,...]
142
+ # apply masking
143
+ B,N,C = x.size()
144
+ if do_mask:
145
+ masks = self.mask_generator(x)
146
+ x = x[~masks].view(B, -1, C)
147
+ posvis = pos[~masks].view(B, -1, 2)
148
+ else:
149
+ B,N,C = x.size()
150
+ masks = torch.zeros((B,N), dtype=bool)
151
+ posvis = pos
152
+ # now apply the transformer encoder and normalization
153
+ if return_all_blocks:
154
+ out = []
155
+ for blk in self.enc_blocks:
156
+ x = blk(x, posvis)
157
+ out.append(x)
158
+ out[-1] = self.enc_norm(out[-1])
159
+ return out, pos, masks
160
+ else:
161
+ for blk in self.enc_blocks:
162
+ x = blk(x, posvis)
163
+ x = self.enc_norm(x)
164
+ return x, pos, masks
165
+
166
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
167
+ """
168
+ return_all_blocks: if True, return the features at the end of every block
169
+ instead of just the features from the last block (eg for some prediction heads)
170
+
171
+ masks1 can be None => assume image1 fully visible
172
+ """
173
+ # encoder to decoder layer
174
+ visf1 = self.decoder_embed(feat1)
175
+ f2 = self.decoder_embed(feat2)
176
+ # append masked tokens to the sequence
177
+ B,Nenc,C = visf1.size()
178
+ if masks1 is None: # downstreams
179
+ f1_ = visf1
180
+ else: # pretraining
181
+ Ntotal = masks1.size(1)
182
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
183
+ f1_[~masks1] = visf1.view(B * Nenc, C)
184
+ # add positional embedding
185
+ if self.dec_pos_embed is not None:
186
+ f1_ = f1_ + self.dec_pos_embed
187
+ f2 = f2 + self.dec_pos_embed
188
+ # apply Transformer blocks
189
+ out = f1_
190
+ out2 = f2
191
+ if return_all_blocks:
192
+ _out, out = out, []
193
+ for blk in self.dec_blocks:
194
+ _out, out2 = blk(_out, out2, pos1, pos2)
195
+ out.append(_out)
196
+ out[-1] = self.dec_norm(out[-1])
197
+ else:
198
+ for blk in self.dec_blocks:
199
+ out, out2 = blk(out, out2, pos1, pos2)
200
+ out = self.dec_norm(out)
201
+ return out
202
+
203
+ def patchify(self, imgs):
204
+ """
205
+ imgs: (B, 3, H, W)
206
+ x: (B, L, patch_size**2 *3)
207
+ """
208
+ p = self.patch_embed.patch_size[0]
209
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
210
+
211
+ h = w = imgs.shape[2] // p
212
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
213
+ x = torch.einsum('nchpwq->nhwpqc', x)
214
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
215
+
216
+ return x
217
+
218
+ def unpatchify(self, x, channels=3):
219
+ """
220
+ x: (N, L, patch_size**2 *channels)
221
+ imgs: (N, 3, H, W)
222
+ """
223
+ patch_size = self.patch_embed.patch_size[0]
224
+ h = w = int(x.shape[1]**.5)
225
+ assert h * w == x.shape[1]
226
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
227
+ x = torch.einsum('nhwpqc->nchpwq', x)
228
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
229
+ return imgs
230
+
231
+ def forward(self, img1, img2):
232
+ """
233
+ img1: tensor of size B x 3 x img_size x img_size
234
+ img2: tensor of size B x 3 x img_size x img_size
235
+
236
+ out will be B x N x (3*patch_size*patch_size)
237
+ masks are also returned as B x N just in case
238
+ """
239
+ # encoder of the masked first image
240
+ feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
241
+ # encoder of the second image
242
+ feat2, pos2, _ = self._encode_image(img2, do_mask=False)
243
+ # decoder
244
+ decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
245
+ # prediction head
246
+ out = self.prediction_head(decfeat)
247
+ # get target
248
+ target = self.patchify(img1)
249
+ return out, mask1, target
mini_dust3r/croco/dpt_block.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # DPT head for ViTs
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # https://github.com/isl-org/DPT
9
+ # https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from typing import Union, Tuple, Iterable, List, Optional, Dict
16
+
17
+ def pair(t):
18
+ return t if isinstance(t, tuple) else (t, t)
19
+
20
+ def make_scratch(in_shape, out_shape, groups=1, expand=False):
21
+ scratch = nn.Module()
22
+
23
+ out_shape1 = out_shape
24
+ out_shape2 = out_shape
25
+ out_shape3 = out_shape
26
+ out_shape4 = out_shape
27
+ if expand == True:
28
+ out_shape1 = out_shape
29
+ out_shape2 = out_shape * 2
30
+ out_shape3 = out_shape * 4
31
+ out_shape4 = out_shape * 8
32
+
33
+ scratch.layer1_rn = nn.Conv2d(
34
+ in_shape[0],
35
+ out_shape1,
36
+ kernel_size=3,
37
+ stride=1,
38
+ padding=1,
39
+ bias=False,
40
+ groups=groups,
41
+ )
42
+ scratch.layer2_rn = nn.Conv2d(
43
+ in_shape[1],
44
+ out_shape2,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ bias=False,
49
+ groups=groups,
50
+ )
51
+ scratch.layer3_rn = nn.Conv2d(
52
+ in_shape[2],
53
+ out_shape3,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False,
58
+ groups=groups,
59
+ )
60
+ scratch.layer4_rn = nn.Conv2d(
61
+ in_shape[3],
62
+ out_shape4,
63
+ kernel_size=3,
64
+ stride=1,
65
+ padding=1,
66
+ bias=False,
67
+ groups=groups,
68
+ )
69
+
70
+ scratch.layer_rn = nn.ModuleList([
71
+ scratch.layer1_rn,
72
+ scratch.layer2_rn,
73
+ scratch.layer3_rn,
74
+ scratch.layer4_rn,
75
+ ])
76
+
77
+ return scratch
78
+
79
+ class ResidualConvUnit_custom(nn.Module):
80
+ """Residual convolution module."""
81
+
82
+ def __init__(self, features, activation, bn):
83
+ """Init.
84
+ Args:
85
+ features (int): number of features
86
+ """
87
+ super().__init__()
88
+
89
+ self.bn = bn
90
+
91
+ self.groups = 1
92
+
93
+ self.conv1 = nn.Conv2d(
94
+ features,
95
+ features,
96
+ kernel_size=3,
97
+ stride=1,
98
+ padding=1,
99
+ bias=not self.bn,
100
+ groups=self.groups,
101
+ )
102
+
103
+ self.conv2 = nn.Conv2d(
104
+ features,
105
+ features,
106
+ kernel_size=3,
107
+ stride=1,
108
+ padding=1,
109
+ bias=not self.bn,
110
+ groups=self.groups,
111
+ )
112
+
113
+ if self.bn == True:
114
+ self.bn1 = nn.BatchNorm2d(features)
115
+ self.bn2 = nn.BatchNorm2d(features)
116
+
117
+ self.activation = activation
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ def forward(self, x):
122
+ """Forward pass.
123
+ Args:
124
+ x (tensor): input
125
+ Returns:
126
+ tensor: output
127
+ """
128
+
129
+ out = self.activation(x)
130
+ out = self.conv1(out)
131
+ if self.bn == True:
132
+ out = self.bn1(out)
133
+
134
+ out = self.activation(out)
135
+ out = self.conv2(out)
136
+ if self.bn == True:
137
+ out = self.bn2(out)
138
+
139
+ if self.groups > 1:
140
+ out = self.conv_merge(out)
141
+
142
+ return self.skip_add.add(out, x)
143
+
144
+ class FeatureFusionBlock_custom(nn.Module):
145
+ """Feature fusion block."""
146
+
147
+ def __init__(
148
+ self,
149
+ features,
150
+ activation,
151
+ deconv=False,
152
+ bn=False,
153
+ expand=False,
154
+ align_corners=True,
155
+ width_ratio=1,
156
+ ):
157
+ """Init.
158
+ Args:
159
+ features (int): number of features
160
+ """
161
+ super(FeatureFusionBlock_custom, self).__init__()
162
+ self.width_ratio = width_ratio
163
+
164
+ self.deconv = deconv
165
+ self.align_corners = align_corners
166
+
167
+ self.groups = 1
168
+
169
+ self.expand = expand
170
+ out_features = features
171
+ if self.expand == True:
172
+ out_features = features // 2
173
+
174
+ self.out_conv = nn.Conv2d(
175
+ features,
176
+ out_features,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0,
180
+ bias=True,
181
+ groups=1,
182
+ )
183
+
184
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
185
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
186
+
187
+ self.skip_add = nn.quantized.FloatFunctional()
188
+
189
+ def forward(self, *xs):
190
+ """Forward pass.
191
+ Returns:
192
+ tensor: output
193
+ """
194
+ output = xs[0]
195
+
196
+ if len(xs) == 2:
197
+ res = self.resConfUnit1(xs[1])
198
+ if self.width_ratio != 1:
199
+ res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
200
+
201
+ output = self.skip_add.add(output, res)
202
+ # output += res
203
+
204
+ output = self.resConfUnit2(output)
205
+
206
+ if self.width_ratio != 1:
207
+ # and output.shape[3] < self.width_ratio * output.shape[2]
208
+ #size=(image.shape[])
209
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
210
+ shape = 3 * output.shape[3]
211
+ else:
212
+ shape = int(self.width_ratio * 2 * output.shape[2])
213
+ output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
214
+ else:
215
+ output = nn.functional.interpolate(output, scale_factor=2,
216
+ mode="bilinear", align_corners=self.align_corners)
217
+ output = self.out_conv(output)
218
+ return output
219
+
220
+ def make_fusion_block(features, use_bn, width_ratio=1):
221
+ return FeatureFusionBlock_custom(
222
+ features,
223
+ nn.ReLU(False),
224
+ deconv=False,
225
+ bn=use_bn,
226
+ expand=False,
227
+ align_corners=True,
228
+ width_ratio=width_ratio,
229
+ )
230
+
231
+ class Interpolate(nn.Module):
232
+ """Interpolation module."""
233
+
234
+ def __init__(self, scale_factor, mode, align_corners=False):
235
+ """Init.
236
+ Args:
237
+ scale_factor (float): scaling
238
+ mode (str): interpolation mode
239
+ """
240
+ super(Interpolate, self).__init__()
241
+
242
+ self.interp = nn.functional.interpolate
243
+ self.scale_factor = scale_factor
244
+ self.mode = mode
245
+ self.align_corners = align_corners
246
+
247
+ def forward(self, x):
248
+ """Forward pass.
249
+ Args:
250
+ x (tensor): input
251
+ Returns:
252
+ tensor: interpolated data
253
+ """
254
+
255
+ x = self.interp(
256
+ x,
257
+ scale_factor=self.scale_factor,
258
+ mode=self.mode,
259
+ align_corners=self.align_corners,
260
+ )
261
+
262
+ return x
263
+
264
+ class DPTOutputAdapter(nn.Module):
265
+ """DPT output adapter.
266
+
267
+ :param num_cahnnels: Number of output channels
268
+ :param stride_level: tride level compared to the full-sized image.
269
+ E.g. 4 for 1/4th the size of the image.
270
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
271
+ Patch size for smaller inputs will be computed accordingly.
272
+ :param hooks: Index of intermediate layers
273
+ :param layer_dims: Dimension of intermediate layers
274
+ :param feature_dim: Feature dimension
275
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
276
+ :param use_bn: If set to True, activates batch norm
277
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
278
+ """
279
+
280
+ def __init__(self,
281
+ num_channels: int = 1,
282
+ stride_level: int = 1,
283
+ patch_size: Union[int, Tuple[int, int]] = 16,
284
+ main_tasks: Iterable[str] = ('rgb',),
285
+ hooks: List[int] = [2, 5, 8, 11],
286
+ layer_dims: List[int] = [96, 192, 384, 768],
287
+ feature_dim: int = 256,
288
+ last_dim: int = 32,
289
+ use_bn: bool = False,
290
+ dim_tokens_enc: Optional[int] = None,
291
+ head_type: str = 'regression',
292
+ output_width_ratio=1,
293
+ **kwargs):
294
+ super().__init__()
295
+ self.num_channels = num_channels
296
+ self.stride_level = stride_level
297
+ self.patch_size = pair(patch_size)
298
+ self.main_tasks = main_tasks
299
+ self.hooks = hooks
300
+ self.layer_dims = layer_dims
301
+ self.feature_dim = feature_dim
302
+ self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
303
+ self.head_type = head_type
304
+
305
+ # Actual patch height and width, taking into account stride of input
306
+ self.P_H = max(1, self.patch_size[0] // stride_level)
307
+ self.P_W = max(1, self.patch_size[1] // stride_level)
308
+
309
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
310
+
311
+ self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
312
+ self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
313
+ self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
314
+ self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
315
+
316
+ if self.head_type == 'regression':
317
+ # The "DPTDepthModel" head
318
+ self.head = nn.Sequential(
319
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
320
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
321
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
322
+ nn.ReLU(True),
323
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
324
+ )
325
+ elif self.head_type == 'semseg':
326
+ # The "DPTSegmentationModel" head
327
+ self.head = nn.Sequential(
328
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
329
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
330
+ nn.ReLU(True),
331
+ nn.Dropout(0.1, False),
332
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
333
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
334
+ )
335
+ else:
336
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
337
+
338
+ if self.dim_tokens_enc is not None:
339
+ self.init(dim_tokens_enc=dim_tokens_enc)
340
+
341
+ def init(self, dim_tokens_enc=768):
342
+ """
343
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
344
+ Should be called when setting up MultiMAE.
345
+
346
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
347
+ """
348
+ #print(dim_tokens_enc)
349
+
350
+ # Set up activation postprocessing layers
351
+ if isinstance(dim_tokens_enc, int):
352
+ dim_tokens_enc = 4 * [dim_tokens_enc]
353
+
354
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
355
+
356
+ self.act_1_postprocess = nn.Sequential(
357
+ nn.Conv2d(
358
+ in_channels=self.dim_tokens_enc[0],
359
+ out_channels=self.layer_dims[0],
360
+ kernel_size=1, stride=1, padding=0,
361
+ ),
362
+ nn.ConvTranspose2d(
363
+ in_channels=self.layer_dims[0],
364
+ out_channels=self.layer_dims[0],
365
+ kernel_size=4, stride=4, padding=0,
366
+ bias=True, dilation=1, groups=1,
367
+ )
368
+ )
369
+
370
+ self.act_2_postprocess = nn.Sequential(
371
+ nn.Conv2d(
372
+ in_channels=self.dim_tokens_enc[1],
373
+ out_channels=self.layer_dims[1],
374
+ kernel_size=1, stride=1, padding=0,
375
+ ),
376
+ nn.ConvTranspose2d(
377
+ in_channels=self.layer_dims[1],
378
+ out_channels=self.layer_dims[1],
379
+ kernel_size=2, stride=2, padding=0,
380
+ bias=True, dilation=1, groups=1,
381
+ )
382
+ )
383
+
384
+ self.act_3_postprocess = nn.Sequential(
385
+ nn.Conv2d(
386
+ in_channels=self.dim_tokens_enc[2],
387
+ out_channels=self.layer_dims[2],
388
+ kernel_size=1, stride=1, padding=0,
389
+ )
390
+ )
391
+
392
+ self.act_4_postprocess = nn.Sequential(
393
+ nn.Conv2d(
394
+ in_channels=self.dim_tokens_enc[3],
395
+ out_channels=self.layer_dims[3],
396
+ kernel_size=1, stride=1, padding=0,
397
+ ),
398
+ nn.Conv2d(
399
+ in_channels=self.layer_dims[3],
400
+ out_channels=self.layer_dims[3],
401
+ kernel_size=3, stride=2, padding=1,
402
+ )
403
+ )
404
+
405
+ self.act_postprocess = nn.ModuleList([
406
+ self.act_1_postprocess,
407
+ self.act_2_postprocess,
408
+ self.act_3_postprocess,
409
+ self.act_4_postprocess
410
+ ])
411
+
412
+ def adapt_tokens(self, encoder_tokens):
413
+ # Adapt tokens
414
+ x = []
415
+ x.append(encoder_tokens[:, :])
416
+ x = torch.cat(x, dim=-1)
417
+ return x
418
+
419
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
420
+ #input_info: Dict):
421
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
422
+ H, W = image_size
423
+
424
+ # Number of patches in height and width
425
+ N_H = H // (self.stride_level * self.P_H)
426
+ N_W = W // (self.stride_level * self.P_W)
427
+
428
+ # Hook decoder onto 4 layers from specified ViT layers
429
+ layers = [encoder_tokens[hook] for hook in self.hooks]
430
+
431
+ # Extract only task-relevant tokens and ignore global tokens.
432
+ layers = [self.adapt_tokens(l) for l in layers]
433
+
434
+ # Reshape tokens to spatial representation
435
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
436
+
437
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
438
+ # Project layers to chosen feature dim
439
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
440
+
441
+ # Fuse layers using refinement stages
442
+ path_4 = self.scratch.refinenet4(layers[3])
443
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
444
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
445
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
446
+
447
+ # Output head
448
+ out = self.head(path_1)
449
+
450
+ return out
mini_dust3r/croco/masking.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Masking utils
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ class RandomMask(nn.Module):
13
+ """
14
+ random masking
15
+ """
16
+
17
+ def __init__(self, num_patches, mask_ratio):
18
+ super().__init__()
19
+ self.num_patches = num_patches
20
+ self.num_mask = int(mask_ratio * self.num_patches)
21
+
22
+ def __call__(self, x):
23
+ noise = torch.rand(x.size(0), self.num_patches, device=x.device)
24
+ argsort = torch.argsort(noise, dim=1)
25
+ return argsort < self.num_mask
mini_dust3r/croco/pos_embed.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Position embedding utils
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ grid_h = np.arange(grid_size, dtype=np.float32)
29
+ grid_w = np.arange(grid_size, dtype=np.float32)
30
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
31
+ grid = np.stack(grid, axis=0)
32
+
33
+ grid = grid.reshape([2, 1, grid_size, grid_size])
34
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35
+ if n_cls_token>0:
36
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
37
+ return pos_embed
38
+
39
+
40
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41
+ assert embed_dim % 2 == 0
42
+
43
+ # use half of dimensions to encode grid_h
44
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46
+
47
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48
+ return emb
49
+
50
+
51
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52
+ """
53
+ embed_dim: output dimension for each position
54
+ pos: a list of positions to be encoded: size (M,)
55
+ out: (M, D)
56
+ """
57
+ assert embed_dim % 2 == 0
58
+ omega = np.arange(embed_dim // 2, dtype=float)
59
+ omega /= embed_dim / 2.
60
+ omega = 1. / 10000**omega # (D/2,)
61
+
62
+ pos = pos.reshape(-1) # (M,)
63
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64
+
65
+ emb_sin = np.sin(out) # (M, D/2)
66
+ emb_cos = np.cos(out) # (M, D/2)
67
+
68
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69
+ return emb
70
+
71
+
72
+ # --------------------------------------------------------
73
+ # Interpolate position embeddings for high-resolution
74
+ # References:
75
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
76
+ # DeiT: https://github.com/facebookresearch/deit
77
+ # --------------------------------------------------------
78
+ def interpolate_pos_embed(model, checkpoint_model):
79
+ if 'pos_embed' in checkpoint_model:
80
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
81
+ embedding_size = pos_embed_checkpoint.shape[-1]
82
+ num_patches = model.patch_embed.num_patches
83
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
84
+ # height (== width) for the checkpoint position embedding
85
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
86
+ # height (== width) for the new position embedding
87
+ new_size = int(num_patches ** 0.5)
88
+ # class_token and dist_token are kept unchanged
89
+ if orig_size != new_size:
90
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
91
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
92
+ # only the position tokens are interpolated
93
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
94
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
95
+ pos_tokens = torch.nn.functional.interpolate(
96
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
97
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
98
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
99
+ checkpoint_model['pos_embed'] = new_pos_embed
100
+
101
+
102
+ #----------------------------------------------------------
103
+ # RoPE2D: RoPE implementation in 2D
104
+ #----------------------------------------------------------
105
+
106
+ try:
107
+ from mini_dust3r.croco.curope import cuRoPE2D
108
+ RoPE2D = cuRoPE2D
109
+ except ImportError:
110
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
111
+
112
+ class RoPE2D(torch.nn.Module):
113
+
114
+ def __init__(self, freq=100.0, F0=1.0):
115
+ super().__init__()
116
+ self.base = freq
117
+ self.F0 = F0
118
+ self.cache = {}
119
+
120
+ def get_cos_sin(self, D, seq_len, device, dtype):
121
+ if (D,seq_len,device,dtype) not in self.cache:
122
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
123
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
124
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
125
+ freqs = torch.cat((freqs, freqs), dim=-1)
126
+ cos = freqs.cos() # (Seq, Dim)
127
+ sin = freqs.sin()
128
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
129
+ return self.cache[D,seq_len,device,dtype]
130
+
131
+ @staticmethod
132
+ def rotate_half(x):
133
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
137
+ assert pos1d.ndim==2
138
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
139
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
140
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
141
+
142
+ def forward(self, tokens, positions):
143
+ """
144
+ input:
145
+ * tokens: batch_size x nheads x ntokens x dim
146
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
147
+ output:
148
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
149
+ """
150
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
151
+ D = tokens.size(3) // 2
152
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
153
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
154
+ # split features into two along the feature dimension, and apply rope1d on each half
155
+ y, x = tokens.chunk(2, dim=-1)
156
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
157
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
158
+ tokens = torch.cat((y, x), dim=-1)
159
+ return tokens
mini_dust3r/heads/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # head factory
6
+ # --------------------------------------------------------
7
+ from .linear_head import LinearPts3d
8
+ from .dpt_head import create_dpt_head
9
+
10
+
11
+ def head_factory(head_type, output_mode, net, has_conf=False):
12
+ """" build a prediction head for the decoder
13
+ """
14
+ if head_type == 'linear' and output_mode == 'pts3d':
15
+ return LinearPts3d(net, has_conf)
16
+ elif head_type == 'dpt' and output_mode == 'pts3d':
17
+ return create_dpt_head(net, has_conf=has_conf)
18
+ else:
19
+ raise NotImplementedError(f"unexpected {head_type=} and {output_mode=}")
mini_dust3r/heads/dpt_head.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # dpt head implementation for DUST3R
6
+ # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ;
7
+ # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True
8
+ # the forward function also takes as input a dictionnary img_info with key "height" and "width"
9
+ # for PixelwiseTask, the output will be of dimension B x num_channels x H x W
10
+ # --------------------------------------------------------
11
+ from einops import rearrange
12
+ from typing import List
13
+ import torch
14
+ import torch.nn as nn
15
+ from mini_dust3r.heads.postprocess import postprocess
16
+ from mini_dust3r.croco.dpt_block import DPTOutputAdapter
17
+
18
+
19
+ class DPTOutputAdapter_fix(DPTOutputAdapter):
20
+ """
21
+ Adapt croco's DPTOutputAdapter implementation for dust3r:
22
+ remove duplicated weigths, and fix forward for dust3r
23
+ """
24
+
25
+ def init(self, dim_tokens_enc=768):
26
+ super().init(dim_tokens_enc)
27
+ # these are duplicated weights
28
+ del self.act_1_postprocess
29
+ del self.act_2_postprocess
30
+ del self.act_3_postprocess
31
+ del self.act_4_postprocess
32
+
33
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size=None):
34
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
35
+ # H, W = input_info['image_size']
36
+ image_size = self.image_size if image_size is None else image_size
37
+ H, W = image_size
38
+ # Number of patches in height and width
39
+ N_H = H // (self.stride_level * self.P_H)
40
+ N_W = W // (self.stride_level * self.P_W)
41
+
42
+ # Hook decoder onto 4 layers from specified ViT layers
43
+ layers = [encoder_tokens[hook] for hook in self.hooks]
44
+
45
+ # Extract only task-relevant tokens and ignore global tokens.
46
+ layers = [self.adapt_tokens(l) for l in layers]
47
+
48
+ # Reshape tokens to spatial representation
49
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
50
+
51
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
52
+ # Project layers to chosen feature dim
53
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
54
+
55
+ # Fuse layers using refinement stages
56
+ path_4 = self.scratch.refinenet4(layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]]
57
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
58
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
59
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
60
+
61
+ # Output head
62
+ out = self.head(path_1)
63
+
64
+ return out
65
+
66
+
67
+ class PixelwiseTaskWithDPT(nn.Module):
68
+ """ DPT module for dust3r, can return 3D points + confidence for all pixels"""
69
+
70
+ def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None,
71
+ output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs):
72
+ super(PixelwiseTaskWithDPT, self).__init__()
73
+ self.return_all_layers = True # backbone needs to return all layers
74
+ self.postprocess = postprocess
75
+ self.depth_mode = depth_mode
76
+ self.conf_mode = conf_mode
77
+
78
+ assert n_cls_token == 0, "Not implemented"
79
+ dpt_args = dict(output_width_ratio=output_width_ratio,
80
+ num_channels=num_channels,
81
+ **kwargs)
82
+ if hooks_idx is not None:
83
+ dpt_args.update(hooks=hooks_idx)
84
+ self.dpt = DPTOutputAdapter_fix(**dpt_args)
85
+ dpt_init_args = {} if dim_tokens is None else {'dim_tokens_enc': dim_tokens}
86
+ self.dpt.init(**dpt_init_args)
87
+
88
+ def forward(self, x, img_info):
89
+ out = self.dpt(x, image_size=(img_info[0], img_info[1]))
90
+ if self.postprocess:
91
+ out = self.postprocess(out, self.depth_mode, self.conf_mode)
92
+ return out
93
+
94
+
95
+ def create_dpt_head(net, has_conf=False):
96
+ """
97
+ return PixelwiseTaskWithDPT for given net params
98
+ """
99
+ assert net.dec_depth > 9
100
+ l2 = net.dec_depth
101
+ feature_dim = 256
102
+ last_dim = feature_dim//2
103
+ out_nchan = 3
104
+ ed = net.enc_embed_dim
105
+ dd = net.dec_embed_dim
106
+ return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf,
107
+ feature_dim=feature_dim,
108
+ last_dim=last_dim,
109
+ hooks_idx=[0, l2*2//4, l2*3//4, l2],
110
+ dim_tokens=[ed, dd, dd, dd],
111
+ postprocess=postprocess,
112
+ depth_mode=net.depth_mode,
113
+ conf_mode=net.conf_mode,
114
+ head_type='regression')
mini_dust3r/heads/linear_head.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # linear head implementation for DUST3R
6
+ # --------------------------------------------------------
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from mini_dust3r.heads.postprocess import postprocess
10
+
11
+
12
+ class LinearPts3d (nn.Module):
13
+ """
14
+ Linear head for dust3r
15
+ Each token outputs: - 16x16 3D points (+ confidence)
16
+ """
17
+
18
+ def __init__(self, net, has_conf=False):
19
+ super().__init__()
20
+ self.patch_size = net.patch_embed.patch_size[0]
21
+ self.depth_mode = net.depth_mode
22
+ self.conf_mode = net.conf_mode
23
+ self.has_conf = has_conf
24
+
25
+ self.proj = nn.Linear(net.dec_embed_dim, (3 + has_conf)*self.patch_size**2)
26
+
27
+ def setup(self, croconet):
28
+ pass
29
+
30
+ def forward(self, decout, img_shape):
31
+ H, W = img_shape
32
+ tokens = decout[-1]
33
+ B, S, D = tokens.shape
34
+
35
+ # extract 3D points
36
+ feat = self.proj(tokens) # B,S,D
37
+ feat = feat.transpose(-1, -2).view(B, -1, H//self.patch_size, W//self.patch_size)
38
+ feat = F.pixel_shuffle(feat, self.patch_size) # B,3,H,W
39
+
40
+ # permute + norm depth
41
+ return postprocess(feat, self.depth_mode, self.conf_mode)
mini_dust3r/heads/postprocess.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # post process function for all heads: extract 3D points/confidence from output
6
+ # --------------------------------------------------------
7
+ import torch
8
+
9
+
10
+ def postprocess(out, depth_mode, conf_mode):
11
+ """
12
+ extract 3D points/confidence from prediction head output
13
+ """
14
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,3
15
+ res = dict(pts3d=reg_dense_depth(fmap[:, :, :, 0:3], mode=depth_mode))
16
+
17
+ if conf_mode is not None:
18
+ res['conf'] = reg_dense_conf(fmap[:, :, :, 3], mode=conf_mode)
19
+ return res
20
+
21
+
22
+ def reg_dense_depth(xyz, mode):
23
+ """
24
+ extract 3D points from prediction head output
25
+ """
26
+ mode, vmin, vmax = mode
27
+
28
+ no_bounds = (vmin == -float('inf')) and (vmax == float('inf'))
29
+ assert no_bounds
30
+
31
+ if mode == 'linear':
32
+ if no_bounds:
33
+ return xyz # [-inf, +inf]
34
+ return xyz.clip(min=vmin, max=vmax)
35
+
36
+ # distance to origin
37
+ d = xyz.norm(dim=-1, keepdim=True)
38
+ xyz = xyz / d.clip(min=1e-8)
39
+
40
+ if mode == 'square':
41
+ return xyz * d.square()
42
+
43
+ if mode == 'exp':
44
+ return xyz * torch.expm1(d)
45
+
46
+ raise ValueError(f'bad {mode=}')
47
+
48
+
49
+ def reg_dense_conf(x, mode):
50
+ """
51
+ extract confidence from prediction head output
52
+ """
53
+ mode, vmin, vmax = mode
54
+ if mode == 'exp':
55
+ return vmin + x.exp().clip(max=vmax-vmin)
56
+ if mode == 'sigmoid':
57
+ return (vmax - vmin) * torch.sigmoid(x) + vmin
58
+ raise ValueError(f'bad {mode=}')
mini_dust3r/image_pairs.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities needed to load image pairs
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ from mini_dust3r.utils.image import ImageDict
10
+
11
+
12
+ def make_pairs(
13
+ imgs: list[ImageDict],
14
+ scene_graph: str = "complete",
15
+ prefilter=None,
16
+ symmetrize=True,
17
+ ) -> list[tuple[ImageDict, ImageDict]]:
18
+ pairs = []
19
+ if scene_graph == "complete": # complete graph
20
+ for i in range(len(imgs)):
21
+ for j in range(i):
22
+ pairs.append((imgs[i], imgs[j]))
23
+ elif scene_graph.startswith("swin"):
24
+ winsize = int(scene_graph.split("-")[1]) if "-" in scene_graph else 3
25
+ pairsid = set()
26
+ for i in range(len(imgs)):
27
+ for j in range(1, winsize + 1):
28
+ idx = (i + j) % len(imgs) # explicit loop closure
29
+ pairsid.add((i, idx) if i < idx else (idx, i))
30
+ for i, j in pairsid:
31
+ pairs.append((imgs[i], imgs[j]))
32
+ elif scene_graph.startswith("oneref"):
33
+ refid = int(scene_graph.split("-")[1]) if "-" in scene_graph else 0
34
+ for j in range(len(imgs)):
35
+ if j != refid:
36
+ pairs.append((imgs[refid], imgs[j]))
37
+ if symmetrize:
38
+ pairs += [(img2, img1) for img1, img2 in pairs]
39
+
40
+ # now, remove edges
41
+ if isinstance(prefilter, str) and prefilter.startswith("seq"):
42
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]))
43
+
44
+ if isinstance(prefilter, str) and prefilter.startswith("cyc"):
45
+ pairs = filter_pairs_seq(pairs, int(prefilter[3:]), cyclic=True)
46
+
47
+ return pairs
48
+
49
+
50
+ def sel(x, kept):
51
+ if isinstance(x, dict):
52
+ return {k: sel(v, kept) for k, v in x.items()}
53
+ if isinstance(x, (torch.Tensor, np.ndarray)):
54
+ return x[kept]
55
+ if isinstance(x, (tuple, list)):
56
+ return type(x)([x[k] for k in kept])
57
+
58
+
59
+ def _filter_edges_seq(edges, seq_dis_thr, cyclic=False):
60
+ # number of images
61
+ n = max(max(e) for e in edges) + 1
62
+
63
+ kept = []
64
+ for e, (i, j) in enumerate(edges):
65
+ dis = abs(i - j)
66
+ if cyclic:
67
+ dis = min(dis, abs(i + n - j), abs(i - n - j))
68
+ if dis <= seq_dis_thr:
69
+ kept.append(e)
70
+ return kept
71
+
72
+
73
+ def filter_pairs_seq(pairs, seq_dis_thr, cyclic=False):
74
+ edges = [(img1["idx"], img2["idx"]) for img1, img2 in pairs]
75
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
76
+ return [pairs[i] for i in kept]
77
+
78
+
79
+ def filter_edges_seq(view1, view2, pred1, pred2, seq_dis_thr, cyclic=False):
80
+ edges = [(int(i), int(j)) for i, j in zip(view1["idx"], view2["idx"])]
81
+ kept = _filter_edges_seq(edges, seq_dis_thr, cyclic=cyclic)
82
+ print(
83
+ f">> Filtering edges more than {seq_dis_thr} frames apart: kept {len(kept)}/{len(edges)} edges"
84
+ )
85
+ return sel(view1, kept), sel(view2, kept), sel(pred1, kept), sel(pred2, kept)
mini_dust3r/inference.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities needed for the inference
6
+ # --------------------------------------------------------
7
+ import tqdm
8
+ import torch
9
+ from mini_dust3r.utils.device import to_cpu, collate_with_cat
10
+ from mini_dust3r.utils.misc import invalid_to_nans
11
+ from mini_dust3r.utils.geometry import depthmap_to_pts3d, geotrf
12
+ from mini_dust3r.utils.image import ImageDict
13
+ from mini_dust3r.model import AsymmetricCroCo3DStereo
14
+
15
+ from typing import Literal, TypedDict, Optional
16
+ from jaxtyping import Float32
17
+
18
+
19
+ class Dust3rPred1(TypedDict):
20
+ pts3d: Float32[torch.Tensor, "b h w c"]
21
+ conf: Float32[torch.Tensor, "b h w"]
22
+
23
+
24
+ class Dust3rPred2(TypedDict):
25
+ pts3d_in_other_view: Float32[torch.Tensor, "b h w c"]
26
+ conf: Float32[torch.Tensor, "b h w"]
27
+
28
+
29
+ class Dust3rResult(TypedDict):
30
+ view1: ImageDict
31
+ view2: ImageDict
32
+ pred1: Dust3rPred1
33
+ pred2: Dust3rPred2
34
+ loss: Optional[int]
35
+
36
+
37
+ def _interleave_imgs(img1, img2):
38
+ res = {}
39
+ for key, value1 in img1.items():
40
+ value2 = img2[key]
41
+ if isinstance(value1, torch.Tensor):
42
+ value = torch.stack((value1, value2), dim=1).flatten(0, 1)
43
+ else:
44
+ value = [x for pair in zip(value1, value2) for x in pair]
45
+ res[key] = value
46
+ return res
47
+
48
+
49
+ def make_batch_symmetric(batch):
50
+ view1, view2 = batch
51
+ view1, view2 = (_interleave_imgs(view1, view2), _interleave_imgs(view2, view1))
52
+ return view1, view2
53
+
54
+
55
+ def loss_of_one_batch(
56
+ batch, model, criterion, device, symmetrize_batch=False, use_amp=False, ret=None
57
+ ):
58
+ view1, view2 = batch
59
+ for view in batch:
60
+ for name in (
61
+ "img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split()
62
+ ): # pseudo_focal
63
+ if name not in view:
64
+ continue
65
+ view[name] = view[name].to(device, non_blocking=True)
66
+
67
+ if symmetrize_batch:
68
+ view1, view2 = make_batch_symmetric(batch)
69
+
70
+ with torch.cuda.amp.autocast(enabled=bool(use_amp)):
71
+ pred1, pred2 = model(view1, view2)
72
+
73
+ # loss is supposed to be symmetric
74
+ with torch.cuda.amp.autocast(enabled=False):
75
+ loss = (
76
+ criterion(view1, view2, pred1, pred2) if criterion is not None else None
77
+ )
78
+
79
+ result = dict(view1=view1, view2=view2, pred1=pred1, pred2=pred2, loss=loss)
80
+ return result[ret] if ret else result
81
+
82
+
83
+ @torch.no_grad()
84
+ def inference(
85
+ pairs: list[tuple[ImageDict, ImageDict]],
86
+ model: AsymmetricCroCo3DStereo,
87
+ device: Literal["cpu", "cuda", "mps"],
88
+ batch_size: int = 8,
89
+ verbose: bool = True,
90
+ ) -> Dust3rResult:
91
+ if verbose:
92
+ print(f">> Inference with model on {len(pairs)} image pairs")
93
+ result = []
94
+
95
+ # first, check if all images have the same size
96
+ multiple_shapes = not (check_if_same_size(pairs))
97
+ if multiple_shapes: # force bs=1
98
+ batch_size = 1
99
+
100
+ for i in tqdm.trange(0, len(pairs), batch_size, disable=not verbose):
101
+ res: Dust3rResult = loss_of_one_batch(
102
+ collate_with_cat(pairs[i : i + batch_size]), model, None, device
103
+ )
104
+ result.append(to_cpu(res))
105
+
106
+ result = collate_with_cat(result, lists=multiple_shapes)
107
+
108
+ return result
109
+
110
+
111
+ def check_if_same_size(pairs):
112
+ shapes1 = [img1["img"].shape[-2:] for img1, img2 in pairs]
113
+ shapes2 = [img2["img"].shape[-2:] for img1, img2 in pairs]
114
+ return all(shapes1[0] == s for s in shapes1) and all(
115
+ shapes2[0] == s for s in shapes2
116
+ )
117
+
118
+
119
+ def get_pred_pts3d(gt, pred, use_pose=False):
120
+ if "depth" in pred and "pseudo_focal" in pred:
121
+ try:
122
+ pp = gt["camera_intrinsics"][..., :2, 2]
123
+ except KeyError:
124
+ pp = None
125
+ pts3d = depthmap_to_pts3d(**pred, pp=pp)
126
+
127
+ elif "pts3d" in pred:
128
+ # pts3d from my camera
129
+ pts3d = pred["pts3d"]
130
+
131
+ elif "pts3d_in_other_view" in pred:
132
+ # pts3d from the other camera, already transformed
133
+ assert use_pose is True
134
+ return pred["pts3d_in_other_view"] # return!
135
+
136
+ if use_pose:
137
+ camera_pose = pred.get("camera_pose")
138
+ assert camera_pose is not None
139
+ pts3d = geotrf(camera_pose, pts3d)
140
+
141
+ return pts3d
142
+
143
+
144
+ def find_opt_scaling(
145
+ gt_pts1,
146
+ gt_pts2,
147
+ pr_pts1,
148
+ pr_pts2=None,
149
+ fit_mode="weiszfeld_stop_grad",
150
+ valid1=None,
151
+ valid2=None,
152
+ ):
153
+ assert gt_pts1.ndim == pr_pts1.ndim == 4
154
+ assert gt_pts1.shape == pr_pts1.shape
155
+ if gt_pts2 is not None:
156
+ assert gt_pts2.ndim == pr_pts2.ndim == 4
157
+ assert gt_pts2.shape == pr_pts2.shape
158
+
159
+ # concat the pointcloud
160
+ nan_gt_pts1 = invalid_to_nans(gt_pts1, valid1).flatten(1, 2)
161
+ nan_gt_pts2 = (
162
+ invalid_to_nans(gt_pts2, valid2).flatten(1, 2) if gt_pts2 is not None else None
163
+ )
164
+
165
+ pr_pts1 = invalid_to_nans(pr_pts1, valid1).flatten(1, 2)
166
+ pr_pts2 = (
167
+ invalid_to_nans(pr_pts2, valid2).flatten(1, 2) if pr_pts2 is not None else None
168
+ )
169
+
170
+ all_gt = (
171
+ torch.cat((nan_gt_pts1, nan_gt_pts2), dim=1)
172
+ if gt_pts2 is not None
173
+ else nan_gt_pts1
174
+ )
175
+ all_pr = torch.cat((pr_pts1, pr_pts2), dim=1) if pr_pts2 is not None else pr_pts1
176
+
177
+ dot_gt_pr = (all_pr * all_gt).sum(dim=-1)
178
+ dot_gt_gt = all_gt.square().sum(dim=-1)
179
+
180
+ if fit_mode.startswith("avg"):
181
+ # scaling = (all_pr / all_gt).view(B, -1).mean(dim=1)
182
+ scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
183
+ elif fit_mode.startswith("median"):
184
+ scaling = (dot_gt_pr / dot_gt_gt).nanmedian(dim=1).values
185
+ elif fit_mode.startswith("weiszfeld"):
186
+ # init scaling with l2 closed form
187
+ scaling = dot_gt_pr.nanmean(dim=1) / dot_gt_gt.nanmean(dim=1)
188
+ # iterative re-weighted least-squares
189
+ for iter in range(10):
190
+ # re-weighting by inverse of distance
191
+ dis = (all_pr - scaling.view(-1, 1, 1) * all_gt).norm(dim=-1)
192
+ # print(dis.nanmean(-1))
193
+ w = dis.clip_(min=1e-8).reciprocal()
194
+ # update the scaling with the new weights
195
+ scaling = (w * dot_gt_pr).nanmean(dim=1) / (w * dot_gt_gt).nanmean(dim=1)
196
+ else:
197
+ raise ValueError(f"bad {fit_mode=}")
198
+
199
+ if fit_mode.endswith("stop_grad"):
200
+ scaling = scaling.detach()
201
+
202
+ scaling = scaling.clip(min=1e-3)
203
+ # assert scaling.isfinite().all(), bb()
204
+ return scaling
mini_dust3r/model.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # DUSt3R model class
6
+ # --------------------------------------------------------
7
+ from copy import deepcopy
8
+ import torch
9
+ import os
10
+ from packaging import version
11
+ import huggingface_hub
12
+
13
+ from .utils.misc import (
14
+ fill_default_args,
15
+ freeze_all_params,
16
+ is_symmetrized,
17
+ interleave,
18
+ transpose_to_landscape,
19
+ )
20
+ from .heads import head_factory
21
+ from mini_dust3r.patch_embed import get_patch_embed
22
+
23
+ from mini_dust3r.croco.croco import CroCoNet
24
+
25
+ inf = float("inf")
26
+
27
+ hf_version_number = huggingface_hub.__version__
28
+ assert version.parse(hf_version_number) >= version.parse(
29
+ "0.22.0"
30
+ ), "Outdated huggingface_hub version, please reinstall requirements.txt"
31
+
32
+
33
+ def load_model(model_path, device, verbose=True):
34
+ if verbose:
35
+ print("... loading model from", model_path)
36
+ ckpt = torch.load(model_path, map_location="cpu")
37
+ args = ckpt["args"].model.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R")
38
+ if "landscape_only" not in args:
39
+ args = args[:-1] + ", landscape_only=False)"
40
+ else:
41
+ args = args.replace(" ", "").replace(
42
+ "landscape_only=True", "landscape_only=False"
43
+ )
44
+ assert "landscape_only=False" in args
45
+ if verbose:
46
+ print(f"instantiating : {args}")
47
+ net = eval(args)
48
+ s = net.load_state_dict(ckpt["model"], strict=False)
49
+ if verbose:
50
+ print(s)
51
+ return net.to(device)
52
+
53
+
54
+ class AsymmetricCroCo3DStereo(
55
+ CroCoNet,
56
+ huggingface_hub.PyTorchModelHubMixin,
57
+ library_name="dust3r",
58
+ repo_url="https://github.com/naver/dust3r",
59
+ tags=["image-to-3d"],
60
+ ):
61
+ """Two siamese encoders, followed by two decoders.
62
+ The goal is to output 3d points directly, both images in view1's frame
63
+ (hence the asymmetry).
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ output_mode="pts3d",
69
+ head_type="linear",
70
+ depth_mode=("exp", -inf, inf),
71
+ conf_mode=("exp", 1, inf),
72
+ freeze="none",
73
+ landscape_only=True,
74
+ patch_embed_cls="PatchEmbedDust3R", # PatchEmbedDust3R or ManyAR_PatchEmbed
75
+ **croco_kwargs,
76
+ ):
77
+ self.patch_embed_cls = patch_embed_cls
78
+ self.croco_args = fill_default_args(croco_kwargs, super().__init__)
79
+ super().__init__(**croco_kwargs)
80
+
81
+ # dust3r specific initialization
82
+ self.dec_blocks2 = deepcopy(self.dec_blocks)
83
+ self.set_downstream_head(
84
+ output_mode,
85
+ head_type,
86
+ landscape_only,
87
+ depth_mode,
88
+ conf_mode,
89
+ **croco_kwargs,
90
+ )
91
+ self.set_freeze(freeze)
92
+
93
+ @classmethod
94
+ def from_pretrained(cls, pretrained_model_name_or_path, **kw):
95
+ if os.path.isfile(pretrained_model_name_or_path):
96
+ return load_model(pretrained_model_name_or_path, device="cpu")
97
+ else:
98
+ return super(AsymmetricCroCo3DStereo, cls).from_pretrained(
99
+ pretrained_model_name_or_path, **kw
100
+ )
101
+
102
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
103
+ self.patch_embed = get_patch_embed(
104
+ self.patch_embed_cls, img_size, patch_size, enc_embed_dim
105
+ )
106
+
107
+ def load_state_dict(self, ckpt, **kw):
108
+ # duplicate all weights for the second decoder if not present
109
+ new_ckpt = dict(ckpt)
110
+ if not any(k.startswith("dec_blocks2") for k in ckpt):
111
+ for key, value in ckpt.items():
112
+ if key.startswith("dec_blocks"):
113
+ new_ckpt[key.replace("dec_blocks", "dec_blocks2")] = value
114
+ return super().load_state_dict(new_ckpt, **kw)
115
+
116
+ def set_freeze(self, freeze): # this is for use by downstream models
117
+ self.freeze = freeze
118
+ to_be_frozen = {
119
+ "none": [],
120
+ "mask": [self.mask_token],
121
+ "encoder": [self.mask_token, self.patch_embed, self.enc_blocks],
122
+ }
123
+ freeze_all_params(to_be_frozen[freeze])
124
+
125
+ def _set_prediction_head(self, *args, **kwargs):
126
+ """No prediction head"""
127
+ return
128
+
129
+ def set_downstream_head(
130
+ self,
131
+ output_mode,
132
+ head_type,
133
+ landscape_only,
134
+ depth_mode,
135
+ conf_mode,
136
+ patch_size,
137
+ img_size,
138
+ **kw,
139
+ ):
140
+ assert (
141
+ img_size[0] % patch_size == 0 and img_size[1] % patch_size == 0
142
+ ), f"{img_size=} must be multiple of {patch_size=}"
143
+ self.output_mode = output_mode
144
+ self.head_type = head_type
145
+ self.depth_mode = depth_mode
146
+ self.conf_mode = conf_mode
147
+ # allocate heads
148
+ self.downstream_head1 = head_factory(
149
+ head_type, output_mode, self, has_conf=bool(conf_mode)
150
+ )
151
+ self.downstream_head2 = head_factory(
152
+ head_type, output_mode, self, has_conf=bool(conf_mode)
153
+ )
154
+ # magic wrapper
155
+ self.head1 = transpose_to_landscape(
156
+ self.downstream_head1, activate=landscape_only
157
+ )
158
+ self.head2 = transpose_to_landscape(
159
+ self.downstream_head2, activate=landscape_only
160
+ )
161
+
162
+ def _encode_image(self, image, true_shape):
163
+ # embed the image into patches (x has size B x Npatches x C)
164
+ x, pos = self.patch_embed(image, true_shape=true_shape)
165
+
166
+ # add positional embedding without cls token
167
+ assert self.enc_pos_embed is None
168
+
169
+ # now apply the transformer encoder and normalization
170
+ for blk in self.enc_blocks:
171
+ x = blk(x, pos)
172
+
173
+ x = self.enc_norm(x)
174
+ return x, pos, None
175
+
176
+ def _encode_image_pairs(self, img1, img2, true_shape1, true_shape2):
177
+ if img1.shape[-2:] == img2.shape[-2:]:
178
+ out, pos, _ = self._encode_image(
179
+ torch.cat((img1, img2), dim=0),
180
+ torch.cat((true_shape1, true_shape2), dim=0),
181
+ )
182
+ out, out2 = out.chunk(2, dim=0)
183
+ pos, pos2 = pos.chunk(2, dim=0)
184
+ else:
185
+ out, pos, _ = self._encode_image(img1, true_shape1)
186
+ out2, pos2, _ = self._encode_image(img2, true_shape2)
187
+ return out, out2, pos, pos2
188
+
189
+ def _encode_symmetrized(self, view1, view2):
190
+ img1 = view1["img"]
191
+ img2 = view2["img"]
192
+ B = img1.shape[0]
193
+ # Recover true_shape when available, otherwise assume that the img shape is the true one
194
+ shape1 = view1.get(
195
+ "true_shape", torch.tensor(img1.shape[-2:])[None].repeat(B, 1)
196
+ )
197
+ shape2 = view2.get(
198
+ "true_shape", torch.tensor(img2.shape[-2:])[None].repeat(B, 1)
199
+ )
200
+ # warning! maybe the images have different portrait/landscape orientations
201
+
202
+ if is_symmetrized(view1, view2):
203
+ # computing half of forward pass!'
204
+ feat1, feat2, pos1, pos2 = self._encode_image_pairs(
205
+ img1[::2], img2[::2], shape1[::2], shape2[::2]
206
+ )
207
+ feat1, feat2 = interleave(feat1, feat2)
208
+ pos1, pos2 = interleave(pos1, pos2)
209
+ else:
210
+ feat1, feat2, pos1, pos2 = self._encode_image_pairs(
211
+ img1, img2, shape1, shape2
212
+ )
213
+
214
+ return (shape1, shape2), (feat1, feat2), (pos1, pos2)
215
+
216
+ def _decoder(self, f1, pos1, f2, pos2):
217
+ final_output = [(f1, f2)] # before projection
218
+
219
+ # project to decoder dim
220
+ f1 = self.decoder_embed(f1)
221
+ f2 = self.decoder_embed(f2)
222
+
223
+ final_output.append((f1, f2))
224
+ for blk1, blk2 in zip(self.dec_blocks, self.dec_blocks2):
225
+ # img1 side
226
+ f1, _ = blk1(*final_output[-1][::+1], pos1, pos2)
227
+ # img2 side
228
+ f2, _ = blk2(*final_output[-1][::-1], pos2, pos1)
229
+ # store the result
230
+ final_output.append((f1, f2))
231
+
232
+ # normalize last output
233
+ del final_output[1] # duplicate with final_output[0]
234
+ final_output[-1] = tuple(map(self.dec_norm, final_output[-1]))
235
+ return zip(*final_output)
236
+
237
+ def _downstream_head(self, head_num, decout, img_shape):
238
+ B, S, D = decout[-1].shape
239
+ # img_shape = tuple(map(int, img_shape))
240
+ head = getattr(self, f"head{head_num}")
241
+ return head(decout, img_shape)
242
+
243
+ def forward(self, view1, view2):
244
+ # encode the two images --> B,S,D
245
+ (shape1, shape2), (feat1, feat2), (pos1, pos2) = self._encode_symmetrized(
246
+ view1, view2
247
+ )
248
+
249
+ # combine all ref images into object-centric representation
250
+ dec1, dec2 = self._decoder(feat1, pos1, feat2, pos2)
251
+
252
+ with torch.cuda.amp.autocast(enabled=False):
253
+ res1 = self._downstream_head(1, [tok.float() for tok in dec1], shape1)
254
+ res2 = self._downstream_head(2, [tok.float() for tok in dec2], shape2)
255
+
256
+ res2["pts3d_in_other_view"] = res2.pop(
257
+ "pts3d"
258
+ ) # predict view2's pts3d in view1's frame
259
+ return res1, res2
mini_dust3r/optim_factory.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # optimization functions
6
+ # --------------------------------------------------------
7
+
8
+
9
+ def adjust_learning_rate_by_lr(optimizer, lr):
10
+ for param_group in optimizer.param_groups:
11
+ if "lr_scale" in param_group:
12
+ param_group["lr"] = lr * param_group["lr_scale"]
13
+ else:
14
+ param_group["lr"] = lr
mini_dust3r/patch_embed.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # PatchEmbed implementation for DUST3R,
6
+ # in particular ManyAR_PatchEmbed that Handle images with non-square aspect ratio
7
+ # --------------------------------------------------------
8
+ import torch
9
+ from mini_dust3r.croco.blocks import PatchEmbed
10
+
11
+
12
+ def get_patch_embed(patch_embed_cls, img_size, patch_size, enc_embed_dim):
13
+ assert patch_embed_cls in ['PatchEmbedDust3R', 'ManyAR_PatchEmbed']
14
+ patch_embed = eval(patch_embed_cls)(img_size, patch_size, 3, enc_embed_dim)
15
+ return patch_embed
16
+
17
+
18
+ class PatchEmbedDust3R(PatchEmbed):
19
+ def forward(self, x, **kw):
20
+ B, C, H, W = x.shape
21
+ assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
22
+ assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
23
+ x = self.proj(x)
24
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
25
+ if self.flatten:
26
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
27
+ x = self.norm(x)
28
+ return x, pos
29
+
30
+
31
+ class ManyAR_PatchEmbed (PatchEmbed):
32
+ """ Handle images with non-square aspect ratio.
33
+ All images in the same batch have the same aspect ratio.
34
+ true_shape = [(height, width) ...] indicates the actual shape of each image.
35
+ """
36
+
37
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
38
+ self.embed_dim = embed_dim
39
+ super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten)
40
+
41
+ def forward(self, img, true_shape):
42
+ B, C, H, W = img.shape
43
+ assert W >= H, f'img should be in landscape mode, but got {W=} {H=}'
44
+ assert H % self.patch_size[0] == 0, f"Input image height ({H}) is not a multiple of patch size ({self.patch_size[0]})."
45
+ assert W % self.patch_size[1] == 0, f"Input image width ({W}) is not a multiple of patch size ({self.patch_size[1]})."
46
+ assert true_shape.shape == (B, 2), f"true_shape has the wrong shape={true_shape.shape}"
47
+
48
+ # size expressed in tokens
49
+ W //= self.patch_size[0]
50
+ H //= self.patch_size[1]
51
+ n_tokens = H * W
52
+
53
+ height, width = true_shape.T
54
+ is_landscape = (width >= height)
55
+ is_portrait = ~is_landscape
56
+
57
+ # allocate result
58
+ x = img.new_zeros((B, n_tokens, self.embed_dim))
59
+ pos = img.new_zeros((B, n_tokens, 2), dtype=torch.int64)
60
+
61
+ # linear projection, transposed if necessary
62
+ x[is_landscape] = self.proj(img[is_landscape]).permute(0, 2, 3, 1).flatten(1, 2).float()
63
+ x[is_portrait] = self.proj(img[is_portrait].swapaxes(-1, -2)).permute(0, 2, 3, 1).flatten(1, 2).float()
64
+
65
+ pos[is_landscape] = self.position_getter(1, H, W, pos.device)
66
+ pos[is_portrait] = self.position_getter(1, W, H, pos.device)
67
+
68
+ x = self.norm(x)
69
+ return x, pos
mini_dust3r/post_process.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilities for interpreting the DUST3R output
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+ from mini_dust3r.utils.geometry import xy_grid
10
+
11
+
12
+ def estimate_focal_knowing_depth(pts3d, pp, focal_mode='median', min_focal=0., max_focal=np.inf):
13
+ """ Reprojection method, for when the absolute depth is known:
14
+ 1) estimate the camera focal using a robust estimator
15
+ 2) reproject points onto true rays, minimizing a certain error
16
+ """
17
+ B, H, W, THREE = pts3d.shape
18
+ assert THREE == 3
19
+
20
+ # centered pixel grid
21
+ pixels = xy_grid(W, H, device=pts3d.device).view(1, -1, 2) - pp.view(-1, 1, 2) # B,HW,2
22
+ pts3d = pts3d.flatten(1, 2) # (B, HW, 3)
23
+
24
+ if focal_mode == 'median':
25
+ with torch.no_grad():
26
+ # direct estimation of focal
27
+ u, v = pixels.unbind(dim=-1)
28
+ x, y, z = pts3d.unbind(dim=-1)
29
+ fx_votes = (u * z) / x
30
+ fy_votes = (v * z) / y
31
+
32
+ # assume square pixels, hence same focal for X and Y
33
+ f_votes = torch.cat((fx_votes.view(B, -1), fy_votes.view(B, -1)), dim=-1)
34
+ focal = torch.nanmedian(f_votes, dim=-1).values
35
+
36
+ elif focal_mode == 'weiszfeld':
37
+ # init focal with l2 closed form
38
+ # we try to find focal = argmin Sum | pixel - focal * (x,y)/z|
39
+ xy_over_z = (pts3d[..., :2] / pts3d[..., 2:3]).nan_to_num(posinf=0, neginf=0) # homogeneous (x,y,1)
40
+
41
+ dot_xy_px = (xy_over_z * pixels).sum(dim=-1)
42
+ dot_xy_xy = xy_over_z.square().sum(dim=-1)
43
+
44
+ focal = dot_xy_px.mean(dim=1) / dot_xy_xy.mean(dim=1)
45
+
46
+ # iterative re-weighted least-squares
47
+ for iter in range(10):
48
+ # re-weighting by inverse of distance
49
+ dis = (pixels - focal.view(-1, 1, 1) * xy_over_z).norm(dim=-1)
50
+ # print(dis.nanmean(-1))
51
+ w = dis.clip(min=1e-8).reciprocal()
52
+ # update the scaling with the new weights
53
+ focal = (w * dot_xy_px).mean(dim=1) / (w * dot_xy_xy).mean(dim=1)
54
+ else:
55
+ raise ValueError(f'bad {focal_mode=}')
56
+
57
+ focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
58
+ focal = focal.clip(min=min_focal*focal_base, max=max_focal*focal_base)
59
+ # print(focal)
60
+ return focal
mini_dust3r/utils/device.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for DUSt3R
6
+ # --------------------------------------------------------
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def todevice(batch, device, callback=None, non_blocking=False):
12
+ ''' Transfer some variables to another device (i.e. GPU, CPU:torch, CPU:numpy).
13
+
14
+ batch: list, tuple, dict of tensors or other things
15
+ device: pytorch device or 'numpy'
16
+ callback: function that would be called on every sub-elements.
17
+ '''
18
+ if callback:
19
+ batch = callback(batch)
20
+
21
+ if isinstance(batch, dict):
22
+ return {k: todevice(v, device) for k, v in batch.items()}
23
+
24
+ if isinstance(batch, (tuple, list)):
25
+ return type(batch)(todevice(x, device) for x in batch)
26
+
27
+ x = batch
28
+ if device == 'numpy':
29
+ if isinstance(x, torch.Tensor):
30
+ x = x.detach().cpu().numpy()
31
+ elif x is not None:
32
+ if isinstance(x, np.ndarray):
33
+ x = torch.from_numpy(x)
34
+ if torch.is_tensor(x):
35
+ x = x.to(device, non_blocking=non_blocking)
36
+ return x
37
+
38
+
39
+ to_device = todevice # alias
40
+
41
+
42
+ def to_numpy(x): return todevice(x, 'numpy')
43
+ def to_cpu(x): return todevice(x, 'cpu')
44
+ def to_cuda(x): return todevice(x, 'cuda')
45
+
46
+
47
+ def collate_with_cat(whatever, lists=False):
48
+ if isinstance(whatever, dict):
49
+ return {k: collate_with_cat(vals, lists=lists) for k, vals in whatever.items()}
50
+
51
+ elif isinstance(whatever, (tuple, list)):
52
+ if len(whatever) == 0:
53
+ return whatever
54
+ elem = whatever[0]
55
+ T = type(whatever)
56
+
57
+ if elem is None:
58
+ return None
59
+ if isinstance(elem, (bool, float, int, str)):
60
+ return whatever
61
+ if isinstance(elem, tuple):
62
+ return T(collate_with_cat(x, lists=lists) for x in zip(*whatever))
63
+ if isinstance(elem, dict):
64
+ return {k: collate_with_cat([e[k] for e in whatever], lists=lists) for k in elem}
65
+
66
+ if isinstance(elem, torch.Tensor):
67
+ return listify(whatever) if lists else torch.cat(whatever)
68
+ if isinstance(elem, np.ndarray):
69
+ return listify(whatever) if lists else torch.cat([torch.from_numpy(x) for x in whatever])
70
+
71
+ # otherwise, we just chain lists
72
+ return sum(whatever, T())
73
+
74
+
75
+ def listify(elems):
76
+ return [x for e in elems for x in e]
mini_dust3r/utils/geometry.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # geometry utilitary functions
6
+ # --------------------------------------------------------
7
+ import torch
8
+ import numpy as np
9
+ from scipy.spatial import cKDTree as KDTree
10
+
11
+ from mini_dust3r.utils.misc import invalid_to_zeros, invalid_to_nans
12
+ from mini_dust3r.utils.device import to_numpy
13
+
14
+
15
+ def xy_grid(W, H, device=None, origin=(0, 0), unsqueeze=None, cat_dim=-1, homogeneous=False, **arange_kw):
16
+ """ Output a (H,W,2) array of int32
17
+ with output[j,i,0] = i + origin[0]
18
+ output[j,i,1] = j + origin[1]
19
+ """
20
+ if device is None:
21
+ # numpy
22
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
23
+ else:
24
+ # torch
25
+ arange = lambda *a, **kw: torch.arange(*a, device=device, **kw)
26
+ meshgrid, stack = torch.meshgrid, torch.stack
27
+ ones = lambda *a: torch.ones(*a, device=device)
28
+
29
+ tw, th = [arange(o, o+s, **arange_kw) for s, o in zip((W, H), origin)]
30
+ grid = meshgrid(tw, th, indexing='xy')
31
+ if homogeneous:
32
+ grid = grid + (ones((H, W)),)
33
+ if unsqueeze is not None:
34
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
35
+ if cat_dim is not None:
36
+ grid = stack(grid, cat_dim)
37
+ return grid
38
+
39
+
40
+ def geotrf(Trf, pts, ncol=None, norm=False):
41
+ """ Apply a geometric transformation to a list of 3-D points.
42
+
43
+ H: 3x3 or 4x4 projection matrix (typically a Homography)
44
+ p: numpy/torch/tuple of coordinates. Shape must be (...,2) or (...,3)
45
+
46
+ ncol: int. number of columns of the result (2 or 3)
47
+ norm: float. if != 0, the resut is projected on the z=norm plane.
48
+
49
+ Returns an array of projected 2d points.
50
+ """
51
+ assert Trf.ndim >= 2
52
+ if isinstance(Trf, np.ndarray):
53
+ pts = np.asarray(pts)
54
+ elif isinstance(Trf, torch.Tensor):
55
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
56
+
57
+ # adapt shape if necessary
58
+ output_reshape = pts.shape[:-1]
59
+ ncol = ncol or pts.shape[-1]
60
+
61
+ # optimized code
62
+ if (isinstance(Trf, torch.Tensor) and isinstance(pts, torch.Tensor) and
63
+ Trf.ndim == 3 and pts.ndim == 4):
64
+ d = pts.shape[3]
65
+ if Trf.shape[-1] == d:
66
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
67
+ elif Trf.shape[-1] == d+1:
68
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts) + Trf[:, None, None, :d, d]
69
+ else:
70
+ raise ValueError(f'bad shape, not ending with 3 or 4, for {pts.shape=}')
71
+ else:
72
+ if Trf.ndim >= 3:
73
+ n = Trf.ndim-2
74
+ assert Trf.shape[:n] == pts.shape[:n], 'batch size does not match'
75
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
76
+
77
+ if pts.ndim > Trf.ndim:
78
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
79
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
80
+ elif pts.ndim == 2:
81
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
82
+ pts = pts[:, None, :]
83
+
84
+ if pts.shape[-1]+1 == Trf.shape[-1]:
85
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
86
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
87
+ elif pts.shape[-1] == Trf.shape[-1]:
88
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
89
+ pts = pts @ Trf
90
+ else:
91
+ pts = Trf @ pts.T
92
+ if pts.ndim >= 2:
93
+ pts = pts.swapaxes(-1, -2)
94
+
95
+ if norm:
96
+ pts = pts / pts[..., -1:] # DONT DO /= BECAUSE OF WEIRD PYTORCH BUG
97
+ if norm != 1:
98
+ pts *= norm
99
+
100
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
101
+ return res
102
+
103
+
104
+ def inv(mat):
105
+ """ Invert a torch or numpy matrix
106
+ """
107
+ if isinstance(mat, torch.Tensor):
108
+ return torch.linalg.inv(mat)
109
+ if isinstance(mat, np.ndarray):
110
+ return np.linalg.inv(mat)
111
+ raise ValueError(f'bad matrix type = {type(mat)}')
112
+
113
+
114
+ def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
115
+ """
116
+ Args:
117
+ - depthmap (BxHxW array):
118
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
119
+ Returns:
120
+ pointmap of absolute coordinates (BxHxWx3 array)
121
+ """
122
+
123
+ if len(depth.shape) == 4:
124
+ B, H, W, n = depth.shape
125
+ else:
126
+ B, H, W = depth.shape
127
+ n = None
128
+
129
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
130
+ pseudo_focalx = pseudo_focaly = pseudo_focal
131
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
132
+ pseudo_focalx = pseudo_focal[:, 0]
133
+ if pseudo_focal.shape[1] == 2:
134
+ pseudo_focaly = pseudo_focal[:, 1]
135
+ else:
136
+ pseudo_focaly = pseudo_focalx
137
+ else:
138
+ raise NotImplementedError("Error, unknown input focal shape format.")
139
+
140
+ assert pseudo_focalx.shape == depth.shape[:3]
141
+ assert pseudo_focaly.shape == depth.shape[:3]
142
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
143
+
144
+ # set principal point
145
+ if pp is None:
146
+ grid_x = grid_x - (W-1)/2
147
+ grid_y = grid_y - (H-1)/2
148
+ else:
149
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
150
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
151
+
152
+ if n is None:
153
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
154
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
155
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
156
+ pts3d[..., 2] = depth
157
+ else:
158
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
159
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
160
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
161
+ pts3d[..., 2, :] = depth
162
+ return pts3d
163
+
164
+
165
+ def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
166
+ """
167
+ Args:
168
+ - depthmap (HxW array):
169
+ - camera_intrinsics: a 3x3 matrix
170
+ Returns:
171
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
172
+ """
173
+ camera_intrinsics = np.float32(camera_intrinsics)
174
+ H, W = depthmap.shape
175
+
176
+ # Compute 3D ray associated with each pixel
177
+ # Strong assumption: there are no skew terms
178
+ assert camera_intrinsics[0, 1] == 0.0
179
+ assert camera_intrinsics[1, 0] == 0.0
180
+ if pseudo_focal is None:
181
+ fu = camera_intrinsics[0, 0]
182
+ fv = camera_intrinsics[1, 1]
183
+ else:
184
+ assert pseudo_focal.shape == (H, W)
185
+ fu = fv = pseudo_focal
186
+ cu = camera_intrinsics[0, 2]
187
+ cv = camera_intrinsics[1, 2]
188
+
189
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
190
+ z_cam = depthmap
191
+ x_cam = (u - cu) * z_cam / fu
192
+ y_cam = (v - cv) * z_cam / fv
193
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
194
+
195
+ # Mask for valid coordinates
196
+ valid_mask = (depthmap > 0.0)
197
+ return X_cam, valid_mask
198
+
199
+
200
+ def depthmap_to_absolute_camera_coordinates(depthmap, camera_intrinsics, camera_pose, **kw):
201
+ """
202
+ Args:
203
+ - depthmap (HxW array):
204
+ - camera_intrinsics: a 3x3 matrix
205
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
206
+ Returns:
207
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels."""
208
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
209
+
210
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
211
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
212
+ R_cam2world = camera_pose[:3, :3]
213
+ t_cam2world = camera_pose[:3, 3]
214
+
215
+ # Express in absolute coordinates (invalid depth values)
216
+ X_world = np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
217
+ return X_world, valid_mask
218
+
219
+
220
+ def colmap_to_opencv_intrinsics(K):
221
+ """
222
+ Modify camera intrinsics to follow a different convention.
223
+ Coordinates of the center of the top-left pixels are by default:
224
+ - (0.5, 0.5) in Colmap
225
+ - (0,0) in OpenCV
226
+ """
227
+ K = K.copy()
228
+ K[0, 2] -= 0.5
229
+ K[1, 2] -= 0.5
230
+ return K
231
+
232
+
233
+ def opencv_to_colmap_intrinsics(K):
234
+ """
235
+ Modify camera intrinsics to follow a different convention.
236
+ Coordinates of the center of the top-left pixels are by default:
237
+ - (0.5, 0.5) in Colmap
238
+ - (0,0) in OpenCV
239
+ """
240
+ K = K.copy()
241
+ K[0, 2] += 0.5
242
+ K[1, 2] += 0.5
243
+ return K
244
+
245
+
246
+ def normalize_pointcloud(pts1, pts2, norm_mode='avg_dis', valid1=None, valid2=None):
247
+ """ renorm pointmaps pts1, pts2 with norm_mode
248
+ """
249
+ assert pts1.ndim >= 3 and pts1.shape[-1] == 3
250
+ assert pts2 is None or (pts2.ndim >= 3 and pts2.shape[-1] == 3)
251
+ norm_mode, dis_mode = norm_mode.split('_')
252
+
253
+ if norm_mode == 'avg':
254
+ # gather all points together (joint normalization)
255
+ nan_pts1, nnz1 = invalid_to_zeros(pts1, valid1, ndim=3)
256
+ nan_pts2, nnz2 = invalid_to_zeros(pts2, valid2, ndim=3) if pts2 is not None else (None, 0)
257
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
258
+
259
+ # compute distance to origin
260
+ all_dis = all_pts.norm(dim=-1)
261
+ if dis_mode == 'dis':
262
+ pass # do nothing
263
+ elif dis_mode == 'log1p':
264
+ all_dis = torch.log1p(all_dis)
265
+ elif dis_mode == 'warp-log1p':
266
+ # actually warp input points before normalizing them
267
+ log_dis = torch.log1p(all_dis)
268
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
269
+ H1, W1 = pts1.shape[1:-1]
270
+ pts1 = pts1 * warp_factor[:, :W1*H1].view(-1, H1, W1, 1)
271
+ if pts2 is not None:
272
+ H2, W2 = pts2.shape[1:-1]
273
+ pts2 = pts2 * warp_factor[:, W1*H1:].view(-1, H2, W2, 1)
274
+ all_dis = log_dis # this is their true distance afterwards
275
+ else:
276
+ raise ValueError(f'bad {dis_mode=}')
277
+
278
+ norm_factor = all_dis.sum(dim=1) / (nnz1 + nnz2 + 1e-8)
279
+ else:
280
+ # gather all points together (joint normalization)
281
+ nan_pts1 = invalid_to_nans(pts1, valid1, ndim=3)
282
+ nan_pts2 = invalid_to_nans(pts2, valid2, ndim=3) if pts2 is not None else None
283
+ all_pts = torch.cat((nan_pts1, nan_pts2), dim=1) if pts2 is not None else nan_pts1
284
+
285
+ # compute distance to origin
286
+ all_dis = all_pts.norm(dim=-1)
287
+
288
+ if norm_mode == 'avg':
289
+ norm_factor = all_dis.nanmean(dim=1)
290
+ elif norm_mode == 'median':
291
+ norm_factor = all_dis.nanmedian(dim=1).values.detach()
292
+ elif norm_mode == 'sqrt':
293
+ norm_factor = all_dis.sqrt().nanmean(dim=1)**2
294
+ else:
295
+ raise ValueError(f'bad {norm_mode=}')
296
+
297
+ norm_factor = norm_factor.clip(min=1e-8)
298
+ while norm_factor.ndim < pts1.ndim:
299
+ norm_factor.unsqueeze_(-1)
300
+
301
+ res = pts1 / norm_factor
302
+ if pts2 is not None:
303
+ res = (res, pts2 / norm_factor)
304
+ return res
305
+
306
+
307
+ @torch.no_grad()
308
+ def get_joint_pointcloud_depth(z1, z2, valid_mask1, valid_mask2=None, quantile=0.5):
309
+ # set invalid points to NaN
310
+ _z1 = invalid_to_nans(z1, valid_mask1).reshape(len(z1), -1)
311
+ _z2 = invalid_to_nans(z2, valid_mask2).reshape(len(z2), -1) if z2 is not None else None
312
+ _z = torch.cat((_z1, _z2), dim=-1) if z2 is not None else _z1
313
+
314
+ # compute median depth overall (ignoring nans)
315
+ if quantile == 0.5:
316
+ shift_z = torch.nanmedian(_z, dim=-1).values
317
+ else:
318
+ shift_z = torch.nanquantile(_z, quantile, dim=-1)
319
+ return shift_z # (B,)
320
+
321
+
322
+ @torch.no_grad()
323
+ def get_joint_pointcloud_center_scale(pts1, pts2, valid_mask1=None, valid_mask2=None, z_only=False, center=True):
324
+ # set invalid points to NaN
325
+ _pts1 = invalid_to_nans(pts1, valid_mask1).reshape(len(pts1), -1, 3)
326
+ _pts2 = invalid_to_nans(pts2, valid_mask2).reshape(len(pts2), -1, 3) if pts2 is not None else None
327
+ _pts = torch.cat((_pts1, _pts2), dim=1) if pts2 is not None else _pts1
328
+
329
+ # compute median center
330
+ _center = torch.nanmedian(_pts, dim=1, keepdim=True).values # (B,1,3)
331
+ if z_only:
332
+ _center[..., :2] = 0 # do not center X and Y
333
+
334
+ # compute median norm
335
+ _norm = ((_pts - _center) if center else _pts).norm(dim=-1)
336
+ scale = torch.nanmedian(_norm, dim=1).values
337
+ return _center[:, None, :, :], scale[:, None, None, None]
338
+
339
+
340
+ def find_reciprocal_matches(P1, P2):
341
+ """
342
+ returns 3 values:
343
+ 1 - reciprocal_in_P2: a boolean array of size P2.shape[0], a "True" value indicates a match
344
+ 2 - nn2_in_P1: a int array of size P2.shape[0], it contains the indexes of the closest points in P1
345
+ 3 - reciprocal_in_P2.sum(): the number of matches
346
+ """
347
+ tree1 = KDTree(P1)
348
+ tree2 = KDTree(P2)
349
+
350
+ _, nn1_in_P2 = tree2.query(P1, workers=8)
351
+ _, nn2_in_P1 = tree1.query(P2, workers=8)
352
+
353
+ reciprocal_in_P1 = (nn2_in_P1[nn1_in_P2] == np.arange(len(nn1_in_P2)))
354
+ reciprocal_in_P2 = (nn1_in_P2[nn2_in_P1] == np.arange(len(nn2_in_P1)))
355
+ assert reciprocal_in_P1.sum() == reciprocal_in_P2.sum()
356
+ return reciprocal_in_P2, nn2_in_P1, reciprocal_in_P2.sum()
357
+
358
+
359
+ def get_med_dist_between_poses(poses):
360
+ from scipy.spatial.distance import pdist
361
+ return np.median(pdist([to_numpy(p[:3, 3]) for p in poses]))
mini_dust3r/utils/image.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions about images (loading/converting...)
6
+ # --------------------------------------------------------
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+ import PIL.Image
11
+ from PIL.ImageOps import exif_transpose
12
+ import torchvision.transforms as tvf
13
+
14
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
15
+ import cv2 # noqa
16
+ from typing import Literal, TypedDict
17
+ from jaxtyping import Float32, Int32
18
+
19
+ try:
20
+ from pillow_heif import register_heif_opener # noqa
21
+
22
+ register_heif_opener()
23
+ heif_support_enabled = True
24
+ except ImportError:
25
+ heif_support_enabled = False
26
+
27
+ ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
28
+
29
+
30
+ class ImageDict(TypedDict):
31
+ img: Float32[torch.Tensor, "b c h w"]
32
+ true_shape: tuple[int, int] | Int32[torch.Tensor, "b 2"]
33
+ idx: int | list[int]
34
+ instance: str | list[str]
35
+
36
+
37
+ def imread_cv2(path, options=cv2.IMREAD_COLOR):
38
+ """Open an image or a depthmap with opencv-python."""
39
+ if path.endswith((".exr", "EXR")):
40
+ options = cv2.IMREAD_ANYDEPTH
41
+ img = cv2.imread(path, options)
42
+ if img is None:
43
+ raise IOError(f"Could not load image={path} with {options=}")
44
+ if img.ndim == 3:
45
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
46
+ return img
47
+
48
+
49
+ def rgb(ftensor, true_shape=None):
50
+ if isinstance(ftensor, list):
51
+ return [rgb(x, true_shape=true_shape) for x in ftensor]
52
+ if isinstance(ftensor, torch.Tensor):
53
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
54
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
55
+ ftensor = ftensor.transpose(1, 2, 0)
56
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
57
+ ftensor = ftensor.transpose(0, 2, 3, 1)
58
+ if true_shape is not None:
59
+ H, W = true_shape
60
+ ftensor = ftensor[:H, :W]
61
+ if ftensor.dtype == np.uint8:
62
+ img = np.float32(ftensor) / 255
63
+ else:
64
+ img = (ftensor * 0.5) + 0.5
65
+ return img.clip(min=0, max=1)
66
+
67
+
68
+ def _resize_pil_image(img, long_edge_size):
69
+ S = max(img.size)
70
+ if S > long_edge_size:
71
+ interp = PIL.Image.LANCZOS
72
+ elif S <= long_edge_size:
73
+ interp = PIL.Image.BICUBIC
74
+ new_size = tuple(int(round(x * long_edge_size / S)) for x in img.size)
75
+ return img.resize(new_size, interp)
76
+
77
+
78
+ def load_images(
79
+ folder_or_list: str | list,
80
+ size: Literal[224, 512],
81
+ square_ok: bool = False,
82
+ verbose: bool = True,
83
+ ) -> list[ImageDict]:
84
+ """open and convert all images in a list or folder to proper input format for DUSt3R"""
85
+ if isinstance(folder_or_list, str):
86
+ if verbose:
87
+ print(f">> Loading images from {folder_or_list}")
88
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
89
+
90
+ elif isinstance(folder_or_list, list):
91
+ if verbose:
92
+ print(f">> Loading a list of {len(folder_or_list)} images")
93
+ root, folder_content = "", folder_or_list
94
+
95
+ else:
96
+ raise ValueError(f"bad {folder_or_list=} ({type(folder_or_list)})")
97
+
98
+ supported_images_extensions = [".jpg", ".jpeg", ".png"]
99
+ if heif_support_enabled:
100
+ supported_images_extensions += [".heic", ".heif"]
101
+ supported_images_extensions = tuple(supported_images_extensions)
102
+
103
+ imgs = []
104
+ for path in folder_content:
105
+ if not path.lower().endswith(supported_images_extensions):
106
+ continue
107
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert("RGB")
108
+ W1, H1 = img.size
109
+ if size == 224:
110
+ # resize short side to 224 (then crop)
111
+ img = _resize_pil_image(img, round(size * max(W1 / H1, H1 / W1)))
112
+ else:
113
+ # resize long side to 512
114
+ img = _resize_pil_image(img, size)
115
+ W, H = img.size
116
+ cx, cy = W // 2, H // 2
117
+ if size == 224:
118
+ half = min(cx, cy)
119
+ img = img.crop((cx - half, cy - half, cx + half, cy + half))
120
+ else:
121
+ halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
122
+ if not (square_ok) and W == H:
123
+ halfh = 3 * halfw / 4
124
+ img = img.crop((cx - halfw, cy - halfh, cx + halfw, cy + halfh))
125
+
126
+ W2, H2 = img.size
127
+ if verbose:
128
+ print(f" - adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
129
+ imgs.append(
130
+ dict(
131
+ img=ImgNorm(img)[None],
132
+ true_shape=np.int32([img.size[::-1]]),
133
+ idx=len(imgs),
134
+ instance=str(len(imgs)),
135
+ )
136
+ )
137
+
138
+ assert imgs, "no images foud at " + root
139
+ if verbose:
140
+ print(f" (Found {len(imgs)} images)")
141
+ return imgs
mini_dust3r/utils/misc.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for DUSt3R
6
+ # --------------------------------------------------------
7
+ import torch
8
+
9
+
10
+ def fill_default_args(kwargs, func):
11
+ import inspect # a bit hacky but it works reliably
12
+ signature = inspect.signature(func)
13
+
14
+ for k, v in signature.parameters.items():
15
+ if v.default is inspect.Parameter.empty:
16
+ continue
17
+ kwargs.setdefault(k, v.default)
18
+
19
+ return kwargs
20
+
21
+
22
+ def freeze_all_params(modules):
23
+ for module in modules:
24
+ try:
25
+ for n, param in module.named_parameters():
26
+ param.requires_grad = False
27
+ except AttributeError:
28
+ # module is directly a parameter
29
+ module.requires_grad = False
30
+
31
+
32
+ def is_symmetrized(gt1, gt2):
33
+ x = gt1['instance']
34
+ y = gt2['instance']
35
+ if len(x) == len(y) and len(x) == 1:
36
+ return False # special case of batchsize 1
37
+ ok = True
38
+ for i in range(0, len(x), 2):
39
+ ok = ok and (x[i] == y[i+1]) and (x[i+1] == y[i])
40
+ return ok
41
+
42
+
43
+ def flip(tensor):
44
+ """ flip so that tensor[0::2] <=> tensor[1::2] """
45
+ return torch.stack((tensor[1::2], tensor[0::2]), dim=1).flatten(0, 1)
46
+
47
+
48
+ def interleave(tensor1, tensor2):
49
+ res1 = torch.stack((tensor1, tensor2), dim=1).flatten(0, 1)
50
+ res2 = torch.stack((tensor2, tensor1), dim=1).flatten(0, 1)
51
+ return res1, res2
52
+
53
+
54
+ def transpose_to_landscape(head, activate=True):
55
+ """ Predict in the correct aspect-ratio,
56
+ then transpose the result in landscape
57
+ and stack everything back together.
58
+ """
59
+ def wrapper_no(decout, true_shape):
60
+ B = len(true_shape)
61
+ assert true_shape[0:1].allclose(true_shape), 'true_shape must be all identical'
62
+ H, W = true_shape[0].cpu().tolist()
63
+ res = head(decout, (H, W))
64
+ return res
65
+
66
+ def wrapper_yes(decout, true_shape):
67
+ B = len(true_shape)
68
+ # by definition, the batch is in landscape mode so W >= H
69
+ H, W = int(true_shape.min()), int(true_shape.max())
70
+
71
+ height, width = true_shape.T
72
+ is_landscape = (width >= height)
73
+ is_portrait = ~is_landscape
74
+
75
+ # true_shape = true_shape.cpu()
76
+ if is_landscape.all():
77
+ return head(decout, (H, W))
78
+ if is_portrait.all():
79
+ return transposed(head(decout, (W, H)))
80
+
81
+ # batch is a mix of both portraint & landscape
82
+ def selout(ar): return [d[ar] for d in decout]
83
+ l_result = head(selout(is_landscape), (H, W))
84
+ p_result = transposed(head(selout(is_portrait), (W, H)))
85
+
86
+ # allocate full result
87
+ result = {}
88
+ for k in l_result | p_result:
89
+ x = l_result[k].new(B, *l_result[k].shape[1:])
90
+ x[is_landscape] = l_result[k]
91
+ x[is_portrait] = p_result[k]
92
+ result[k] = x
93
+
94
+ return result
95
+
96
+ return wrapper_yes if activate else wrapper_no
97
+
98
+
99
+ def transposed(dic):
100
+ return {k: v.swapaxes(1, 2) for k, v in dic.items()}
101
+
102
+
103
+ def invalid_to_nans(arr, valid_mask, ndim=999):
104
+ if valid_mask is not None:
105
+ arr = arr.clone()
106
+ arr[~valid_mask] = float('nan')
107
+ if arr.ndim > ndim:
108
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
109
+ return arr
110
+
111
+
112
+ def invalid_to_zeros(arr, valid_mask, ndim=999):
113
+ if valid_mask is not None:
114
+ arr = arr.clone()
115
+ arr[~valid_mask] = 0
116
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
117
+ else:
118
+ nnz = arr.numel() // len(arr) if len(arr) else 0 # number of point per image
119
+ if arr.ndim > ndim:
120
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
121
+ return arr, nnz
mini_dust3r/viz.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Visualization utilities using trimesh
6
+ # --------------------------------------------------------
7
+ import PIL.Image
8
+ import numpy as np
9
+ from scipy.spatial.transform import Rotation
10
+ import torch
11
+
12
+ from mini_dust3r.utils.geometry import geotrf, get_med_dist_between_poses
13
+ from mini_dust3r.utils.device import to_numpy
14
+ from mini_dust3r.utils.image import rgb
15
+
16
+ try:
17
+ import trimesh
18
+ except ImportError:
19
+ print('/!\\ module trimesh is not installed, cannot visualize results /!\\')
20
+
21
+
22
+ def cat_3d(vecs):
23
+ if isinstance(vecs, (np.ndarray, torch.Tensor)):
24
+ vecs = [vecs]
25
+ return np.concatenate([p.reshape(-1, 3) for p in to_numpy(vecs)])
26
+
27
+
28
+ def show_raw_pointcloud(pts3d, colors, point_size=2):
29
+ scene = trimesh.Scene()
30
+
31
+ pct = trimesh.PointCloud(cat_3d(pts3d), colors=cat_3d(colors))
32
+ scene.add_geometry(pct)
33
+
34
+ scene.show(line_settings={'point_size': point_size})
35
+
36
+
37
+ def pts3d_to_trimesh(img, pts3d, valid=None):
38
+ H, W, THREE = img.shape
39
+ assert THREE == 3
40
+ assert img.shape == pts3d.shape
41
+
42
+ vertices = pts3d.reshape(-1, 3)
43
+
44
+ # make squares: each pixel == 2 triangles
45
+ idx = np.arange(len(vertices)).reshape(H, W)
46
+ idx1 = idx[:-1, :-1].ravel() # top-left corner
47
+ idx2 = idx[:-1, +1:].ravel() # right-left corner
48
+ idx3 = idx[+1:, :-1].ravel() # bottom-left corner
49
+ idx4 = idx[+1:, +1:].ravel() # bottom-right corner
50
+ faces = np.concatenate((
51
+ np.c_[idx1, idx2, idx3],
52
+ np.c_[idx3, idx2, idx1], # same triangle, but backward (cheap solution to cancel face culling)
53
+ np.c_[idx2, idx3, idx4],
54
+ np.c_[idx4, idx3, idx2], # same triangle, but backward (cheap solution to cancel face culling)
55
+ ), axis=0)
56
+
57
+ # prepare triangle colors
58
+ face_colors = np.concatenate((
59
+ img[:-1, :-1].reshape(-1, 3),
60
+ img[:-1, :-1].reshape(-1, 3),
61
+ img[+1:, +1:].reshape(-1, 3),
62
+ img[+1:, +1:].reshape(-1, 3)
63
+ ), axis=0)
64
+
65
+ # remove invalid faces
66
+ if valid is not None:
67
+ assert valid.shape == (H, W)
68
+ valid_idxs = valid.ravel()
69
+ valid_faces = valid_idxs[faces].all(axis=-1)
70
+ faces = faces[valid_faces]
71
+ face_colors = face_colors[valid_faces]
72
+
73
+ assert len(faces) == len(face_colors)
74
+ return dict(vertices=vertices, face_colors=face_colors, faces=faces)
75
+
76
+
77
+ def cat_meshes(meshes):
78
+ vertices, faces, colors = zip(*[(m['vertices'], m['faces'], m['face_colors']) for m in meshes])
79
+ n_vertices = np.cumsum([0]+[len(v) for v in vertices])
80
+ for i in range(len(faces)):
81
+ faces[i][:] += n_vertices[i]
82
+
83
+ vertices = np.concatenate(vertices)
84
+ colors = np.concatenate(colors)
85
+ faces = np.concatenate(faces)
86
+ return dict(vertices=vertices, face_colors=colors, faces=faces)
87
+
88
+
89
+ def show_duster_pairs(view1, view2, pred1, pred2):
90
+ import matplotlib.pyplot as pl
91
+ pl.ion()
92
+
93
+ for e in range(len(view1['instance'])):
94
+ i = view1['idx'][e]
95
+ j = view2['idx'][e]
96
+ img1 = rgb(view1['img'][e])
97
+ img2 = rgb(view2['img'][e])
98
+ conf1 = pred1['conf'][e].squeeze()
99
+ conf2 = pred2['conf'][e].squeeze()
100
+ score = conf1.mean()*conf2.mean()
101
+ print(f">> Showing pair #{e} {i}-{j} {score=:g}")
102
+ pl.clf()
103
+ pl.subplot(221).imshow(img1)
104
+ pl.subplot(223).imshow(img2)
105
+ pl.subplot(222).imshow(conf1, vmin=1, vmax=30)
106
+ pl.subplot(224).imshow(conf2, vmin=1, vmax=30)
107
+ pts1 = pred1['pts3d'][e]
108
+ pts2 = pred2['pts3d_in_other_view'][e]
109
+ pl.subplots_adjust(0, 0, 1, 1, 0, 0)
110
+ if input('show pointcloud? (y/n) ') == 'y':
111
+ show_raw_pointcloud(cat(pts1, pts2), cat(img1, img2), point_size=5)
112
+
113
+
114
+ def auto_cam_size(im_poses):
115
+ return 0.1 * get_med_dist_between_poses(im_poses)
116
+
117
+
118
+ class SceneViz:
119
+ def __init__(self):
120
+ self.scene = trimesh.Scene()
121
+
122
+ def add_pointcloud(self, pts3d, color, mask=None):
123
+ pts3d = to_numpy(pts3d)
124
+ mask = to_numpy(mask)
125
+ if mask is None:
126
+ mask = [slice(None)] * len(pts3d)
127
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
128
+ pct = trimesh.PointCloud(pts.reshape(-1, 3))
129
+
130
+ if isinstance(color, (list, np.ndarray, torch.Tensor)):
131
+ color = to_numpy(color)
132
+ col = np.concatenate([p[m] for p, m in zip(color, mask)])
133
+ assert col.shape == pts.shape
134
+ pct.visual.vertex_colors = uint8(col.reshape(-1, 3))
135
+ else:
136
+ assert len(color) == 3
137
+ pct.visual.vertex_colors = np.broadcast_to(uint8(color), pts.shape)
138
+
139
+ self.scene.add_geometry(pct)
140
+ return self
141
+
142
+ def add_camera(self, pose_c2w, focal=None, color=(0, 0, 0), image=None, imsize=None, cam_size=0.03):
143
+ pose_c2w, focal, color, image = to_numpy((pose_c2w, focal, color, image))
144
+ add_scene_cam(self.scene, pose_c2w, color, image, focal, screen_width=cam_size)
145
+ return self
146
+
147
+ def add_cameras(self, poses, focals=None, images=None, imsizes=None, colors=None, **kw):
148
+ def get(arr, idx): return None if arr is None else arr[idx]
149
+ for i, pose_c2w in enumerate(poses):
150
+ self.add_camera(pose_c2w, get(focals, i), image=get(images, i),
151
+ color=get(colors, i), imsize=get(imsizes, i), **kw)
152
+ return self
153
+
154
+ def show(self, point_size=2):
155
+ self.scene.show(line_settings={'point_size': point_size})
156
+
157
+
158
+ def show_raw_pointcloud_with_cams(imgs, pts3d, mask, focals, cams2world,
159
+ point_size=2, cam_size=0.05, cam_color=None):
160
+ """ Visualization of a pointcloud with cameras
161
+ imgs = (N, H, W, 3) or N-size list of [(H,W,3), ...]
162
+ pts3d = (N, H, W, 3) or N-size list of [(H,W,3), ...]
163
+ focals = (N,) or N-size list of [focal, ...]
164
+ cams2world = (N,4,4) or N-size list of [(4,4), ...]
165
+ """
166
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
167
+ pts3d = to_numpy(pts3d)
168
+ imgs = to_numpy(imgs)
169
+ focals = to_numpy(focals)
170
+ cams2world = to_numpy(cams2world)
171
+
172
+ scene = trimesh.Scene()
173
+
174
+ # full pointcloud
175
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
176
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
177
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
178
+ scene.add_geometry(pct)
179
+
180
+ # add each camera
181
+ for i, pose_c2w in enumerate(cams2world):
182
+ if isinstance(cam_color, list):
183
+ camera_edge_color = cam_color[i]
184
+ else:
185
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
186
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
187
+ imgs[i] if i < len(imgs) else None, focals[i], screen_width=cam_size)
188
+
189
+ scene.show(line_settings={'point_size': point_size})
190
+
191
+
192
+ def add_scene_cam(scene, pose_c2w, edge_color, image=None, focal=None, imsize=None, screen_width=0.03):
193
+
194
+ if image is not None:
195
+ H, W, THREE = image.shape
196
+ assert THREE == 3
197
+ if image.dtype != np.uint8:
198
+ image = np.uint8(255*image)
199
+ elif imsize is not None:
200
+ W, H = imsize
201
+ elif focal is not None:
202
+ H = W = focal / 1.1
203
+ else:
204
+ H = W = 1
205
+
206
+ if focal is None:
207
+ focal = min(H, W) * 1.1 # default value
208
+ elif isinstance(focal, np.ndarray):
209
+ focal = focal[0]
210
+
211
+ # create fake camera
212
+ height = focal * screen_width / H
213
+ width = screen_width * 0.5**0.5
214
+ rot45 = np.eye(4)
215
+ rot45[:3, :3] = Rotation.from_euler('z', np.deg2rad(45)).as_matrix()
216
+ rot45[2, 3] = -height # set the tip of the cone = optical center
217
+ aspect_ratio = np.eye(4)
218
+ aspect_ratio[0, 0] = W/H
219
+ transform = pose_c2w @ OPENGL @ aspect_ratio @ rot45
220
+ cam = trimesh.creation.cone(width, height, sections=4) # , transform=transform)
221
+
222
+ # this is the image
223
+ if image is not None:
224
+ vertices = geotrf(transform, cam.vertices[[4, 5, 1, 3]])
225
+ faces = np.array([[0, 1, 2], [0, 2, 3], [2, 1, 0], [3, 2, 0]])
226
+ img = trimesh.Trimesh(vertices=vertices, faces=faces)
227
+ uv_coords = np.float32([[0, 0], [1, 0], [1, 1], [0, 1]])
228
+ img.visual = trimesh.visual.TextureVisuals(uv_coords, image=PIL.Image.fromarray(image))
229
+ scene.add_geometry(img)
230
+
231
+ # this is the camera mesh
232
+ rot2 = np.eye(4)
233
+ rot2[:3, :3] = Rotation.from_euler('z', np.deg2rad(2)).as_matrix()
234
+ vertices = np.r_[cam.vertices, 0.95*cam.vertices, geotrf(rot2, cam.vertices)]
235
+ vertices = geotrf(transform, vertices)
236
+ faces = []
237
+ for face in cam.faces:
238
+ if 0 in face:
239
+ continue
240
+ a, b, c = face
241
+ a2, b2, c2 = face + len(cam.vertices)
242
+ a3, b3, c3 = face + 2*len(cam.vertices)
243
+
244
+ # add 3 pseudo-edges
245
+ faces.append((a, b, b2))
246
+ faces.append((a, a2, c))
247
+ faces.append((c2, b, c))
248
+
249
+ faces.append((a, b, b3))
250
+ faces.append((a, a3, c))
251
+ faces.append((c3, b, c))
252
+
253
+ # no culling
254
+ faces += [(c, b, a) for a, b, c in faces]
255
+
256
+ cam = trimesh.Trimesh(vertices=vertices, faces=faces)
257
+ cam.visual.face_colors[:, :3] = edge_color
258
+ scene.add_geometry(cam)
259
+
260
+
261
+ def cat(a, b):
262
+ return np.concatenate((a.reshape(-1, 3), b.reshape(-1, 3)))
263
+
264
+
265
+ OPENGL = np.array([[1, 0, 0, 0],
266
+ [0, -1, 0, 0],
267
+ [0, 0, -1, 0],
268
+ [0, 0, 0, 1]])
269
+
270
+
271
+ CAM_COLORS = [(255, 0, 0), (0, 0, 255), (0, 255, 0), (255, 0, 255), (255, 204, 0), (0, 204, 204),
272
+ (128, 255, 255), (255, 128, 255), (255, 255, 128), (0, 0, 0), (128, 128, 128)]
273
+
274
+
275
+ def uint8(colors):
276
+ if not isinstance(colors, np.ndarray):
277
+ colors = np.array(colors)
278
+ if np.issubdtype(colors.dtype, np.floating):
279
+ colors *= 255
280
+ assert 0 <= colors.min() and colors.max() < 256
281
+ return np.uint8(colors)
282
+
283
+
284
+ def segment_sky(image):
285
+ import cv2
286
+ from scipy import ndimage
287
+
288
+ # Convert to HSV
289
+ image = to_numpy(image)
290
+ if np.issubdtype(image.dtype, np.floating):
291
+ image = np.uint8(255*image.clip(min=0, max=1))
292
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
293
+
294
+ # Define range for blue color and create mask
295
+ lower_blue = np.array([0, 0, 100])
296
+ upper_blue = np.array([30, 255, 255])
297
+ mask = cv2.inRange(hsv, lower_blue, upper_blue).view(bool)
298
+
299
+ # add luminous gray
300
+ mask |= (hsv[:, :, 1] < 10) & (hsv[:, :, 2] > 150)
301
+ mask |= (hsv[:, :, 1] < 30) & (hsv[:, :, 2] > 180)
302
+ mask |= (hsv[:, :, 1] < 50) & (hsv[:, :, 2] > 220)
303
+
304
+ # Morphological operations
305
+ kernel = np.ones((5, 5), np.uint8)
306
+ mask2 = ndimage.binary_opening(mask, structure=kernel)
307
+
308
+ # keep only largest CC
309
+ _, labels, stats, _ = cv2.connectedComponentsWithStats(mask2.view(np.uint8), connectivity=8)
310
+ cc_sizes = stats[1:, cv2.CC_STAT_AREA]
311
+ order = cc_sizes.argsort()[::-1] # bigger first
312
+ i = 0
313
+ selection = []
314
+ while i < len(order) and cc_sizes[order[i]] > cc_sizes[order[0]] / 2:
315
+ selection.append(1 + order[i])
316
+ i += 1
317
+ mask3 = np.in1d(labels, selection).reshape(labels.shape)
318
+
319
+ # Apply mask
320
+ return torch.from_numpy(mask3)