kxhit commited on
Commit
ce137a5
1 Parent(s): c1a2c97
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +775 -114
README.md CHANGED
@@ -4,6 +4,7 @@ emoji: 📸📸➡️🖼️🖼️🖼️
4
  app_file: app.py
5
  sdk: gradio
6
  sdk_version: 4.31.0
 
7
  ---
8
  [comment]: <> (# EscherNet: A Generative Model for Scalable View Synthesis)
9
 
 
4
  app_file: app.py
5
  sdk: gradio
6
  sdk_version: 4.31.0
7
+ short_description: 3D novel view synthesis from any number images!
8
  ---
9
  [comment]: <> (# EscherNet: A Generative Model for Scalable View Synthesis)
10
 
app.py CHANGED
@@ -1,125 +1,786 @@
1
- import gradio as gr
2
-
3
  import spaces
4
  import torch
5
- from gradio_rerun import Rerun
6
- import rerun as rr
7
- import rerun.blueprint as rrb
8
- from pathlib import Path
9
- import uuid
10
-
11
- from mini_dust3r.api import OptimizedResult, inferece_dust3r, log_optimized_result
12
- from mini_dust3r.model import AsymmetricCroCo3DStereo
13
-
14
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
- model = AsymmetricCroCo3DStereo.from_pretrained(
16
- "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
17
- ).to(DEVICE)
18
-
19
-
20
- def create_blueprint(image_name_list: list[str], log_path: Path) -> rrb.Blueprint:
21
- # dont show 2d views if there are more than 4 images as to not clutter the view
22
- if len(image_name_list) > 4:
23
- blueprint = rrb.Blueprint(
24
- rrb.Horizontal(
25
- rrb.Spatial3DView(origin=f"{log_path}"),
26
- ),
27
- collapse_panels=True,
28
- )
29
- else:
30
- blueprint = rrb.Blueprint(
31
- rrb.Horizontal(
32
- contents=[
33
- rrb.Spatial3DView(origin=f"{log_path}"),
34
- rrb.Vertical(
35
- contents=[
36
- rrb.Spatial2DView(
37
- origin=f"{log_path}/camera_{i}/pinhole/",
38
- contents=[
39
- "+ $origin/**",
40
- ],
41
- )
42
- for i in range(len(image_name_list))
43
- ]
44
- ),
45
- ],
46
- column_shares=[3, 1],
47
- ),
48
- collapse_panels=True,
49
- )
50
- return blueprint
51
-
52
-
53
- @spaces.GPU
54
- def predict(image_name_list: list[str] | str):
55
- # check if is list or string and if not raise error
56
- if not isinstance(image_name_list, list) and not isinstance(image_name_list, str):
57
- raise gr.Error(
58
- f"Input must be a list of strings or a string, got: {type(image_name_list)}"
59
- )
60
- uuid_str = str(uuid.uuid4())
61
- filename = Path(f"/tmp/gradio/{uuid_str}.rrd")
62
- rr.init(f"{uuid_str}")
63
- log_path = Path("world")
64
-
65
- if isinstance(image_name_list, str):
66
- image_name_list = [image_name_list]
67
-
68
- optimized_results: OptimizedResult = inferece_dust3r(
69
- image_dir_or_list=image_name_list,
70
- model=model,
71
- device=DEVICE,
72
- batch_size=1,
73
- )
74
 
75
- blueprint: rrb.Blueprint = create_blueprint(image_name_list, log_path)
76
- rr.send_blueprint(blueprint)
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- rr.set_time_sequence("sequence", 0)
79
- log_optimized_result(optimized_results, log_path)
80
- rr.save(filename.as_posix())
81
- return filename.as_posix()
82
 
 
 
 
 
 
 
 
 
 
83
 
84
- with gr.Blocks(
85
- css=""".gradio-container {margin: 0 !important; min-width: 100%};""",
86
- title="Mini-DUSt3R Demo",
87
- ) as demo:
88
- # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
89
- gr.HTML('<h2 style="text-align: center;">Mini-DUSt3R Demo</h2>')
90
- gr.HTML(
91
- '<p style="text-align: center;">Unofficial DUSt3R demo using the mini-dust3r pip package</p>'
92
- )
93
- gr.HTML(
94
- '<p style="text-align: center;">More info <a href="https://github.com/pablovela5620/mini-dust3r">here</a></p>'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  )
96
- with gr.Tab(label="Single Image"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  with gr.Column():
98
- single_image = gr.Image(type="filepath", height=300)
99
- run_btn_single = gr.Button("Run")
100
- rerun_viewer_single = Rerun(height=900)
101
- run_btn_single.click(
102
- fn=predict, inputs=[single_image], outputs=[rerun_viewer_single]
103
- )
104
-
105
- example_single_dir = Path("examples/single_image")
106
- example_single_files = sorted(example_single_dir.glob("*.png"))
107
-
108
- examples_single = gr.Examples(
109
- examples=example_single_files,
110
- inputs=[single_image],
111
- outputs=[rerun_viewer_single],
112
- fn=predict,
113
- cache_examples="lazy",
114
- )
115
- with gr.Tab(label="Multi Image"):
 
 
116
  with gr.Column():
117
- multi_files = gr.File(file_count="multiple")
118
- run_btn_multi = gr.Button("Run")
119
- rerun_viewer_multi = Rerun(height=900)
120
- run_btn_multi.click(
121
- fn=predict, inputs=[multi_files], outputs=[rerun_viewer_multi]
122
- )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- demo.launch()
 
 
 
 
1
  import spaces
2
  import torch
3
+ print("cuda is available: ", torch.cuda.is_available())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ import gradio as gr
6
+ import os
7
+ import shutil
8
+ import rembg
9
+ import numpy as np
10
+ import math
11
+ import open3d as o3d
12
+ from PIL import Image
13
+ import torchvision
14
+ import trimesh
15
+ from skimage.io import imsave
16
+ import imageio
17
+ import cv2
18
+ import matplotlib.pyplot as pl
19
+ pl.ion()
20
 
21
+ CaPE_TYPE = "6DoF"
22
+ device = 'cuda' #if torch.cuda.is_available() else 'cpu'
23
+ weight_dtype = torch.float16
24
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
25
 
26
+ # EscherNet
27
+ # create angles in archimedean spiral with N steps
28
+ def get_archimedean_spiral(sphere_radius, num_steps=250):
29
+ # x-z plane, around upper y
30
+ '''
31
+ https://en.wikipedia.org/wiki/Spiral, section "Spherical spiral". c = a / pi
32
+ '''
33
+ a = 40
34
+ r = sphere_radius
35
 
36
+ translations = []
37
+ angles = []
38
+
39
+ # i = a / 2
40
+ i = 0.01
41
+ while i < a:
42
+ theta = i / a * math.pi
43
+ x = r * math.sin(theta) * math.cos(-i)
44
+ z = r * math.sin(-theta + math.pi) * math.sin(-i)
45
+ y = r * - math.cos(theta)
46
+
47
+ # translations.append((x, y, z)) # origin
48
+ translations.append((x, z, -y))
49
+ angles.append([np.rad2deg(-i), np.rad2deg(theta)])
50
+
51
+ # i += a / (2 * num_steps)
52
+ i += a / (1 * num_steps)
53
+
54
+ return np.array(translations), np.stack(angles)
55
+
56
+ def look_at(origin, target, up):
57
+ forward = (target - origin)
58
+ forward = forward / np.linalg.norm(forward)
59
+ right = np.cross(up, forward)
60
+ right = right / np.linalg.norm(right)
61
+ new_up = np.cross(forward, right)
62
+ rotation_matrix = np.column_stack((right, new_up, -forward, target))
63
+ matrix = np.row_stack((rotation_matrix, [0, 0, 0, 1]))
64
+ return matrix
65
+
66
+ import einops
67
+ import sys
68
+
69
+ sys.path.insert(0, "./6DoF/") # TODO change it when deploying
70
+ # use the customized diffusers modules
71
+ from diffusers import DDIMScheduler
72
+ from dataset import get_pose
73
+ from CN_encoder import CN_encoder
74
+ from pipeline_zero1to3 import Zero1to3StableDiffusionPipeline
75
+ from segment_anything import sam_model_registry, SamPredictor
76
+
77
+ # import rembg
78
+ from carvekit.api.high import HiInterface
79
+
80
+
81
+ pretrained_model_name_or_path = "kxic/EscherNet_demo"
82
+ resolution = 256
83
+ h,w = resolution,resolution
84
+ guidance_scale = 3.0
85
+ radius = 2.2
86
+ bg_color = [1., 1., 1., 1.]
87
+ image_transforms = torchvision.transforms.Compose(
88
+ [
89
+ torchvision.transforms.Resize((resolution, resolution)), # 256, 256
90
+ torchvision.transforms.ToTensor(),
91
+ torchvision.transforms.Normalize([0.5], [0.5])
92
+ ]
93
  )
94
+ xyzs_spiral, angles_spiral = get_archimedean_spiral(1.5, 200)
95
+ # only half toop
96
+ xyzs_spiral = xyzs_spiral[:100]
97
+ angles_spiral = angles_spiral[:100]
98
+
99
+ # Init pipeline
100
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler", revision=None)
101
+ image_encoder = CN_encoder.from_pretrained(pretrained_model_name_or_path, subfolder="image_encoder", revision=None)
102
+ pipeline = Zero1to3StableDiffusionPipeline.from_pretrained(
103
+ pretrained_model_name_or_path,
104
+ revision=None,
105
+ scheduler=scheduler,
106
+ image_encoder=None,
107
+ safety_checker=None,
108
+ feature_extractor=None,
109
+ torch_dtype=weight_dtype,
110
+ )
111
+ pipeline.image_encoder = image_encoder.to(weight_dtype)
112
+
113
+ pipeline.set_progress_bar_config(disable=False)
114
+
115
+ pipeline = pipeline.to(device)
116
+
117
+ # pipeline.enable_xformers_memory_efficient_attention()
118
+ # enable vae slicing
119
+ pipeline.enable_vae_slicing()
120
+ # pipeline.enable_xformers_memory_efficient_attention()
121
+
122
+
123
+ #### object segmentation
124
+ def sam_init():
125
+ sam_checkpoint = os.path.join("./sam_pt/sam_vit_h_4b8939.pth")
126
+ if os.path.exists(sam_checkpoint) is False:
127
+ os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P ./sam_pt/")
128
+ model_type = "vit_h"
129
+
130
+ sam = sam_model_registry[model_type](checkpoint=sam_checkpoint).to(device=device)
131
+ predictor = SamPredictor(sam)
132
+ return predictor
133
+
134
+ def create_carvekit_interface():
135
+ # Check doc strings for more information
136
+ interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
137
+ batch_size_seg=6,
138
+ batch_size_matting=1,
139
+ device="cpu",
140
+ seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
141
+ matting_mask_size=2048,
142
+ trimap_prob_threshold=231,
143
+ trimap_dilation=30,
144
+ trimap_erosion_iters=5,
145
+ fp16=True)
146
+
147
+ return interface
148
+
149
+
150
+ # rembg_session = rembg.new_session()
151
+ rembg_session = create_carvekit_interface()
152
+ predictor = sam_init()
153
+
154
+
155
+
156
+ @spaces.GPU(duration=120)
157
+ def run_eschernet(eschernet_input_dict, sample_steps, sample_seed, nvs_num, nvs_mode):
158
+ # set the random seed
159
+ generator = torch.Generator(device=device).manual_seed(sample_seed)
160
+ # generator = None
161
+ T_out = nvs_num
162
+ T_in = len(eschernet_input_dict['imgs'])
163
+ ####### output pose
164
+ # TODO choose T_out number of poses sequentially from the spiral
165
+ xyzs = xyzs_spiral[::(len(xyzs_spiral) // T_out)]
166
+ angles_out = angles_spiral[::(len(xyzs_spiral) // T_out)]
167
+
168
+ ####### input's max radius for translation scaling
169
+ radii = eschernet_input_dict['radii']
170
+ max_t = np.max(radii)
171
+ min_t = np.min(radii)
172
+
173
+ ####### input pose
174
+ pose_in = []
175
+ for T_in_index in range(T_in):
176
+ pose = get_pose(np.linalg.inv(eschernet_input_dict['poses'][T_in_index]))
177
+ pose[1:3, :] *= -1 # coordinate system conversion
178
+ pose[3, 3] *= 1. / max_t * radius # scale radius to [1.5, 2.2]
179
+ pose_in.append(torch.from_numpy(pose))
180
+
181
+ ####### input image
182
+ img = eschernet_input_dict['imgs'] / 255.
183
+ img[img[:, :, :, -1] == 0.] = bg_color
184
+ # TODO batch image_transforms
185
+ input_image = [image_transforms(Image.fromarray(np.uint8(im[:, :, :3] * 255.)).convert("RGB")) for im in img]
186
+
187
+ ####### nvs pose
188
+ pose_out = []
189
+ for T_out_index in range(T_out):
190
+ azimuth, polar = angles_out[T_out_index]
191
+ if CaPE_TYPE == "4DoF":
192
+ pose_out.append(torch.tensor([np.deg2rad(polar), np.deg2rad(azimuth), 0., 0.]))
193
+ elif CaPE_TYPE == "6DoF":
194
+ pose = look_at(origin=np.array([0, 0, 0]), target=xyzs[T_out_index], up=np.array([0, 0, 1]))
195
+ pose = np.linalg.inv(pose)
196
+ pose[2, :] *= -1
197
+ pose_out.append(torch.from_numpy(get_pose(pose)))
198
+
199
+
200
+
201
+ # [B, T, C, H, W]
202
+ input_image = torch.stack(input_image, dim=0).to(device).to(weight_dtype).unsqueeze(0)
203
+ # [B, T, 4]
204
+ pose_in = np.stack(pose_in)
205
+ pose_out = np.stack(pose_out)
206
+
207
+ if CaPE_TYPE == "6DoF":
208
+ pose_in_inv = np.linalg.inv(pose_in).transpose([0, 2, 1])
209
+ pose_out_inv = np.linalg.inv(pose_out).transpose([0, 2, 1])
210
+ pose_in_inv = torch.from_numpy(pose_in_inv).to(device).to(weight_dtype).unsqueeze(0)
211
+ pose_out_inv = torch.from_numpy(pose_out_inv).to(device).to(weight_dtype).unsqueeze(0)
212
+
213
+ pose_in = torch.from_numpy(pose_in).to(device).to(weight_dtype).unsqueeze(0)
214
+ pose_out = torch.from_numpy(pose_out).to(device).to(weight_dtype).unsqueeze(0)
215
+
216
+ input_image = einops.rearrange(input_image, "b t c h w -> (b t) c h w")
217
+ assert T_in == input_image.shape[0]
218
+ assert T_in == pose_in.shape[1]
219
+ assert T_out == pose_out.shape[1]
220
+
221
+ # run inference
222
+ # pipeline.to(device)
223
+ pipeline.enable_xformers_memory_efficient_attention()
224
+ image = pipeline(input_imgs=input_image, prompt_imgs=input_image,
225
+ poses=[[pose_out, pose_out_inv], [pose_in, pose_in_inv]],
226
+ height=h, width=w, T_in=T_in, T_out=T_out,
227
+ guidance_scale=guidance_scale, num_inference_steps=50, generator=generator,
228
+ output_type="numpy").images
229
+
230
+ # save output image
231
+ output_dir = os.path.join(tmpdirname, "eschernet")
232
+ if os.path.exists(output_dir):
233
+ shutil.rmtree(output_dir)
234
+ os.makedirs(output_dir, exist_ok=True)
235
+ # # save to N imgs
236
+ # for i in range(T_out):
237
+ # imsave(os.path.join(output_dir, f'{i}.png'), (image[i] * 255).astype(np.uint8))
238
+ # make a gif
239
+ frames = [Image.fromarray((image[i] * 255).astype(np.uint8)) for i in range(T_out)]
240
+ # frame_one = frames[0]
241
+ # frame_one.save(os.path.join(output_dir, "output.gif"), format="GIF", append_images=frames,
242
+ # save_all=True, duration=50, loop=1)
243
+
244
+ # get a video
245
+ video_path = os.path.join(output_dir, "output.mp4")
246
+ imageio.mimwrite(video_path, np.stack(frames), fps=10, codec='h264')
247
+
248
+
249
+ return video_path
250
+
251
+ # TODO mesh it
252
+ @spaces.GPU(duration=120)
253
+ def make3d():
254
+ pass
255
+
256
+
257
+
258
+ ############################ Dust3r as Pose Estimation ############################
259
+ from scipy.spatial.transform import Rotation
260
+ import copy
261
+
262
+ from dust3r.inference import inference
263
+ from dust3r.model import AsymmetricCroCo3DStereo
264
+ from dust3r.image_pairs import make_pairs
265
+ from dust3r.utils.image import load_images, rgb
266
+ from dust3r.utils.device import to_numpy
267
+ from dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
268
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
269
+ import math
270
+
271
+ @spaces.GPU(duration=120)
272
+ def _convert_scene_output_to_glb(outdir, imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
273
+ cam_color=None, as_pointcloud=False,
274
+ transparent_cams=False, silent=False, same_focals=False):
275
+ assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world)
276
+ if not same_focals:
277
+ assert (len(cams2world) == len(focals))
278
+ pts3d = to_numpy(pts3d)
279
+ imgs = to_numpy(imgs)
280
+ focals = to_numpy(focals)
281
+ cams2world = to_numpy(cams2world)
282
+
283
+ scene = trimesh.Scene()
284
+
285
+ # add axes
286
+ scene.add_geometry(trimesh.creation.axis(axis_length=0.5, axis_radius=0.001))
287
+
288
+ # full pointcloud
289
+ if as_pointcloud:
290
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, mask)])
291
+ col = np.concatenate([p[m] for p, m in zip(imgs, mask)])
292
+ pct = trimesh.PointCloud(pts.reshape(-1, 3), colors=col.reshape(-1, 3))
293
+ scene.add_geometry(pct)
294
+ else:
295
+ meshes = []
296
+ for i in range(len(imgs)):
297
+ meshes.append(pts3d_to_trimesh(imgs[i], pts3d[i], mask[i]))
298
+ mesh = trimesh.Trimesh(**cat_meshes(meshes))
299
+ scene.add_geometry(mesh)
300
+
301
+ # add each camera
302
+ for i, pose_c2w in enumerate(cams2world):
303
+ if isinstance(cam_color, list):
304
+ camera_edge_color = cam_color[i]
305
+ else:
306
+ camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
307
+ if same_focals:
308
+ focal = focals[0]
309
+ else:
310
+ focal = focals[i]
311
+ add_scene_cam(scene, pose_c2w, camera_edge_color,
312
+ None if transparent_cams else imgs[i], focal,
313
+ imsize=imgs[i].shape[1::-1], screen_width=cam_size)
314
+
315
+ rot = np.eye(4)
316
+ rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
317
+ scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
318
+ outfile = os.path.join(outdir, 'scene.glb')
319
+ if not silent:
320
+ print('(exporting 3D scene to', outfile, ')')
321
+ scene.export(file_obj=outfile)
322
+ return outfile
323
+
324
+ @spaces.GPU(duration=120)
325
+ def get_3D_model_from_scene(outdir, silent, scene, min_conf_thr=3, as_pointcloud=False, mask_sky=False,
326
+ clean_depth=False, transparent_cams=False, cam_size=0.05, same_focals=False):
327
+ """
328
+ extract 3D_model (glb file) from a reconstructed scene
329
+ """
330
+ if scene is None:
331
+ return None
332
+ # post processes
333
+ if clean_depth:
334
+ scene = scene.clean_pointcloud()
335
+ if mask_sky:
336
+ scene = scene.mask_sky()
337
+
338
+ # get optimized values from scene
339
+ rgbimg = to_numpy(scene.imgs)
340
+ focals = to_numpy(scene.get_focals().cpu())
341
+ # cams2world = to_numpy(scene.get_im_poses().cpu())
342
+ # TODO use the vis_poses
343
+ cams2world = scene.vis_poses
344
+
345
+ # 3D pointcloud from depthmap, poses and intrinsics
346
+ # pts3d = to_numpy(scene.get_pts3d())
347
+ # TODO use the vis_poses
348
+ pts3d = scene.vis_pts3d
349
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
350
+ msk = to_numpy(scene.get_masks())
351
+
352
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
353
+ transparent_cams=transparent_cams, cam_size=cam_size, silent=silent,
354
+ same_focals=same_focals)
355
+
356
+ @spaces.GPU(duration=120)
357
+ def get_reconstructed_scene(filelist, schedule, niter, min_conf_thr,
358
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
359
+ scenegraph_type, winsize, refid, same_focals):
360
+ """
361
+ from a list of images, run dust3r inference, global aligner.
362
+ then run get_3D_model_from_scene
363
+ """
364
+ silent = False
365
+ image_size = 224
366
+ # remove the directory if it already exists
367
+ outdir = tmpdirname
368
+ if os.path.exists(outdir):
369
+ shutil.rmtree(outdir)
370
+ os.makedirs(outdir, exist_ok=True)
371
+ imgs, imgs_rgba = load_images(filelist, size=image_size, verbose=not silent, do_remove_background=True, rembg_session=rembg_session, predictor=predictor)
372
+ if len(imgs) == 1:
373
+ imgs = [imgs[0], copy.deepcopy(imgs[0])]
374
+ imgs[1]['idx'] = 1
375
+ if scenegraph_type == "swin":
376
+ scenegraph_type = scenegraph_type + "-" + str(winsize)
377
+ elif scenegraph_type == "oneref":
378
+ scenegraph_type = scenegraph_type + "-" + str(refid)
379
+
380
+ pairs = make_pairs(imgs, scene_graph=scenegraph_type, prefilter=None, symmetrize=True)
381
+ output = inference(pairs, model, device, batch_size=1, verbose=not silent)
382
+
383
+ mode = GlobalAlignerMode.PointCloudOptimizer if len(imgs) > 2 else GlobalAlignerMode.PairViewer
384
+ scene = global_aligner(output, device=device, mode=mode, verbose=not silent, same_focals=same_focals)
385
+ lr = 0.01
386
+
387
+ if mode == GlobalAlignerMode.PointCloudOptimizer:
388
+ loss = scene.compute_global_alignment(init='mst', niter=niter, schedule=schedule, lr=lr)
389
+
390
+ # outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
391
+ # clean_depth, transparent_cams, cam_size, same_focals=same_focals)
392
+
393
+ # also return rgb, depth and confidence imgs
394
+ # depth is normalized with the max value for all images
395
+ # we apply the jet colormap on the confidence maps
396
+ rgbimg = scene.imgs
397
+ # depths = to_numpy(scene.get_depthmaps())
398
+ # confs = to_numpy([c for c in scene.im_conf])
399
+ # cmap = pl.get_cmap('jet')
400
+ # depths_max = max([d.max() for d in depths])
401
+ # depths = [d / depths_max for d in depths]
402
+ # confs_max = max([d.max() for d in confs])
403
+ # confs = [cmap(d / confs_max) for d in confs]
404
+
405
+ imgs = []
406
+ rgbaimg = []
407
+ for i in range(len(rgbimg)): # when only 1 image, scene.imgs is two
408
+ imgs.append(rgbimg[i])
409
+ # imgs.append(rgb(depths[i]))
410
+ # imgs.append(rgb(confs[i]))
411
+ # imgs.append(imgs_rgba[i])
412
+ if len(imgs_rgba) == 1 and i == 1:
413
+ imgs.append(imgs_rgba[0])
414
+ rgbaimg.append(np.array(imgs_rgba[0]))
415
+ else:
416
+ imgs.append(imgs_rgba[i])
417
+ rgbaimg.append(np.array(imgs_rgba[i]))
418
+
419
+ rgbaimg = np.array(rgbaimg)
420
+
421
+ # for eschernet
422
+ # get optimized values from scene
423
+ rgbimg = to_numpy(scene.imgs)
424
+ # focals = to_numpy(scene.get_focals().cpu())
425
+ cams2world = to_numpy(scene.get_im_poses().cpu())
426
+
427
+ # 3D pointcloud from depthmap, poses and intrinsics
428
+ pts3d = to_numpy(scene.get_pts3d())
429
+ scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
430
+ msk = to_numpy(scene.get_masks())
431
+ obj_mask = rgbaimg[..., 3] > 0
432
+
433
+ # TODO set global coordinate system at the center of the scene, z-axis is up
434
+ pts = np.concatenate([p[m] for p, m in zip(pts3d, msk)]).reshape(-1, 3)
435
+ pts_obj = np.concatenate([p[m&obj_m] for p, m, obj_m in zip(pts3d, msk, obj_mask)]).reshape(-1, 3)
436
+ centroid = np.mean(pts_obj, axis=0) # obj center
437
+ obj2world = np.eye(4)
438
+ obj2world[:3, 3] = -centroid # T_wc
439
+
440
+ # get z_up vector
441
+ # TODO fit a plane and get the normal vector
442
+ pcd = o3d.geometry.PointCloud()
443
+ pcd.points = o3d.utility.Vector3dVector(pts)
444
+ plane_model, inliers = pcd.segment_plane(distance_threshold=0.01, ransac_n=3, num_iterations=1000)
445
+ # get the normalised normal vector dim = 3
446
+ normal = plane_model[:3] / np.linalg.norm(plane_model[:3])
447
+ # the normal direction should be pointing up
448
+ if normal[1] < 0:
449
+ normal = -normal
450
+ # print("normal", normal)
451
+
452
+ # # TODO z-up 180
453
+ # z_up = np.array([[1,0,0,0],
454
+ # [0,-1,0,0],
455
+ # [0,0,-1,0],
456
+ # [0,0,0,1]])
457
+ # obj2world = z_up @ obj2world
458
+
459
+ # # avg the y
460
+ # z_up_avg = cams2world[:,:3,3].sum(0) / np.linalg.norm(cams2world[:,:3,3].sum(0), axis=-1) # average direction in cam coordinate
461
+ # # import pdb; pdb.set_trace()
462
+ # rot_axis = np.cross(np.array([0, 0, 1]), z_up_avg)
463
+ # rot_angle = np.arccos(np.dot(np.array([0, 0, 1]), z_up_avg) / (np.linalg.norm(z_up_avg) + 1e-6))
464
+ # rot = Rotation.from_rotvec(rot_angle * rot_axis)
465
+ # z_up = np.eye(4)
466
+ # z_up[:3, :3] = rot.as_matrix()
467
+
468
+ # get the rotation matrix from normal to z-axis
469
+ z_axis = np.array([0, 0, 1])
470
+ rot_axis = np.cross(normal, z_axis)
471
+ rot_angle = np.arccos(np.dot(normal, z_axis) / (np.linalg.norm(normal) + 1e-6))
472
+ rot = Rotation.from_rotvec(rot_angle * rot_axis)
473
+ z_up = np.eye(4)
474
+ z_up[:3, :3] = rot.as_matrix()
475
+ obj2world = z_up @ obj2world
476
+ # flip 180
477
+ flip_rot = np.array([[1, 0, 0, 0],
478
+ [0, -1, 0, 0],
479
+ [0, 0, -1, 0],
480
+ [0, 0, 0, 1]])
481
+ obj2world = flip_rot @ obj2world
482
+
483
+ # get new cams2obj
484
+ cams2obj = []
485
+ for i, cam2world in enumerate(cams2world):
486
+ cams2obj.append(obj2world @ cam2world)
487
+ # TODO transform pts3d to the new coordinate system
488
+ for i, pts in enumerate(pts3d):
489
+ pts3d[i] = (obj2world @ np.concatenate([pts, np.ones_like(pts)[..., :1]], axis=-1).transpose(2, 0, 1).reshape(4,
490
+ -1)) \
491
+ .reshape(4, pts.shape[0], pts.shape[1]).transpose(1, 2, 0)[..., :3]
492
+ cams2world = np.array(cams2obj)
493
+ # TODO rewrite hack
494
+ scene.vis_poses = cams2world.copy()
495
+ scene.vis_pts3d = pts3d.copy()
496
+
497
+ # TODO save cams2world and rgbimg to each file, file name "000.npy", "001.npy", ... and "000.png", "001.png", ...
498
+ for i, (img, img_rgba, pose) in enumerate(zip(rgbimg, rgbaimg, cams2world)):
499
+ np.save(os.path.join(outdir, f"{i:03d}.npy"), pose)
500
+ pl.imsave(os.path.join(outdir, f"{i:03d}.png"), img)
501
+ pl.imsave(os.path.join(outdir, f"{i:03d}_rgba.png"), img_rgba)
502
+ # np.save(os.path.join(outdir, f"{i:03d}_focal.npy"), to_numpy(focal))
503
+ # save the min/max radius of camera
504
+ radii = np.linalg.norm(np.linalg.inv(cams2world)[..., :3, 3])
505
+ np.save(os.path.join(outdir, "radii.npy"), radii)
506
+
507
+ eschernet_input = {"poses": cams2world,
508
+ "radii": radii,
509
+ "imgs": rgbaimg}
510
+ print("got eschernet input")
511
+ outfile = get_3D_model_from_scene(outdir, silent, scene, min_conf_thr, as_pointcloud, mask_sky,
512
+ clean_depth, transparent_cams, cam_size, same_focals=same_focals)
513
+
514
+ return scene, outfile, imgs, eschernet_input
515
+
516
+
517
+ def set_scenegraph_options(inputfiles, winsize, refid, scenegraph_type):
518
+ num_files = len(inputfiles) if inputfiles is not None else 1
519
+ max_winsize = max(1, math.ceil((num_files - 1) / 2))
520
+ if scenegraph_type == "swin":
521
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
522
+ minimum=1, maximum=max_winsize, step=1, visible=True)
523
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
524
+ maximum=num_files - 1, step=1, visible=False)
525
+ elif scenegraph_type == "oneref":
526
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
527
+ minimum=1, maximum=max_winsize, step=1, visible=False)
528
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
529
+ maximum=num_files - 1, step=1, visible=True)
530
+ else:
531
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=max_winsize,
532
+ minimum=1, maximum=max_winsize, step=1, visible=False)
533
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0,
534
+ maximum=num_files - 1, step=1, visible=False)
535
+ return winsize, refid
536
+
537
+
538
+ def get_examples(path):
539
+ objs = []
540
+ for obj_name in sorted(os.listdir(path)):
541
+ img_files = []
542
+ for img_file in sorted(os.listdir(os.path.join(path, obj_name))):
543
+ img_files.append(os.path.join(path, obj_name, img_file))
544
+ objs.append([img_files])
545
+ print("objs = ", objs)
546
+ return objs
547
+
548
+ def preview_input(inputfiles):
549
+ if inputfiles is None:
550
+ return None
551
+ imgs = []
552
+ for img_file in inputfiles:
553
+ img = pl.imread(img_file)
554
+ imgs.append(img)
555
+ return imgs
556
+
557
+ # def main():
558
+ # dustr init
559
+ silent = False
560
+ image_size = 224
561
+ weights_path = 'checkpoints/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth'
562
+ model = AsymmetricCroCo3DStereo.from_pretrained(weights_path).to(device)
563
+ # dust3r will write the 3D model inside tmpdirname
564
+ # with tempfile.TemporaryDirectory(suffix='dust3r_gradio_demo') as tmpdirname:
565
+ tmpdirname = os.path.join('logs/user_object')
566
+ # remove the directory if it already exists
567
+ if os.path.exists(tmpdirname):
568
+ shutil.rmtree(tmpdirname)
569
+ os.makedirs(tmpdirname, exist_ok=True)
570
+ if not silent:
571
+ print('Outputing stuff in', tmpdirname)
572
+
573
+ _HEADER_ = '''
574
+ <h2><b>[CVPR'24 Oral] EscherNet: A Generative Model for Scalable View Synthesis</b></h2>
575
+ <b>EscherNet</b> is a multiview diffusion model for scalable generative any-to-any number/pose novel view synthesis.
576
+
577
+ Image views are treated as tokens and the camera pose is encoded by <b>CaPE (Camera Positional Encoding)</b>.
578
+
579
+ <a href='https://kxhit.github.io/EscherNet' target='_blank'>Project</a> <b>|</b>
580
+ <a href='https://github.com/kxhit/EscherNet' target='_blank'>GitHub</a> <b>|</b>
581
+ <a href='https://arxiv.org/abs/2402.03908' target='_blank'>ArXiv</a>
582
+
583
+ <h4><b>Tips:</b></h4>
584
+
585
+ - Our model can take <b>any number input images</b>. The more images you provide <b>(>=3 for this demo)</b>, the better the results.
586
+
587
+ - Our model can generate <b>any number and any pose</b> novel views. You can specify the number of views you want to generate. In this demo, we set novel views on an <b>archemedian spiral</b> for simplicity.
588
+
589
+ - The pose estimation is done using <a href='https://github.com/naver/dust3r' target='_blank'>DUSt3R</a>. You can also provide your own poses or get pose via any SLAM system.
590
+
591
+ - The current checkpoint supports 6DoF camera pose and is trained on 30k 3D <a href='https://objaverse.allenai.org/' target='_blank'>Objaverse</a> objects for demo. Scaling is on the roadmap!
592
+
593
+ '''
594
+
595
+ _CITE_ = r"""
596
+ 📝 <b>Citation</b>:
597
+ ```bibtex
598
+ @article{kong2024eschernet,
599
+ title={EscherNet: A Generative Model for Scalable View Synthesis},
600
+ author={Kong, Xin and Liu, Shikun and Lyu, Xiaoyang and Taher, Marwan and Qi, Xiaojuan and Davison, Andrew J},
601
+ journal={arXiv preprint arXiv:2402.03908},
602
+ year={2024}
603
+ }
604
+ ```
605
+ """
606
+
607
+ with gr.Blocks() as demo:
608
+ gr.Markdown(_HEADER_)
609
+ # mv_images = gr.State()
610
+ scene = gr.State(None)
611
+ eschernet_input = gr.State(None)
612
+ with gr.Row(variant="panel"):
613
+ # left column
614
  with gr.Column():
