kxhit commited on
Commit
49f0812
1 Parent(s): 5ca3a35

debug mini

Browse files
Files changed (1) hide show
  1. app.py +97 -110
app.py CHANGED
@@ -268,6 +268,9 @@ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_
268
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
  import math
270
 
 
 
 
271
  # @spaces.GPU(duration=120)
272
  def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
273
  cam_color=None, as_pointcloud=False,
@@ -368,39 +371,26 @@ def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
368
  if os.path.exists(outdir):
369
  shutil.rmtree(outdir)
370
  os.makedirs(outdir, exist_ok=True)
371
- imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
372
- if len(imgs) == 1:
373
- imgs = [imgs[0], copy.deepcopy(imgs[0])]
374
- imgs[1]['idx'] = 1
375
- if scenegraph_type == "swin":
376
- scenegraph_type = scenegraph_type + "-" + str(winsize)
377
- elif scenegraph_type == "oneref":
378
- scenegraph_type = scenegraph_type + "-" + str(refid)
379
 
380
- pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
381
- output = inference(pairs, model, device, batch_size=1, verbose=not silent)
382
-
383
- mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
384
- scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
385
- lr = 0.01
 
 
 
 
 
 
 
 
386
 
387
- if mode == GlobalAlignerMode.PointCloudOptimizer:
388
- loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
389
 
390
- # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
391
- # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
392
 
393
- # also return rgb, depth and confidence imgs
394
- # depth is normalized with the max value for all images
395
- # we apply the jet colormap on the confidence maps
396
- rgbimg = scene.imgs
397
- # depths = to_numpy(scene.get_depthmaps())
398
- # confs = to_numpy([c for c in scene.im_conf])
399
- # cmap = pl.get_cmap('jet')
400
- # depths_max = max([d.max() for d in depths])
401
- # depths = [d / depths_max for d in depths]
402
- # confs_max = max([d.max() for d in confs])
403
- # confs = [cmap(d / confs_max) for d in confs]
404
 
405
  imgs = []
406
  rgbaimg = []
@@ -419,101 +409,98 @@ def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
419
  rgbaimg = np.array(rgbaimg)
420
 
421
  # for eschernet
422
- # get optimized values from scene
423
- rgbimg = to_numpy(scene.imgs)
424
- # focals = to_numpy(scene.get_focals().cpu())
425
- cams2world = to_numpy(scene.get_im_poses().cpu())
426
-
427
- # 3D pointcloud from depthmap, poses and intrinsics
428
- pts3d = to_numpy(scene.get_pts3d())
429
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
430
- msk = to_numpy(scene.get_masks())
431
- obj_mask = rgbaimg[..., 3] > 0
432
-
433
- # TODO set global coordinate system at the center of the scene, z-axis is up
434
- pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
435
- pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
436
- centroid = np.mean(pts_obj, axis=0) # obj center
437
- obj2world = np.eye(4)
438
- obj2world[:3, 3] = -centroid # T_wc
439
-
440
- # get z_up vector
441
- # TODO fit a plane and get the normal vector
442
- pcd = o3d.geometry.PointCloud()
443
- pcd.points = o3d.utility.Vector3dVector(pts)
444
- plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
445
- # get the normalised normal vector dim = 3
446
- normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
447
- # the normal direction should be pointing up
448
- if normal[1] < 0:
449
- normal = -normal
450
- # print("normal", normal)
451
-
452
- # # TODO z-up 180
453
- # z_up = np.array([[1,0,0,0],
454
- # [0,-1,0,0],
455
- # [0,0,-1,0],
456
- # [0,0,0,1]])
457
- # obj2world = z_up @ obj2world
458
-
459
- # # avg the y
460
- # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
461
- # # import pdb; pdb.set_trace()
462
- # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
463
- # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
 
 
 
464
  # rot = Rotation.from_rotvec(rot_angle * rot_axis)
465
  # z_up = np.eye(4)
466
  # z_up[:3, :3] = rot.as_matrix()
467
-
468
- # get the rotation matrix from normal to z-axis
469
- z_axis = np.array([0, 0, 1])
470
- rot_axis = np.cross(normal, z_axis)
471
- rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
472
- rot = Rotation.from_rotvec(rot_angle * rot_axis)
473
- z_up = np.eye(4)
474
- z_up[:3, :3] = rot.as_matrix()
475
- obj2world = z_up @ obj2world
476
- # flip 180
477
- flip_rot = np.array([[1, 0, 0, 0],
478
- [0, -1, 0, 0],
479
- [0, 0, -1, 0],
480
- [0, 0, 0, 1]])
481
- obj2world = flip_rot @ obj2world
482
-
483
- # get new cams2obj
484
- cams2obj = []
485
- for i, cam2world in enumerate(cams2world):
486
- cams2obj.append(obj2world @ cam2world)
487
- # TODO transform pts3d to the new coordinate system
488
- for i, pts in enumerate(pts3d):
489
- pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
490
- -1)) \
491
- .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
492
- cams2world = np.array(cams2obj)
493
- # TODO rewrite hack
494
- scene.vis_poses = cams2world.copy()
495
- scene.vis_pts3d = pts3d.copy()
496
-
497
- # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
498
- for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
499
- np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
500
- pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
501
- pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
502
- # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
503
  # save the min/max radius of camera
504
  radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
505
- np.save(os.path.join(outdir, "radii.npy"), radii)
506
 
507
  eschernet_input = {"poses": cams2world,
508
  "radii": radii,
509
  "imgs": rgbaimg}
510
  print("got eschernet input")
511
- outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
512
- clean_depth, transparent_cams, cam_size, same_focals=same_focals)
513
 
