Cooked mast3r demo
Browse filesπ§βπ»π§βπ»
![NasralibyjferoGIF.gif](https://cdn-uploads.huggingface.co/production/uploads/623c636949b6a399ee11152e/yi3P8Ifqgr9W39BtzgNNV.gif)
app.py
CHANGED
@@ -1,7 +1,146 @@
|
|
1 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import os.path as path
|
5 |
+
import torch
|
6 |
+
import tempfile
|
7 |
+
import gradio
|
8 |
+
import shutil
|
9 |
+
import math
|
10 |
|
11 |
+
HERE_PATH = path.normpath(path.dirname(__file__)) # noqa
|
12 |
+
MASt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, './mast3r')) # noqa
|
13 |
+
sys.path.insert(0, MASt3R_REPO_PATH) # noqa
|
14 |
|
15 |
+
from mast3r.demo import get_reconstructed_scene
|
16 |
+
from mast3r.model import AsymmetricMASt3R
|
17 |
+
from mast3r.utils.misc import hash_md5
|
18 |
+
|
19 |
+
import mast3r.utils.path_to_dust3r # noqa
|
20 |
+
from dust3r.demo import set_print_with_timestamp
|
21 |
+
|
22 |
+
import matplotlib.pyplot as pl
|
23 |
+
pl.ion()
|
24 |
+
|
25 |
+
# for gpu >= Ampere and pytorch >= 1.12
|
26 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
27 |
+
batch_size = 1
|
28 |
+
set_print_with_timestamp()
|
29 |
+
|
30 |
+
weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
|
31 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
32 |
+
model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
|
33 |
+
chkpt_tag = hash_md5(weights_path)
|
34 |
+
|
35 |
+
tmpdirname = tempfile.mkdtemp(suffix='_mast3r_gradio_demo')
|
36 |
+
image_size = 512
|
37 |
+
silent = True
|
38 |
+
gradio_delete_cache = 7200
|
39 |
+
|
40 |
+
|
41 |
+
class FileState:
|
42 |
+
def __init__(self, outfile_name=None):
|
43 |
+
self.outfile_name = outfile_name
|
44 |
+
|
45 |
+
def __del__(self):
|
46 |
+
if self.outfile_name is not None and os.path.isfile(self.outfile_name):
|
47 |
+
os.remove(self.outfile_name)
|
48 |
+
self.outfile_name = None
|
49 |
+
|
50 |
+
|
51 |
+
@spaces.GPU(duration=180)
|
52 |
+
def local_get_reconstructed_scene(filelist, min_conf_thr, matching_conf_thr,
|
53 |
+
as_pointcloud, cam_size,
|
54 |
+
shared_intrinsics, **kw):
|
55 |
+
lr1 = 0.07
|
56 |
+
niter1 = 500
|
57 |
+
lr2 = 0.014
|
58 |
+
niter2 = 200
|
59 |
+
optim_level = 'refine'
|
60 |
+
mask_sky, clean_depth, transparent_cams = False, True, False
|
61 |
+
if len(filelist) < 5:
|
62 |
+
scenegraph_type = 'complete'
|
63 |
+
winsize = 1
|
64 |
+
else:
|
65 |
+
scenegraph_type = 'logwin'
|
66 |
+
half_size = math.ceil((len(filelist) - 1) / 2)
|
67 |
+
max_winsize = max(1, math.ceil(math.log(half_size, 2)))
|
68 |
+
winsize = min(5, max_winsize)
|
69 |
+
refid = 0
|
70 |
+
win_cyclic = False
|
71 |
+
scene_state, outfile = get_reconstructed_scene(tmpdirname, gradio_delete_cache, model, device, silent, image_size, None,
|
72 |
+
filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
|
73 |
+
as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
|
74 |
+
win_cyclic, refid, TSDF_thresh=0, shared_intrinsics=shared_intrinsics, **kw)
|
75 |
+
filestate = FileState(scene_state.outfile_name)
|
76 |
+
scene_state.outfile_name = None
|
77 |
+
del scene_state
|
78 |
+
return filestate, outfile
|
79 |
+
|
80 |
+
|
81 |
+
def run_example(snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, filelist, **kw):
|
82 |
+
return local_get_reconstructed_scene(filelist, min_conf_thr, matching_conf_thr, as_pointcloud, cam_size, shared_intrinsics, **kw)
|
83 |
+
|
84 |
+
css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
|
85 |
+
title = "MASt3R Demo"
|
86 |
+
with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo:
|
87 |
+
filestate = gradio.State(None)
|
88 |
+
gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with MASt3R</h2>')
|
89 |
+
gradio.HTML('<p>Upload one or multiple images (wait for them to be fully uploaded before hitting the run button). '
|
90 |
+
'We tested with up to 18 images before running into the allocation timeout - set at 3 minutes but your mileage may vary. '
|
91 |
+
'At the very bottom of this page, you will find an example. If you click on it, it will pull the 3D reconstruction from 7 images of the small Naver Labs Europe tower from cache. '
|
92 |
+
'If you want to try larger image collections, you can find the more complete version of this demo that you can run locally '
|
93 |
+
'and more details about the method at <a href="https://github.com/naver/mast3r">github.com/naver/mast3r</a>. '
|
94 |
+
'The checkpoint used in this demo is available at <a href="https://huggingface.co/naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric">huggingface.co/naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric</a>.</p>')
|
95 |
+
with gradio.Column():
|
96 |
+
inputfiles = gradio.File(file_count="multiple")
|
97 |
+
snapshot = gradio.Image(None, visible=False)
|
98 |
+
with gradio.Row():
|
99 |
+
matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=2.,
|
100 |
+
minimum=0., maximum=30., step=0.1,
|
101 |
+
info="Before Fallback to Regr3D!")
|
102 |
+
# adjust the confidence threshold
|
103 |
+
min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
|
104 |
+
# adjust the camera size in the output pointcloud
|
105 |
+
cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
|
106 |
+
with gradio.Row():
|
107 |
+
as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
|
108 |
+
shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
|
109 |
+
info="Only optimize one set of intrinsics for all views")
|
110 |
+
run_btn = gradio.Button("Run")
|
111 |
+
outmodel = gradio.Model3D()
|
112 |
+
|
113 |
+
examples = gradio.Examples(
|
114 |
+
examples=[
|
115 |
+
[
|
116 |
+
os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/FF5599FD-768B-431A-AB83-BDA5FB44CB9D-83120-000041DADDE35483.jpg'),
|
117 |
+
0.0, 1.5, 0.2, True, False,
|
118 |
+
[os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/01D90321-69C8-439F-B0B0-E87E7634741C-83120-000041DAE419D7AE.jpg'),
|
119 |
+
os.path.join(
|
120 |
+
HERE_PATH, 'mast3r/assets/NLE_tower/1AD85EF5-B651-4291-A5C0-7BDB7D966384-83120-000041DADF639E09.jpg'),
|
121 |
+
os.path.join(
|
122 |
+
HERE_PATH, 'mast3r/assets/NLE_tower/28EDBB63-B9F9-42FB-AC86-4852A33ED71B-83120-000041DAF22407A1.jpg'),
|
123 |
+
os.path.join(
|
124 |
+
HERE_PATH, 'mast3r/assets/NLE_tower/91E9B685-7A7D-42D7-B933-23A800EE4129-83120-000041DAE12C8176.jpg'),
|
125 |
+
os.path.join(
|
126 |
+
HERE_PATH, 'mast3r/assets/NLE_tower/2679C386-1DC0-4443-81B5-93D7EDE4AB37-83120-000041DADB2EA917.jpg'),
|
127 |
+
os.path.join(
|
128 |
+
HERE_PATH, 'mast3r/assets/NLE_tower/CDBBD885-54C3-4EB4-9181-226059A60EE0-83120-000041DAE0C3D612.jpg'),
|
129 |
+
os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/FF5599FD-768B-431A-AB83-BDA5FB44CB9D-83120-000041DADDE35483.jpg')]
|
130 |
+
]
|
131 |
+
],
|
132 |
+
inputs=[snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, inputfiles],
|
133 |
+
outputs=[filestate, outmodel],
|
134 |
+
fn=run_example,
|
135 |
+
cache_examples="lazy",
|
136 |
+
)
|
137 |
+
|
138 |
+
# events
|
139 |
+
run_btn.click(fn=local_get_reconstructed_scene,
|
140 |
+
inputs=[inputfiles, min_conf_thr, matching_conf_thr,
|
141 |
+
as_pointcloud,
|
142 |
+
cam_size, shared_intrinsics],
|
143 |
+
outputs=[filestate, outmodel])
|
144 |
+
|
145 |
+
demo.launch(show_error=True, share=None, server_name=None, server_port=None)
|
146 |
+
shutil.rmtree(tmpdirname)
|