kxhit commited on
Commit
d80ec21
1 Parent(s): c73e772
Files changed (2) hide show
  1. app.py +105 -753
  2. app_mini.py +773 -0
app.py CHANGED
@@ -1,773 +1,125 @@
1
- import spaces
2
- import torch
3
- print("cuda is available: ", torch.cuda.is_available())
4
-
5
  import gradio as gr
6
- import os
7
- import shutil
8
- import rembg
9
- import numpy as np
10
- import math
11
- import open3d as o3d
12
- from PIL import Image
13
- import torchvision
14
- import trimesh
15
- from skimage.io import imsave
16
- import imageio
17
- import cv2
18
- import matplotlib.pyplot as pl
19
- pl.ion()
20
-
21
- CaPE_TYPE = "6DoF"
22
- device = 'cuda' #if torch.cuda.is_available() else 'cpu'
23
- weight_dtype = torch.float16
24
- torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
-
26
- # EscherNet
27
- # create angles in archimedean spiral with N steps
28
- def get_archimedean_spiral(sphere_radius, num_steps=250):
29
- # x-z plane, around upper y
30
- '''
31
- https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
32
- '''
33
- a = 40
34
- r = sphere_radius
35
-
36
- translations = []
37
- angles = []
38
-
39
- # i = a / 2
40
- i = 0.01
41
- while i < a:
42
- theta = i / a * math.pi
43
- x = r * math.sin(theta) * math.cos(-i)
44
- z = r * math.sin(-theta + math.pi) * math.sin(-i)
45
- y = r * - math.cos(theta)
46
-
47
- # translations.append((x, y, z)) # origin
48
- translations.append((x, z, -y))
49
- angles.append([np.rad2deg(-i), np.rad2deg(theta)])
50
-
51
- # i += a / (2 * num_steps)
52
- i += a / (1 * num_steps)
53
-
54
- return np.array(translations), np.stack(angles)
55
-
56
- def look_at(origin, target, up):
57
- forward = (target - origin)
58
- forward = forward / np.linalg.norm(forward)
59
- right = np.cross(up, forward)
60
- right = right / np.linalg.norm(right)
61
- new_up = np.cross(forward, right)
62
- rotation_matrix = np.column_stack((right, new_up, -forward, target))
63
- matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
64
- return matrix
65
-
66
- import einops
67
- import sys
68
-
69
- sys.path.insert(0, "./6DoF/") # TODO change it when deploying
70
- # use the customized diffusers modules
71
- from diffusers import DDIMScheduler
72
- from dataset import get_pose
73
- from CN_encoder import CN_encoder
74
- from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
- from segment_anything import sam_model_registry, SamPredictor
76
-
77
- # import rembg
78
- from carvekit.api.high import HiInterface
79
-
80
-
81
- pretrained_model_name_or_path = "kxic/EscherNet_demo"
82
- resolution = 256
83
- h,w = resolution,resolution
84
- guidance_scale = 3.0
85
- radius = 2.2
86
- bg_color = [1., 1., 1., 1.]
87
- image_transforms = torchvision.transforms.Compose(
88
- [
89
- torchvision.transforms.Resize((resolution, resolution)), # 256, 256
90
- torchvision.transforms.ToTensor(),
91
- torchvision.transforms.Normalize([0.5], [0.5])
92
- ]
93
- )
94
- xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
95
- # only half toop
96
- xyzs_spiral = xyzs_spiral[:100]
97
- angles_spiral = angles_spiral[:100]
98
-
99
- # Init pipeline
100
- scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
101
- image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
102
- pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
103
- pretrained_model_name_or_path,
104
- revision=None,
105
- scheduler=scheduler,
106
- image_encoder=None,
107
- safety_checker=None,
108
- feature_extractor=None,
109
- torch_dtype=weight_dtype,
110
- )
111
- pipeline.image_encoder = image_encoder.to(weight_dtype)
112
-
113
- pipeline.set_progress_bar_config(disable=False)
114
-
115
- pipeline = pipeline.to(device)
116
-
117
- # pipeline.enable_xformers_memory_efficient_attention()
118
- # enable vae slicing
119
- pipeline.enable_vae_slicing()
120
- # pipeline.enable_xformers_memory_efficient_attention()
121
-
122
-
123
- #### object segmentation
124
- def sam_init():
125
- sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
126
- if os.path.exists(sam_checkpoint) is False:
127
- os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
128
- model_type = "vit_h"
129
-
130
- sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
131
- predictor = SamPredictor(sam)
132
- return predictor
133
-
134
- def create_carvekit_interface():
135
- # Check doc strings for more information
136
- interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
137
- batch_size_seg=6,
138
- batch_size_matting=1,
139
- device="cpu",
140
- seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
141
- matting_mask_size=2048,
142
- trimap_prob_threshold=231,
143
- trimap_dilation=30,
144
- trimap_erosion_iters=5,
145
- fp16=True)
146
-
147
- return interface
148
-
149
-
150
- # rembg_session = rembg.new_session()
151
- rembg_session = create_carvekit_interface()
152
- predictor = sam_init()
153
-
154
 
155
-
156
- @spaces.GPU(duration=120)
157
- def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
158
- # set the random seed
159
- generator = torch.Generator(device=device).manual_seed(sample_seed)
160
- # generator = None
161
- T_out = nvs_num
162
- T_in = len(eschernet_input_dict['imgs'])
163
- ####### output pose
164
- # TODO choose T_out number of poses sequentially from the spiral
165
- xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
166
- angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
167
-
168
- ####### input's max radius for translation scaling
169
- radii = eschernet_input_dict['radii']
170
- max_t = np.max(radii)
171
- min_t = np.min(radii)
172
-
173
- ####### input pose
174
- pose_in = []
175
- for T_in_index in range(T_in):
176
- pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
177
- pose[1:3, :] *= -1 # coordinate system conversion
178
- pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
179
- pose_in.append(torch.from_numpy(pose))
180
-
181
- ####### input image
182
- img = eschernet_input_dict['imgs'] / 255.
183
- img[img[:, :, :, -1] == 0.] = bg_color
184
- # TODO batch image_transforms
185
- input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
186
-
187
- ####### nvs pose
188
- pose_out = []
189
- for T_out_index in range(T_out):
190
- azimuth, polar = angles_out[T_out_index]
191
- if CaPE_TYPE == "4DoF":
192
- pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
193
- elif CaPE_TYPE == "6DoF":
194
- pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
195
- pose = np.linalg.inv(pose)
196
- pose[2, :] *= -1
197
- pose_out.append(torch.from_numpy(get_pose(pose)))
198
-
199
-
200
-
201
- # [B, T, C, H, W]
202
- input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
203
- # [B, T, 4]
204
- pose_in = np.stack(pose_in)
205
- pose_out = np.stack(pose_out)
206
-
207
- if CaPE_TYPE == "6DoF":
208
- pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
209
- pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
210
- pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
211
- pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
212
-
213
- pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
214
- pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
215
-
216
- input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
217
- assert T_in == input_image.shape[0]
218
- assert T_in == pose_in.shape[1]
219
- assert T_out == pose_out.shape[1]
220
-
221
- # run inference
222
- # pipeline.to(device)
223
- pipeline.enable_xformers_memory_efficient_attention()
224
- image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
225
- poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
226
- height=h, width=w, T_in=T_in, T_out=T_out,
227
- guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
228
- output_type="numpy").images
229
-
230
- # save output image
231
- output_dir = os.path.join(tmpdirname, "eschernet")
232
- if os.path.exists(output_dir):
233
- shutil.rmtree(output_dir)
234
- os.makedirs(output_dir, exist_ok=True)
235
- # # save to N imgs
236
- # for i in range(T_out):
237
- # imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
238
- # make a gif
239
- frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
240
- # frame_one = frames[0]
241
- # frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
242
- # save_all=True, duration=50, loop=1)
243
-
244
- # get a video
245
- video_path = os.path.join(output_dir, "output.mp4")
246
- imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
247
-
248
-
249
- return video_path
250
-
251
- # TODO mesh it
252
- @spaces.GPU(duration=120)
253
- def make3d():
254
- pass
255
-
256
-
257
-
258
- ############################ Dust3r as Pose Estimation ############################
259
- from scipy.spatial.transform import Rotation
260
- import copy
261
-
262
- from dust3r.inference import inference
263
- from dust3r.model import AsymmetricCroCo3DStereo
264
- from dust3r.image_pairs import make_pairs
265
- from dust3r.utils.image import load_images, rgb
266
- from dust3r.utils.device import to_numpy
267
- from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
268
- from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
- import math
270
 