514
  return scene, outfile, imgs, eschernet_input
515
 
516
 
 
 
517
  def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
518
  num_files = len(inputfiles) if inputfiles is not None else 1
519
  max_winsize = max(1, math.ceil((num_files - 1) / 2))
 
268
  from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
  import math
270
 
271
+ from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
272
+ from mini_dust3r.model import AsymmetricCroCo3DStereo
273
+
274
  # @spaces.GPU(duration=120)
275
  def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
276
  cam_color=None, as_pointcloud=False,
 
371
  if os.path.exists(outdir):
372
  shutil.rmtree(outdir)
373
  os.makedirs(outdir, exist_ok=True)
374
+ # imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
 
 
 
 
 
 
 
375
 
376
+ optimized_results: OptimizedResult = inferece_dust3r(
377
+ image_dir_or_list=filelist,
378
+ model=model,
379
+ device=device,
380
+ batch_size=1,
381
+ )
382
+ rgbimg = optimized_results.rgb_hw3_list
383
+ imgs_rgba = rgbimg
384
+ cams2world = optimized_results.world_T_cam_b44
385
+ pts3d = optimized_results.point_cloud
386
+ pts_obj = pts3d
387
+ outfile = os.path.join(outdir, 'scene.glb')
388
+ # save point cloud trimesh.PointCloud to .ply
389
+ pts3d.export(os.path.join(outdir, 'scene.glb'))
390
 
 
 
391
 
 
 
392
 
393
+ # rgbimg = to_numpy(scene.imgs)
 
 
 
 
 
 
 
 
 
 
394
 
395
  imgs = []
396
  rgbaimg = []
 
409
  rgbaimg = np.array(rgbaimg)
410
 
411
  # for eschernet
412
+ # cams2world = to_numpy(scene.get_im_poses().cpu())
413
+ # pts3d = to_numpy(scene.get_pts3d())
414
+ # scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
415
+ # msk = to_numpy(scene.get_masks())
416
+ # obj_mask = rgbaimg[..., 3] > 0
417
+
418
+ # # TODO set global coordinate system at the center of the scene, z-axis is up
419
+ # # pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
420
+ # # pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
421
+ # centroid = np.mean(pts_obj, axis=0) # obj center
422
+ # obj2world = np.eye(4)
423
+ # obj2world[:3, 3] = -centroid # T_wc
424
+ #
425
+ # # get z_up vector
426
+ # # TODO fit a plane and get the normal vector
427
+ # pcd = o3d.geometry.PointCloud()
428
+ # pcd.points = o3d.utility.Vector3dVector(pts)
429
+ # plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
430
+ # # get the normalised normal vector dim = 3
431
+ # normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
432
+ # # the normal direction should be pointing up
433
+ # if normal[1] < 0:
434
+ # normal = -normal
435
+ # # print("normal", normal)
436
+ #
437
+ # # # TODO z-up 180
438
+ # # z_up = np.array([[1,0,0,0],
439
+ # # [0,-1,0,0],
440
+ # # [0,0,-1,0],
441
+ # # [0,0,0,1]])
442
+ # # obj2world = z_up @ obj2world
443
+ #
444
+ # # # avg the y
445
+ # # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
446
+ # # # import pdb; pdb.set_trace()
447
+ # # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
448
+ # # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
449
+ # # rot = Rotation.from_rotvec(rot_angle * rot_axis)
450
+ # # z_up = np.eye(4)
451
+ # # z_up[:3, :3] = rot.as_matrix()
452
+ #
453
+ # # get the rotation matrix from normal to z-axis
454
+ # z_axis = np.array([0, 0, 1])
455
+ # rot_axis = np.cross(normal, z_axis)
456
+ # rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
457
  # rot = Rotation.from_rotvec(rot_angle * rot_axis)
458
  # z_up = np.eye(4)
459
  # z_up[:3, :3] = rot.as_matrix()
460
+ # obj2world = z_up @ obj2world
461
+ # # flip 180
462
+ # flip_rot = np.array([[1, 0, 0, 0],
463
+ # [0, -1, 0, 0],
464
+ # [0, 0, -1, 0],
465
+ # [0, 0, 0, 1]])
466
+ # obj2world = flip_rot @ obj2world
467
+ #
468
+ # # get new cams2obj
469
+ # cams2obj = []
470
+ # for i, cam2world in enumerate(cams2world):
471
+ # cams2obj.append(obj2world @ cam2world)
472
+ # # TODO transform pts3d to the new coordinate system
473
+ # for i, pts in enumerate(pts3d):
474
+ # pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
475
+ # -1)) \
476
+ # .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
477
+ # cams2world = np.array(cams2obj)
478
+ # # TODO rewrite hack
479
+ # scene.vis_poses = cams2world.copy()
480
+ # scene.vis_pts3d = pts3d.copy()
481
+
482
+ # # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
483
+ # for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
484
+ # np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
485
+ # pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
486
+ # pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
487
+ # # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
 
 
 
 
 
 
 
 
488
  # save the min/max radius of camera
489
  radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
490
+ # np.save(os.path.join(outdir, "radii.npy"), radii)
491
 
492
  eschernet_input = {"poses": cams2world,
493
  "radii": radii,
494
  "imgs": rgbaimg}
495
  print("got eschernet input")
496
+ # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
497
+ # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
498
 
499
  return scene, outfile, imgs, eschernet_input
500
 
501
 
502
+
503
+
504
  def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
505
  num_files = len(inputfiles) if inputfiles is not None else 1
506
  max_winsize = max(1, math.ceil((num_files - 1) / 2))