615
+ with gr.Row():
616
+ input_image = gr.File(file_count="multiple")
617
+ with gr.Row():
618
+ run_dust3r = gr.Button("Get Pose!", elem_id="dust3r")
619
+ with gr.Row():
620
+ processed_image = gr.Gallery(label='Input Views', columns=2, height="100%")
621
+ with gr.Row(variant="panel"):
622
+ # input examples under "examples" folder
623
+ gr.Examples(
624
+ examples=get_examples('examples'),
625
+ inputs=[input_image],
626
+ label="Examples (click one set of images to start!)",
627
+ examples_per_page=20
628
+ )
629
+
630
+
631
+
632
+
633
+
634
+ # right column
635
  with gr.Column():
 
 
 
 
 
 
636
 
637
+ with gr.Row():
638
+ outmodel = gr.Model3D()
639
+
640
+ with gr.Row():
641
+ gr.Markdown('''
642
+ <h4><b>Check if the pose (blue is axis is estimated z-up direction) and segmentation looks correct. If not, remove the incorrect images and try again.</b></h4>
643
+ ''')
644
+
645
+ with gr.Row():
646
+ with gr.Group():
647
+ do_remove_background = gr.Checkbox(
648
+ label="Remove Background", value=True
649
+ )
650
+ sample_seed = gr.Number(value=42, label="Seed Value", precision=0)
651
+
652
+ sample_steps = gr.Slider(
653
+ label="Sample Steps",
654
+ minimum=30,
655
+ maximum=75,
656
+ value=50,
657
+ step=5,
658
+ visible=False
659
+ )
660
+
661
+ nvs_num = gr.Slider(
662
+ label="Number of Novel Views",
663
+ minimum=5,
664
+ maximum=100,
665
+ value=30,
666
+ step=1
667
+ )
668
+
669
+ nvs_mode = gr.Dropdown(["archimedes circle"], # "fixed 4 views", "fixed 8 views"
670
+ value="archimedes circle", label="Novel Views Pose Chosen", visible=True)
671
+
672
+ with gr.Row():
673
+ gr.Markdown('''
674
+ <h4><b>Choose your desired novel view poses number and generate! The more output images the longer it takes.</b></h4>
675
+ ''')
676
+
677
+ with gr.Row():
678
+ submit = gr.Button("Submit", elem_id="eschernet", variant="primary")
679
+
680
+ with gr.Row():
681
+ with gr.Column():
682
+ output_video = gr.Video(
683
+ label="video", format="mp4",
684
+ width=379,
685
+ autoplay=True,
686
+ interactive=False
687
+ )
688
+
689
+ with gr.Row():
690
+ gr.Markdown('''
691
+ <h4><b>The novel views are generated on an archimedean spiral (rotating around z-up axis and looking at the object center). You can download the video.</b></h4>
692
+ ''')
693
+
694
+ gr.Markdown(_CITE_)
695
+
696
+ # set dust3r parameter invisible to be clean
697
+ with gr.Column():
698
+ with gr.Row():
699
+ schedule = gr.Dropdown(["linear", "cosine"],
700
+ value='linear', label="schedule", info="For global alignment!", visible=False)
701
+ niter = gr.Number(value=300, precision=0, minimum=0, maximum=5000,
702
+ label="num_iterations", info="For global alignment!", visible=False)
703
+ scenegraph_type = gr.Dropdown(["complete", "swin", "oneref"],
704
+ value='complete', label="Scenegraph",
705
+ info="Define how to make pairs",
706
+ interactive=True, visible=False)
707
+ same_focals = gr.Checkbox(value=True, label="Focal", info="Use the same focal for all cameras", visible=False)
708
+ winsize = gr.Slider(label="Scene Graph: Window Size", value=1,
709
+ minimum=1, maximum=1, step=1, visible=False)
710
+ refid = gr.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
711
+
712
+ with gr.Row():
713
+ # adjust the confidence threshold
714
+ min_conf_thr = gr.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
715
+ # adjust the camera size in the output pointcloud
716
+ cam_size = gr.Slider(label="cam_size", value=0.05, minimum=0.01, maximum=0.5, step=0.001, visible=False)
717
+ with gr.Row():
718
+ as_pointcloud = gr.Checkbox(value=False, label="As pointcloud", visible=False)
719
+ # two post process implemented
720
+ mask_sky = gr.Checkbox(value=False, label="Mask sky", visible=False)
721
+ clean_depth = gr.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
722
+ transparent_cams = gr.Checkbox(value=False, label="Transparent cameras", visible=False)
723
+
724
+ # events
725
+ # scenegraph_type.change(set_scenegraph_options,
726
+ # inputs=[input_image, winsize, refid, scenegraph_type],
727
+ # outputs=[winsize, refid])
728
+ # min_conf_thr.release(fn=model_from_scene_fun,
729
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
730
+ # clean_depth, transparent_cams, cam_size, same_focals],
731
+ # outputs=outmodel)
732
+ # cam_size.change(fn=model_from_scene_fun,
733
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
734
+ # clean_depth, transparent_cams, cam_size, same_focals],
735
+ # outputs=outmodel)
736
+ # as_pointcloud.change(fn=model_from_scene_fun,
737
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
738
+ # clean_depth, transparent_cams, cam_size, same_focals],
739
+ # outputs=outmodel)
740
+ # mask_sky.change(fn=model_from_scene_fun,
741
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
742
+ # clean_depth, transparent_cams, cam_size, same_focals],
743
+ # outputs=outmodel)
744
+ # clean_depth.change(fn=model_from_scene_fun,
745
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
746
+ # clean_depth, transparent_cams, cam_size, same_focals],
747
+ # outputs=outmodel)
748
+ # transparent_cams.change(model_from_scene_fun,
749
+ # inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
750
+ # clean_depth, transparent_cams, cam_size, same_focals],
751
+ # outputs=outmodel)
752
+ # run_dust3r.click(fn=recon_fun,
753
+ # inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
754
+ # mask_sky, clean_depth, transparent_cams, cam_size,
755
+ # scenegraph_type, winsize, refid, same_focals],
756
+ # outputs=[scene, outmodel, processed_image, eschernet_input])
757
+
758
+ # events
759
+ input_image.change(set_scenegraph_options,
760
+ inputs=[input_image, winsize, refid, scenegraph_type],
761
+ outputs=[winsize, refid])
762
+ run_dust3r.click(fn=get_reconstructed_scene,
763
+ inputs=[input_image, schedule, niter, min_conf_thr, as_pointcloud,
764
+ mask_sky, clean_depth, transparent_cams, cam_size,
765
+ scenegraph_type, winsize, refid, same_focals],
766
+ outputs=[scene, outmodel, processed_image, eschernet_input])
767
+
768
+
769
+ # events
770
+ input_image.change(fn=preview_input,
771
+ inputs=[input_image],
772
+ outputs=[processed_image])
773
+
774
+ submit.click(fn=run_eschernet,
775
+ inputs=[eschernet_input, sample_steps, sample_seed,
776
+ nvs_num, nvs_mode],
777
+ outputs=[output_video])
778
+
779
+
780
+
781
+ # demo.queue(max_size=10)
782
+ # demo.launch(share=True, server_name="0.0.0.0", server_port=None)
783
+ demo.queue(max_size=10).launch()
784
 
785
+ # if __name__ == '__main__':
786
+ # main()