271
  from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
272
  from mini_dust3r.model import AsymmetricCroCo3DStereo
273
 
274
- # @spaces.GPU(duration=120)
275
- def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
276
- cam_color=None, as_pointcloud=False,
277
- transparent_cams=False, silent=False, same_focals=False):
278
- assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
279
- if not same_focals:
280
- assert (len(cams2world) == len(focals))
281
- pts3d = to_numpy(pts3d)
282
- imgs = to_numpy(imgs)
283
- focals = to_numpy(focals)
284
- cams2world = to_numpy(cams2world)
285
-
286
- scene = trimesh.Scene()
287
-
288
- # add axes
289
- scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
290
-
291
- # full pointcloud
292
- if as_pointcloud:
293
- pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
294
- col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
295
- pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
296
- scene.add_geometry(pct)
297
  else:
298
- meshes = []
299
- for i in range(len(imgs)):
300
- meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
301
- mesh = trimesh.Trimesh(**cat_meshes(meshes))
302
- scene.add_geometry(mesh)
303
-
304
- # add each camera
305
- for i, pose_c2w in enumerate(cams2world):
306
- if isinstance(cam_color, list):
307
- camera_edge_color = cam_color[i]
308
- else:
309
- camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
310
- if same_focals:
311
- focal = focals[0]
312
- else:
313
- focal = focals[i]
314
- add_scene_cam(scene, pose_c2w, camera_edge_color,
315
- None if transparent_cams else imgs[i], focal,
316
- imsize=imgs[i].shape[1::-1], screen_width=cam_size)
317
-
318
- rot = np.eye(4)
319
- rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
320
- scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
321
- outfile = os.path.join(outdir, 'scene.glb')
322
- if not silent:
323
- print('(exporting 3D scene to', outfile, ')')
324
- scene.export(file_obj=outfile)
325
- return outfile
326
-
327
- # @spaces.GPU(duration=120)
328
- def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
329
- clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
330
- """
331
- extract 3D_model (glb file) from a reconstructed scene
332
- """
333
- if scene is None:
334
- return None
335
- # post processes
336
- if clean_depth:
337
- scene = scene.clean_pointcloud()
338
- if mask_sky:
339
- scene = scene.mask_sky()
340
-
341
- # get optimized values from scene
342
- rgbimg = to_numpy(scene.imgs)
343
- focals = to_numpy(scene.get_focals().cpu())
344
- # cams2world = to_numpy(scene.get_im_poses().cpu())
345
- # TODO use the vis_poses
346
- cams2world = scene.vis_poses
347
-
348
- # 3D pointcloud from depthmap, poses and intrinsics
349
- # pts3d = to_numpy(scene.get_pts3d())
350
- # TODO use the vis_poses
351
- pts3d = scene.vis_pts3d
352
- scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
353
- msk = to_numpy(scene.get_masks())
354
-
355
- return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
356
- transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
357
- same_focals=same_focals)
358
-
359
- @spaces.GPU(duration=120)
360
- def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
361
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
362
- scenegraph_type, winsize, refid, same_focals):
363
- """
364
- from a list of images, run dust3r inference, global aligner.
365
- then run get_3D_model_from_scene
366
- """
367
- silent = False
368
- image_size = 224
369
- # remove the directory if it already exists
370
- outdir = tmpdirname
371
- if os.path.exists(outdir):
372
- shutil.rmtree(outdir)
373
- os.makedirs(outdir, exist_ok=True)
374
- # imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
375
 
376
  optimized_results: OptimizedResult = inferece_dust3r(
377
- image_dir_or_list=filelist,
378
  model=model,
379
- device=device,
380
  batch_size=1,
381
  )
382
- rgbimg = optimized_results.rgb_hw3_list
383
- imgs_rgba = rgbimg
384
- cams2world = optimized_results.world_T_cam_b44
385
- pts3d = optimized_results.point_cloud
386
- pts_obj = pts3d
387
- outfile = os.path.join(outdir, 'scene.glb')
388
- # save point cloud trimesh.PointCloud to .ply
389
- pts3d.export(os.path.join(outdir, 'scene.glb'))
390
-
391
-
392
-
393
- # rgbimg = to_numpy(scene.imgs)
394
-
395
- imgs = []
396
- rgbaimg = []
397
- for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
398
- imgs.append(rgbimg[i])
399
- # imgs.append(rgb(depths[i]))
400
- # imgs.append(rgb(confs[i]))
401
- # imgs.append(imgs_rgba[i])
402
- if len(imgs_rgba) == 1 and i == 1:
403
- imgs.append(imgs_rgba[0])
404
- rgbaimg.append(np.array(imgs_rgba[0]))
405
- else:
406
- imgs.append(imgs_rgba[i])
407
- rgbaimg.append(np.array(imgs_rgba[i]))
408
-
409
- rgbaimg = np.array(rgbaimg)
410
-
411
- # for eschernet
412
- # cams2world = to_numpy(scene.get_im_poses().cpu())
413
- # pts3d = to_numpy(scene.get_pts3d())
414
- # scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
415
- # msk = to_numpy(scene.get_masks())
416
- # obj_mask = rgbaimg[..., 3] > 0
417
-
418
- # # TODO set global coordinate system at the center of the scene, z-axis is up
419
- # # pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
420
- # # pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
421
- # centroid = np.mean(pts_obj, axis=0) # obj center
422
- # obj2world = np.eye(4)
423
- # obj2world[:3, 3] = -centroid # T_wc
424
- #
425
- # # get z_up vector
426
- # # TODO fit a plane and get the normal vector
427
- # pcd = o3d.geometry.PointCloud()
428
- # pcd.points = o3d.utility.Vector3dVector(pts)
429
- # plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
430
- # # get the normalised normal vector dim = 3
431
- # normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
432
- # # the normal direction should be pointing up
433
- # if normal[1] < 0:
434
- # normal = -normal
435
- # # print("normal", normal)
436
- #
437
- # # # TODO z-up 180
438
- # # z_up = np.array([[1,0,0,0],
439
- # # [0,-1,0,0],
440
- # # [0,0,-1,0],
441
- # # [0,0,0,1]])
442
- # # obj2world = z_up @ obj2world
443
- #
444
- # # # avg the y
445
- # # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
446
- # # # import pdb; pdb.set_trace()
447
- # # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
448
- # # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
449
- # # rot = Rotation.from_rotvec(rot_angle * rot_axis)
450
- # # z_up = np.eye(4)
451
- # # z_up[:3, :3] = rot.as_matrix()
452
- #
453
- # # get the rotation matrix from normal to z-axis
454
- # z_axis = np.array([0, 0, 1])
455
- # rot_axis = np.cross(normal, z_axis)
456
- # rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
457
- # rot = Rotation.from_rotvec(rot_angle * rot_axis)
458
- # z_up = np.eye(4)
459
- # z_up[:3, :3] = rot.as_matrix()
460
- # obj2world = z_up @ obj2world
461
- # # flip 180
462
- # flip_rot = np.array([[1, 0, 0, 0],
463
- # [0, -1, 0, 0],
464
- # [0, 0, -1, 0],
465
- # [0, 0, 0, 1]])
466
- # obj2world = flip_rot @ obj2world
467
- #
468
- # # get new cams2obj
469
- # cams2obj = []
470
- # for i, cam2world in enumerate(cams2world):
471
- # cams2obj.append(obj2world @ cam2world)
472
- # # TODO transform pts3d to the new coordinate system
473
- # for i, pts in enumerate(pts3d):
474
- # pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
475
- # -1)) \
476
- # .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
477
- # cams2world = np.array(cams2obj)
478
- # # TODO rewrite hack
479
- # scene.vis_poses = cams2world.copy()
480
- # scene.vis_pts3d = pts3d.copy()
481
-
482
- # # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
483
- # for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
484
- # np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
485
- # pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
486
- # pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
487
- # # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
488
- # save the min/max radius of camera
489
- radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
490
- # np.save(os.path.join(outdir, "radii.npy"), radii)
491
-
492
- eschernet_input = {"poses": cams2world,
493
- "radii": radii,
494
- "imgs": rgbaimg}
495
- print("got eschernet input")
496
- # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
497
- # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
498
-
499
- return scene, outfile, imgs, eschernet_input
500
-
501
-
502
-
503
-
504
- def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
505
- num_files = len(inputfiles) if inputfiles is not None else 1
506
- max_winsize = max(1, math.ceil((num_files - 1) / 2))
507
- if scenegraph_type == "swin":
508
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
509
- minimum=1, maximum=max_winsize, step=1, visible=True)
510
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
511
- maximum=num_files - 1, step=1, visible=False)
512
- elif scenegraph_type == "oneref":
513
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
514
- minimum=1, maximum=max_winsize, step=1, visible=False)
515
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
516
- maximum=num_files - 1, step=1, visible=True)
517
- else:
518
- winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
519
- minimum=1, maximum=max_winsize, step=1, visible=False)
520
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
521
- maximum=num_files - 1, step=1, visible=False)
522
- return winsize, refid
523
-
524
-
525
- def get_examples(path):
526
- objs = []
527
- for obj_name in sorted(os.listdir(path)):
528
- img_files = []
529
- for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
530
- img_files.append(os.path.join(path, obj_name, img_file))
531
- objs.append([img_files])
532
- print("objs = ", objs)
533
- return objs
534
-
535
- def preview_input(inputfiles):
536
- if inputfiles is None:
537
- return None
538
- imgs = []
539
- for img_file in inputfiles:
540
- img = pl.imread(img_file)
541
- imgs.append(img)
542
- return imgs
543
-
544
- # def main():
545
- # dustr init
546
- silent = False
547
- image_size = 224
548
- weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
549
- model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
550
- # dust3r will write the 3D model inside tmpdirname
551
- # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
552
- tmpdirname = os.path.join('logs/user_object')
553
- # remove the directory if it already exists
554
- if os.path.exists(tmpdirname):
555
- shutil.rmtree(tmpdirname)
556
- os.makedirs(tmpdirname, exist_ok=True)
557
- if not silent:
558
- print('Outputing stuff in', tmpdirname)
559
-
560
- _HEADER_ = '''
561
- <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
562
- <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
563
-
564
- Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
565
-
566
- <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
567
- <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
568
- <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
569
-
570
- <h4><b>Tips:</b></h4>
571
-
572
- - Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
573
-
574
- - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
575
 
576
- - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
 
577
 
578
- - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
 
 
 
579
 
580
- '''
581
 
582
- _CITE_ = r"""
583
- 📝 <b>Citation</b>:
584
- ```bibtex
585
- @article{kong2024eschernet,
586
- title={EscherNet: A Generative Model for Scalable View Synthesis},
587
- author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
588
- journal={arXiv preprint arXiv:2402.03908},
589
- year={2024}
590
- }
591
- ```
592
- """
593
-
594
- with gr.Blocks() as demo:
595
- gr.Markdown(_HEADER_)
596
- # mv_images = gr.State()
597
- scene = gr.State(None)
598
- eschernet_input = gr.State(None)
599
- with gr.Row(variant="panel"):
600
- # left column
601
  with gr.Column():
602
- with gr.Row():
603
- input_image = gr.File(file_count="multiple")
604
- with gr.Row():
605
- run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
606
- with gr.Row():
607
- processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
608
- with gr.Row(variant="panel"):
609
- # input examples under "examples" folder
610
- gr.Examples(
611
- examples=get_examples('examples'),
612
- inputs=[input_image],
613
- label="Examples (click one set of images to start!)",
614
- examples_per_page=20
615
- )
616
-
617
-
618
-
619
-
620
-
621
- # right column
622
  with gr.Column():
 
 
 
 
 
 
623
 
624
- with gr.Row():
625
- outmodel = gr.Model3D()
626
-
627
- with gr.Row():
628
- gr.Markdown('''
629
- <h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
630
- ''')
631
-
632
- with gr.Row():
633
- with gr.Group():
634
- do_remove_background = gr.Checkbox(
635
- label="Remove Background", value=True
636
- )
637
- sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
638
-
639
- sample_steps = gr.Slider(
640
- label="Sample Steps",
641
- minimum=30,
642
- maximum=75,
643
- value=50,
644
- step=5,
645
- visible=False
646
- )
647
-
648
- nvs_num = gr.Slider(
649
- label="Number of Novel Views",
650
- minimum=5,
651
- maximum=100,
652
- value=30,
653
- step=1
654
- )
655
-
656
- nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
657
- value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
658
-
659
- with gr.Row():
660
- gr.Markdown('''
661
- <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
662
- ''')
663
-
664
- with gr.Row():
665
- submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
666
-
667
- with gr.Row():
668
- with gr.Column():
669
- output_video = gr.Video(
670
- label="video", format="mp4",
671
- width=379,
672
- autoplay=True,
673
- interactive=False
674
- )
675
-
676
- with gr.Row():
677
- gr.Markdown('''
678
- <h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
679
- ''')
680
-
681
- gr.Markdown(_CITE_)
682
-
683
- # set dust3r parameter invisible to be clean
684
- with gr.Column():
685
- with gr.Row():
686
- schedule = gr.Dropdown(["linear", "cosine"],
687
- value='linear', label="schedule", info="For global alignment!", visible=False)
688
- niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
689
- label="num_iterations", info="For global alignment!", visible=False)
690
- scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
691
- value='complete', label="Scenegraph",
692
- info="Define how to make pairs",
693
- interactive=True, visible=False)
694
- same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
695
- winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
696
- minimum=1, maximum=1, step=1, visible=False)
697
- refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
698
-
699
- with gr.Row():
700
- # adjust the confidence threshold
701
- min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
702
- # adjust the camera size in the output pointcloud
703
- cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
704
- with gr.Row():
705
- as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
706
- # two post process implemented
707
- mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
708
- clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
709
- transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
710
-
711
- # events
712
- # scenegraph_type.change(set_scenegraph_options,
713
- # inputs=[input_image, winsize, refid, scenegraph_type],
714
- # outputs=[winsize, refid])
715
- # min_conf_thr.release(fn=model_from_scene_fun,
716
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
717
- # clean_depth, transparent_cams, cam_size, same_focals],
718
- # outputs=outmodel)
719
- # cam_size.change(fn=model_from_scene_fun,
720
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
721
- # clean_depth, transparent_cams, cam_size, same_focals],
722
- # outputs=outmodel)
723
- # as_pointcloud.change(fn=model_from_scene_fun,
724
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
725
- # clean_depth, transparent_cams, cam_size, same_focals],
726
- # outputs=outmodel)
727
- # mask_sky.change(fn=model_from_scene_fun,
728
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
729
- # clean_depth, transparent_cams, cam_size, same_focals],
730
- # outputs=outmodel)
731
- # clean_depth.change(fn=model_from_scene_fun,
732
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
733
- # clean_depth, transparent_cams, cam_size, same_focals],
734
- # outputs=outmodel)
735
- # transparent_cams.change(model_from_scene_fun,
736
- # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
737
- # clean_depth, transparent_cams, cam_size, same_focals],
738
- # outputs=outmodel)
739
- # run_dust3r.click(fn=recon_fun,
740
- # inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
741
- # mask_sky, clean_depth, transparent_cams, cam_size,
742
- # scenegraph_type, winsize, refid, same_focals],
743
- # outputs=[scene, outmodel, processed_image, eschernet_input])
744
-
745
- # events
746
- input_image.change(set_scenegraph_options,
747
- inputs=[input_image, winsize, refid, scenegraph_type],
748
- outputs=[winsize, refid])
749
- run_dust3r.click(fn=get_reconstructed_scene,
750
- inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
751
- mask_sky, clean_depth, transparent_cams, cam_size,
752
- scenegraph_type, winsize, refid, same_focals],
753
- outputs=[scene, outmodel, processed_image, eschernet_input])
754
-
755
-
756
- # events
757
- input_image.change(fn=preview_input,
758
- inputs=[input_image],
759
- outputs=[processed_image])
760
-
761
- submit.click(fn=run_eschernet,
762
- inputs=[eschernet_input, sample_steps, sample_seed,
763
- nvs_num, nvs_mode],
764
- outputs=[output_video])
765
-
766
-
767
-
768
- # demo.queue(max_size=10)
769
- # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
770
- demo.queue(max_size=10).launch()
771
 
772
- # if __name__ == '__main__':
773
- # main()
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ import spaces
4
+ import torch
5
+ from gradio_rerun import Rerun
6
+ import rerun as rr
7
+ import rerun.blueprint as rrb
8
+ from pathlib import Path
9
+ import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
12
  from mini_dust3r.model import AsymmetricCroCo3DStereo
13
 
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model = AsymmetricCroCo3DStereo.from_pretrained(
16
+ "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
17
+ ).to(DEVICE)
18
+
19
+
20
+ def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
21
+ # dont show 2d views if there are more than 4 images as to not clutter the view
22
+ if len(image_name_list) > 4:
23
+ blueprint = rrb.Blueprint(
24
+ rrb.Horizontal(
25
+ rrb.Spatial3DView(origin=f"{log_path}"),
26
+ ),
27
+ collapse_panels=True,
28
+ )
 
 
 
 
 
 
 
 
29
  else:
30
+ blueprint = rrb.Blueprint(
31
+ rrb.Horizontal(
32
+ contents=[
33
+ rrb.Spatial3DView(origin=f"{log_path}"),
34
+ rrb.Vertical(
35
+ contents=[
36
+ rrb.Spatial2DView(
37
+ origin=f"{log_path}/camera_{i}/pinhole/",
38
+ contents=[
39
+ "+ $origin/**",
40
+ ],
41
+ )
42
+ for i in range(len(image_name_list))
43
+ ]
44
+ ),
45
+ ],
46
+ column_shares=[3, 1],
47
+ ),
48
+ collapse_panels=True,
49
+ )
50
+ return blueprint
51
+
52
+
53
+ @spaces.GPU
54
+ def predict(image_name_list: list[str] | str):
55
+ # check if is list or string and if not raise error
56
+ if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
57
+ raise gr.Error(
58
+ f"Input must be a list of strings or a string, got: {type(image_name_list)}"
59
+ )
60
+ uuid_str = str(uuid.uuid4())
61
+ filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
62
+ rr.init(f"{uuid_str}")
63
+ log_path = Path("world")
64
+
65
+ if isinstance(image_name_list, str):
66
+ image_name_list = [image_name_list]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  optimized_results: OptimizedResult = inferece_dust3r(
69
+ image_dir_or_list=image_name_list,
70
  model=model,
71
+ device=DEVICE,
72
  batch_size=1,
73
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path)
76
+ rr.send_blueprint(blueprint)
77
 
78
+ rr.set_time_sequence("sequence", 0)
79
+ log_optimized_result(optimized_results, log_path)
80
+ rr.save(filename.as_posix())
81
+ return filename.as_posix()
82
 
 
83
 
84
+ with gr.Blocks(
85
+ css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
86
+ title="Mini-DUSt3R Demo",
87
+ ) as demo:
88
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
89
+ gr.HTML('<h2 style="text-align: center;">Mini-DUSt3R Demo</h2>')
90
+ gr.HTML(
91
+ '<p style="text-align: center;">Unofficial DUSt3R demo using the mini-dust3r pip package</p>'
92
+ )
93
+ gr.HTML(
94
+ '<p style="text-align: center;">More info <a href="https://github.com/pablovela5620/mini-dust3r">here</a></p>'
95
+ )
96
+ with gr.Tab(label="Single Image"):
 
 
 
 
 
 
97
  with gr.Column():
98
+ single_image = gr.Image(type="filepath", height=300)
99
+ run_btn_single = gr.Button("Run")
100
+ rerun_viewer_single = Rerun(height=900)
101
+ run_btn_single.click(
102
+ fn=predict, inputs=[single_image], outputs=[rerun_viewer_single]
103
+ )
104
+
105
+ example_single_dir = Path("examples/single_image")
106
+ example_single_files = sorted(example_single_dir.glob("*.png"))
107
+
108
+ examples_single = gr.Examples(
109
+ examples=example_single_files,
110
+ inputs=[single_image],
111
+ outputs=[rerun_viewer_single],
112
+ fn=predict,
113
+ cache_examples="lazy",
114
+ )
115
+ with gr.Tab(label="Multi Image"):
 
 
116
  with gr.Column():
117
+ multi_files = gr.File(file_count="multiple")
118
+ run_btn_multi = gr.Button("Run")
119
+ rerun_viewer_multi = Rerun(height=900)
120
+ run_btn_multi.click(
121
+ fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi]
122
+ )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
+ demo.launch()
 
app_mini.py ADDED
@@ -0,0 +1,773 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import torch
3
+ print("cuda is available: ", torch.cuda.is_available())
4
+
5
+ import gradio as gr
6
+ import os
7
+ import shutil
8
+ import rembg
9
+ import numpy as np
10
+ import math
11
+ import open3d as o3d
12
+ from PIL import Image
13
+ import torchvision
14
+ import trimesh
15
+ from skimage.io import imsave
16
+ import imageio
17
+ import cv2
18
+ import matplotlib.pyplot as pl
19
+ pl.ion()
20
+
21
+ CaPE_TYPE = "6DoF"
22
+ device = 'cuda' #if torch.cuda.is_available() else 'cpu'
23
+ weight_dtype = torch.float16
24
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
+
26
+ # EscherNet
27
+ # create angles in archimedean spiral with N steps
28
+ def get_archimedean_spiral(sphere_radius, num_steps=250):
29
+ # x-z plane, around upper y
30
+ '''
31
+ https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
32
+ '''
33
+ a = 40
34
+ r = sphere_radius
35
+
36
+ translations = []
37
+ angles = []
38
+
39
+ # i = a / 2
40
+ i = 0.01
41
+ while i < a:
42
+ theta = i / a * math.pi
43
+ x = r * math.sin(theta) * math.cos(-i)
44
+ z = r * math.sin(-theta + math.pi) * math.sin(-i)
45
+ y = r * - math.cos(theta)
46
+
47
+ # translations.append((x, y, z)) # origin
48
+ translations.append((x, z, -y))
49
+ angles.append([np.rad2deg(-i), np.rad2deg(theta)])
50
+
51
+ # i += a / (2 * num_steps)
52
+ i += a / (1 * num_steps)
53
+
54
+ return np.array(translations), np.stack(angles)
55
+
56
+ def look_at(origin, target, up):
57
+ forward = (target - origin)
58
+ forward = forward / np.linalg.norm(forward)
59
+ right = np.cross(up, forward)
60
+ right = right / np.linalg.norm(right)
61
+ new_up = np.cross(forward, right)
62
+ rotation_matrix = np.column_stack((right, new_up, -forward, target))
63
+ matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
64
+ return matrix
65
+
66
+ import einops
67
+ import sys
68
+
69
+ sys.path.insert(0, "./6DoF/") # TODO change it when deploying
70
+ # use the customized diffusers modules
71
+ from diffusers import DDIMScheduler
72
+ from dataset import get_pose
73
+ from CN_encoder import CN_encoder
74
+ from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
+ from segment_anything import sam_model_registry, SamPredictor
76
+
77
+ # import rembg
78
+ from carvekit.api.high import HiInterface
79
+
80
+
81
+ pretrained_model_name_or_path = "kxic/EscherNet_demo"
82
+ resolution = 256
83
+ h,w = resolution,resolution
84
+ guidance_scale = 3.0
85
+ radius = 2.2
86
+ bg_color = [1., 1., 1., 1.]
87
+ image_transforms = torchvision.transforms.Compose(
88
+ [
89
+ torchvision.transforms.Resize((resolution, resolution)), # 256, 256
90
+ torchvision.transforms.ToTensor(),
91
+ torchvision.transforms.Normalize([0.5], [0.5])
92
+ ]
93
+ )
94
+ xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
95
+ # only half toop
96
+ xyzs_spiral = xyzs_spiral[:100]
97
+ angles_spiral = angles_spiral[:100]
98
+
99
+ # Init pipeline
100
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
101
+ image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
102
+ pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
103
+ pretrained_model_name_or_path,
104
+ revision=None,
105
+ scheduler=scheduler,
106
+ image_encoder=None,
107
+ safety_checker=None,
108
+ feature_extractor=None,
109
+ torch_dtype=weight_dtype,
110
+ )
111
+ pipeline.image_encoder = image_encoder.to(weight_dtype)
112
+
113
+ pipeline.set_progress_bar_config(disable=False)
114
+
115
+ pipeline = pipeline.to(device)
116
+
117
+ # pipeline.enable_xformers_memory_efficient_attention()
118
+ # enable vae slicing
119
+ pipeline.enable_vae_slicing()
120
+ # pipeline.enable_xformers_memory_efficient_attention()
121
+
122
+
123
+ #### object segmentation
124
+ def sam_init():
125
+ sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
126
+ if os.path.exists(sam_checkpoint) is False:
127
+ os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
128
+ model_type = "vit_h"
129
+
130
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
131
+ predictor = SamPredictor(sam)
132
+ return predictor
133
+
134
+ def create_carvekit_interface():
135
+ # Check doc strings for more information
136
+ interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
137
+ batch_size_seg=6,
138
+ batch_size_matting=1,
139
+ device="cpu",
140
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
141
+ matting_mask_size=2048,
142
+ trimap_prob_threshold=231,
143
+ trimap_dilation=30,
144
+ trimap_erosion_iters=5,
145
+ fp16=True)
146
+
147
+ return interface
148
+
149
+
150
+ # rembg_session = rembg.new_session()
151
+ rembg_session = create_carvekit_interface()
152
+ predictor = sam_init()
153
+
154
+
155
+
156
+ @spaces.GPU(duration=120)
157
+ def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
158
+ # set the random seed
159
+ generator = torch.Generator(device=device).manual_seed(sample_seed)
160
+ # generator = None
161
+ T_out = nvs_num
162
+ T_in = len(eschernet_input_dict['imgs'])
163
+ ####### output pose
164
+ # TODO choose T_out number of poses sequentially from the spiral
165
+ xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
166
+ angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
167
+
168
+ ####### input's max radius for translation scaling
169
+ radii = eschernet_input_dict['radii']
170
+ max_t = np.max(radii)
171
+ min_t = np.min(radii)
172
+
173
+ ####### input pose
174
+ pose_in = []
175
+ for T_in_index in range(T_in):
176
+ pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
177
+ pose[1:3, :] *= -1 # coordinate system conversion
178
+ pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
179
+ pose_in.append(torch.from_numpy(pose))
180
+
181
+ ####### input image
182
+ img = eschernet_input_dict['imgs'] / 255.
183
+ img[img[:, :, :, -1] == 0.] = bg_color
184
+ # TODO batch image_transforms
185
+ input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
186
+
187
+ ####### nvs pose
188
+ pose_out = []
189
+ for T_out_index in range(T_out):
190
+ azimuth, polar = angles_out[T_out_index]
191
+ if CaPE_TYPE == "4DoF":
192
+ pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
193
+ elif CaPE_TYPE == "6DoF":
194
+ pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
195
+ pose = np.linalg.inv(pose)
196
+ pose[2, :] *= -1
197
+ pose_out.append(torch.from_numpy(get_pose(pose)))
198
+
199
+
200
+
201
+ # [B, T, C, H, W]
202
+ input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
203
+ # [B, T, 4]
204
+ pose_in = np.stack(pose_in)
205
+ pose_out = np.stack(pose_out)
206
+
207
+ if CaPE_TYPE == "6DoF":
208
+ pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
209
+ pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
210
+ pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
211
+ pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
212
+
213
+ pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
214
+ pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
215
+
216
+ input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
217
+ assert T_in == input_image.shape[0]
218
+ assert T_in == pose_in.shape[1]
219
+ assert T_out == pose_out.shape[1]
220
+
221
+ # run inference
222
+ # pipeline.to(device)
223
+ pipeline.enable_xformers_memory_efficient_attention()
224
+ image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
225
+ poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
226
+ height=h, width=w, T_in=T_in, T_out=T_out,
227
+ guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
228
+ output_type="numpy").images
229
+
230
+ # save output image
231
+ output_dir = os.path.join(tmpdirname, "eschernet")
232
+ if os.path.exists(output_dir):
233
+ shutil.rmtree(output_dir)
234
+ os.makedirs(output_dir, exist_ok=True)
235
+ # # save to N imgs
236
+ # for i in range(T_out):
237
+ # imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
238
+ # make a gif
239
+ frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
240
+ # frame_one = frames[0]
241
+ # frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
242
+ # save_all=True, duration=50, loop=1)
243
+
244
+ # get a video
245
+ video_path = os.path.join(output_dir, "output.mp4")
246
+ imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
247
+
248
+
249
+ return video_path
250
+
251
+ # TODO mesh it
252
+ @spaces.GPU(duration=120)
253
+ def make3d():
254
+ pass
255
+
256
+
257
+
258
+ ############################ Dust3r as Pose Estimation ############################
259
+ from scipy.spatial.transform import Rotation
260
+ import copy
261
+
262
+ from dust3r.inference import inference
263
+ from dust3r.model import AsymmetricCroCo3DStereo
264
+ from dust3r.image_pairs import make_pairs
265
+ from dust3r.utils.image import load_images, rgb
266
+ from dust3r.utils.device import to_numpy
267
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
268
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
+ import math
270
+
271
+ from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
272
+ from mini_dust3r.model import AsymmetricCroCo3DStereo
273
+
274
+ # @spaces.GPU(duration=120)
275
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
276
+ cam_color=None, as_pointcloud=False,
277
+ transparent_cams=False, silent=False, same_focals=False):
278
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
279
+ if not same_focals:
280
+ assert (len(cams2world) == len(focals))
281
+ pts3d = to_numpy(pts3d)
282
+ imgs = to_numpy(imgs)
283
+ focals = to_numpy(focals)
284
+ cams2world = to_numpy(cams2world)
285
+
286
+ scene = trimesh.Scene()
287
+
288
+ # add axes
289
+ scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
290
+
291
+ # full pointcloud
292
+ if as_pointcloud:
293
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
294
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
295
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
296
+ scene.add_geometry(pct)
297
+ else:
298
+ meshes = []
299
+ for i in range(len(imgs)):
300
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
301
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
302
+ scene.add_geometry(mesh)
303
+
304
+ # add each camera
305
+ for i, pose_c2w in enumerate(cams2world):
306
+ if isinstance(cam_color, list):
307
+ camera_edge_color = cam_color[i]
308
+ else:
309
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
310
+ if same_focals:
311
+ focal = focals[0]
312
+ else:
313
+ focal = focals[i]
314
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
315
+ None if transparent_cams else imgs[i], focal,
316
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
317
+
318
+ rot = np.eye(4)
319
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
320
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
321
+ outfile = os.path.join(outdir, 'scene.glb')
322
+ if not silent:
323
+ print('(exporting 3D scene to', outfile, ')')
324
+ scene.export(file_obj=outfile)
325
+ return outfile
326
+
327
+ # @spaces.GPU(duration=120)
328
+ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
329
+ clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
330
+ """
331
+ extract 3D_model (glb file) from a reconstructed scene
332
+ """
333
+ if scene is None:
334
+ return None
335
+ # post processes
336
+ if clean_depth:
337
+ scene = scene.clean_pointcloud()
338
+ if mask_sky:
339
+ scene = scene.mask_sky()
340
+
341
+ # get optimized values from scene
342
+ rgbimg = to_numpy(scene.imgs)
343
+ focals = to_numpy(scene.get_focals().cpu())
344
+ # cams2world = to_numpy(scene.get_im_poses().cpu())
345
+ # TODO use the vis_poses
346
+ cams2world = scene.vis_poses
347
+
348
+ # 3D pointcloud from depthmap, poses and intrinsics
349
+ # pts3d = to_numpy(scene.get_pts3d())
350
+ # TODO use the vis_poses
351
+ pts3d = scene.vis_pts3d
352
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
353
+ msk = to_numpy(scene.get_masks())
354
+
355
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
356
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
357
+ same_focals=same_focals)
358
+
359
+ @spaces.GPU(duration=120)
360
+ def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
361
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
362
+ scenegraph_type, winsize, refid, same_focals):
363
+ """
364
+ from a list of images, run dust3r inference, global aligner.
365
+ then run get_3D_model_from_scene
366
+ """
367
+ silent = False
368
+ image_size = 224
369
+ # remove the directory if it already exists
370
+ outdir = tmpdirname
371
+ if os.path.exists(outdir):
372
+ shutil.rmtree(outdir)
373
+ os.makedirs(outdir, exist_ok=True)
374
+ # imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
375
+
376
+ optimized_results: OptimizedResult = inferece_dust3r(
377
+ image_dir_or_list=filelist,
378
+ model=model,
379
+ device=device,
380
+ batch_size=1,
381
+ )
382
+ rgbimg = optimized_results.rgb_hw3_list
383
+ imgs_rgba = rgbimg
384
+ cams2world = optimized_results.world_T_cam_b44
385
+ pts3d = optimized_results.point_cloud
386
+ pts_obj = pts3d
387
+ outfile = os.path.join(outdir, 'scene.glb')
388
+ # save point cloud trimesh.PointCloud to .ply
389
+ pts3d.export(os.path.join(outdir, 'scene.glb'))
390
+
391
+
392
+
393
+ # rgbimg = to_numpy(scene.imgs)
394
+
395
+ imgs = []
396
+ rgbaimg = []
397
+ for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
398
+ imgs.append(rgbimg[i])
399
+ # imgs.append(rgb(depths[i]))
400
+ # imgs.append(rgb(confs[i]))
401
+ # imgs.append(imgs_rgba[i])
402
+ if len(imgs_rgba) == 1 and i == 1:
403
+ imgs.append(imgs_rgba[0])
404
+ rgbaimg.append(np.array(imgs_rgba[0]))
405
+ else:
406
+ imgs.append(imgs_rgba[i])
407
+ rgbaimg.append(np.array(imgs_rgba[i]))
408
+
409
+ rgbaimg = np.array(rgbaimg)
410
+
411
+ # for eschernet
412
+ # cams2world = to_numpy(scene.get_im_poses().cpu())
413
+ # pts3d = to_numpy(scene.get_pts3d())
414
+ # scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
415
+ # msk = to_numpy(scene.get_masks())
416
+ # obj_mask = rgbaimg[..., 3] > 0
417
+
418
+ # # TODO set global coordinate system at the center of the scene, z-axis is up
419
+ # # pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
420
+ # # pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
421
+ # centroid = np.mean(pts_obj, axis=0) # obj center
422
+ # obj2world = np.eye(4)
423
+ # obj2world[:3, 3] = -centroid # T_wc
424
+ #
425
+ # # get z_up vector
426
+ # # TODO fit a plane and get the normal vector
427
+ # pcd = o3d.geometry.PointCloud()
428
+ # pcd.points = o3d.utility.Vector3dVector(pts)
429
+ # plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
430
+ # # get the normalised normal vector dim = 3
431
+ # normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
432
+ # # the normal direction should be pointing up
433
+ # if normal[1] < 0:
434
+ # normal = -normal
435
+ # # print("normal", normal)
436
+ #
437
+ # # # TODO z-up 180
438
+ # # z_up = np.array([[1,0,0,0],
439
+ # # [0,-1,0,0],
440
+ # # [0,0,-1,0],
441
+ # # [0,0,0,1]])
442
+ # # obj2world = z_up @ obj2world
443
+ #
444
+ # # # avg the y
445
+ # # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
446
+ # # # import pdb; pdb.set_trace()
447
+ # # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
448
+ # # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
449
+ # # rot = Rotation.from_rotvec(rot_angle * rot_axis)
450
+ # # z_up = np.eye(4)
451
+ # # z_up[:3, :3] = rot.as_matrix()
452
+ #
453
+ # # get the rotation matrix from normal to z-axis
454
+ # z_axis = np.array([0, 0, 1])
455
+ # rot_axis = np.cross(normal, z_axis)
456
+ # rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
457
+ # rot = Rotation.from_rotvec(rot_angle * rot_axis)
458
+ # z_up = np.eye(4)
459
+ # z_up[:3, :3] = rot.as_matrix()
460
+ # obj2world = z_up @ obj2world
461
+ # # flip 180
462
+ # flip_rot = np.array([[1, 0, 0, 0],
463
+ # [0, -1, 0, 0],
464
+ # [0, 0, -1, 0],
465
+ # [0, 0, 0, 1]])
466
+ # obj2world = flip_rot @ obj2world
467
+ #
468
+ # # get new cams2obj
469
+ # cams2obj = []
470
+ # for i, cam2world in enumerate(cams2world):
471
+ # cams2obj.append(obj2world @ cam2world)
472
+ # # TODO transform pts3d to the new coordinate system
473
+ # for i, pts in enumerate(pts3d):
474
+ # pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
475
+ # -1)) \
476
+ # .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
477
+ # cams2world = np.array(cams2obj)
478
+ # # TODO rewrite hack
479
+ # scene.vis_poses = cams2world.copy()
480
+ # scene.vis_pts3d = pts3d.copy()
481
+
482
+ # # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
483
+ # for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
484
+ # np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
485
+ # pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
486
+ # pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
487
+ # # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
488
+ # save the min/max radius of camera
489
+ radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
490
+ # np.save(os.path.join(outdir, "radii.npy"), radii)
491
+
492
+ eschernet_input = {"poses": cams2world,
493
+ "radii": radii,
494
+ "imgs": rgbaimg}
495
+ print("got eschernet input")
496
+ # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
497
+ # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
498
+
499
+ return scene, outfile, imgs, eschernet_input
500
+
501
+
502
+
503
+
504
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
505
+ num_files = len(inputfiles) if inputfiles is not None else 1
506
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
507
+ if scenegraph_type == "swin":
508
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
509
+ minimum=1, maximum=max_winsize, step=1, visible=True)
510
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
511
+ maximum=num_files - 1, step=1, visible=False)
512
+ elif scenegraph_type == "oneref":
513
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
514
+ minimum=1, maximum=max_winsize, step=1, visible=False)
515
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
516
+ maximum=num_files - 1, step=1, visible=True)
517
+ else:
518
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
519
+ minimum=1, maximum=max_winsize, step=1, visible=False)
520
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
521
+ maximum=num_files - 1, step=1, visible=False)
522
+ return winsize, refid
523
+
524
+
525
+ def get_examples(path):
526
+ objs = []
527
+ for obj_name in sorted(os.listdir(path)):
528
+ img_files = []
529
+ for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
530
+ img_files.append(os.path.join(path, obj_name, img_file))
531
+ objs.append([img_files])
532
+ print("objs = ", objs)
533
+ return objs
534
+
535
+ def preview_input(inputfiles):
536
+ if inputfiles is None:
537
+ return None
538
+ imgs = []
539
+ for img_file in inputfiles:
540
+ img = pl.imread(img_file)
541
+ imgs.append(img)
542
+ return imgs
543
+
544
+ # def main():
545
+ # dustr init
546
+ silent = False
547
+ image_size = 224
548
+ weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
549
+ model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
550
+ # dust3r will write the 3D model inside tmpdirname
551
+ # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
552
+ tmpdirname = os.path.join('logs/user_object')
553
+ # remove the directory if it already exists
554
+ if os.path.exists(tmpdirname):
555
+ shutil.rmtree(tmpdirname)
556
+ os.makedirs(tmpdirname, exist_ok=True)
557
+ if not silent:
558
+ print('Outputing stuff in', tmpdirname)
559
+
560
+ _HEADER_ = '''
561
+ <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
562
+ <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
563
+
564
+ Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
565
+
566
+ <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
567
+ <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
568
+ <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
569
+
570
+ <h4><b>Tips:</b></h4>
571
+
572
+ - Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
573
+
574
+ - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
575
+
576
+ - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
577
+
578
+ - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
579
+
580
+ '''
581
+
582
+ _CITE_ = r"""
583
+ 📝 <b>Citation</b>:
584
+ ```bibtex
585
+ @article{kong2024eschernet,
586
+ title={EscherNet: A Generative Model for Scalable View Synthesis},
587
+ author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
588
+ journal={arXiv preprint arXiv:2402.03908},
589
+ year={2024}
590
+ }
591
+ ```
592
+ """
593
+
594
+ with gr.Blocks() as demo:
595
+ gr.Markdown(_HEADER_)
596
+ # mv_images = gr.State()
597
+ scene = gr.State(None)
598
+ eschernet_input = gr.State(None)
599
+ with gr.Row(variant="panel"):
600
+ # left column
601
+ with gr.Column():
602
+ with gr.Row():
603
+ input_image = gr.File(file_count="multiple")
604
+ with gr.Row():
605
+ run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
606
+ with gr.Row():
607
+ processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
608
+ with gr.Row(variant="panel"):
609
+ # input examples under "examples" folder
610
+ gr.Examples(
611
+ examples=get_examples('examples'),
612
+ inputs=[input_image],
613
+ label="Examples (click one set of images to start!)",
614
+ examples_per_page=20
615
+ )
616
+
617
+
618
+
619
+
620
+
621
+ # right column
622
+ with gr.Column():
623
+
624
+ with gr.Row():
625
+ outmodel = gr.Model3D()
626
+
627
+ with gr.Row():
628
+ gr.Markdown('''
629
+ <h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
630
+ ''')
631
+
632
+ with gr.Row():
633
+ with gr.Group():
634
+ do_remove_background = gr.Checkbox(
635
+ label="Remove Background", value=True
636
+ )
637
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
638
+
639
+ sample_steps = gr.Slider(
640
+ label="Sample Steps",
641
+ minimum=30,
642
+ maximum=75,
643
+ value=50,
644
+ step=5,
645
+ visible=False
646
+ )
647
+
648
+ nvs_num = gr.Slider(
649
+ label="Number of Novel Views",
650
+ minimum=5,
651
+ maximum=100,
652
+ value=30,
653
+ step=1
654
+ )
655
+
656
+ nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
657
+ value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
658
+
659
+ with gr.Row():
660
+ gr.Markdown('''
661
+ <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
662
+ ''')
663
+
664
+ with gr.Row():
665
+ submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
666
+
667
+ with gr.Row():
668
+ with gr.Column():
669
+ output_video = gr.Video(
670
+ label="video", format="mp4",
671
+ width=379,
672
+ autoplay=True,
673
+ interactive=False
674
+ )
675
+
676
+ with gr.Row():
677
+ gr.Markdown('''
678
+ <h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
679
+ ''')
680
+
681
+ gr.Markdown(_CITE_)
682
+
683
+ # set dust3r parameter invisible to be clean
684
+ with gr.Column():
685
+ with gr.Row():
686
+ schedule = gr.Dropdown(["linear", "cosine"],
687
+ value='linear', label="schedule", info="For global alignment!", visible=False)
688
+ niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
689
+ label="num_iterations", info="For global alignment!", visible=False)
690
+ scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
691
+ value='complete', label="Scenegraph",
692
+ info="Define how to make pairs",
693
+ interactive=True, visible=False)
694
+ same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
695
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
696
+ minimum=1, maximum=1, step=1, visible=False)
697
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
698
+
699
+ with gr.Row():
700
+ # adjust the confidence threshold
701
+ min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
702
+ # adjust the camera size in the output pointcloud
703
+ cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
704
+ with gr.Row():
705
+ as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
706
+ # two post process implemented
707
+ mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
708
+ clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
709
+ transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
710
+
711
+ # events
712
+ # scenegraph_type.change(set_scenegraph_options,
713
+ # inputs=[input_image, winsize, refid, scenegraph_type],
714
+ # outputs=[winsize, refid])
715
+ # min_conf_thr.release(fn=model_from_scene_fun,
716
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
717
+ # clean_depth, transparent_cams, cam_size, same_focals],
718
+ # outputs=outmodel)
719
+ # cam_size.change(fn=model_from_scene_fun,
720
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
721
+ # clean_depth, transparent_cams, cam_size, same_focals],
722
+ # outputs=outmodel)
723
+ # as_pointcloud.change(fn=model_from_scene_fun,
724
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
725
+ # clean_depth, transparent_cams, cam_size, same_focals],
726
+ # outputs=outmodel)
727
+ # mask_sky.change(fn=model_from_scene_fun,
728
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
729
+ # clean_depth, transparent_cams, cam_size, same_focals],
730
+ # outputs=outmodel)
731
+ # clean_depth.change(fn=model_from_scene_fun,
732
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
733
+ # clean_depth, transparent_cams, cam_size, same_focals],
734
+ # outputs=outmodel)
735
+ # transparent_cams.change(model_from_scene_fun,
736
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
737
+ # clean_depth, transparent_cams, cam_size, same_focals],
738
+ # outputs=outmodel)
739
+ # run_dust3r.click(fn=recon_fun,
740
+ # inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
741
+ # mask_sky, clean_depth, transparent_cams, cam_size,
742
+ # scenegraph_type, winsize, refid, same_focals],
743
+ # outputs=[scene, outmodel, processed_image, eschernet_input])
744
+
745
+ # events
746
+ input_image.change(set_scenegraph_options,
747
+ inputs=[input_image, winsize, refid, scenegraph_type],
748
+ outputs=[winsize, refid])
749
+ run_dust3r.click(fn=get_reconstructed_scene,
750
+ inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
751
+ mask_sky, clean_depth, transparent_cams, cam_size,
752
+ scenegraph_type, winsize, refid, same_focals],
753
+ outputs=[scene, outmodel, processed_image, eschernet_input])
754
+
755
+
756
+ # events
757
+ input_image.change(fn=preview_input,
758
+ inputs=[input_image],
759
+ outputs=[processed_image])
760
+
761
+ submit.click(fn=run_eschernet,
762
+ inputs=[eschernet_input, sample_steps, sample_seed,
763
+ nvs_num, nvs_mode],
764
+ outputs=[output_video])
765
+
766
+
767
+
768
+ # demo.queue(max_size=10)
769
+ # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
770
+ demo.queue(max_size=10).launch()
771
+
772
+ # if __name__ == '__main__':
773
+ # main()