Ryukijano commited on
Commit
75adee0
β€’
1 Parent(s): 592dcaa

Cooked mast3r demo

Browse files

πŸ§‘β€πŸ’»πŸ§‘β€πŸ’»
![NasralibyjferoGIF.gif](https://cdn-uploads.huggingface.co/production/uploads/623c636949b6a399ee11152e/yi3P8Ifqgr9W39BtzgNNV.gif)

Files changed (1) hide show
  1. app.py +144 -5
app.py CHANGED
@@ -1,7 +1,146 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import sys
4
+ import os.path as path
5
+ import torch
6
+ import tempfile
7
+ import gradio
8
+ import shutil
9
+ import math
10
 
11
+ HERE_PATH = path.normpath(path.dirname(__file__)) # noqa
12
+ MASt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, './mast3r')) # noqa
13
+ sys.path.insert(0, MASt3R_REPO_PATH) # noqa
14
 
15
+ from mast3r.demo import get_reconstructed_scene
16
+ from mast3r.model import AsymmetricMASt3R
17
+ from mast3r.utils.misc import hash_md5
18
+
19
+ import mast3r.utils.path_to_dust3r # noqa
20
+ from dust3r.demo import set_print_with_timestamp
21
+
22
+ import matplotlib.pyplot as pl
23
+ pl.ion()
24
+
25
+ # for gpu >= Ampere and pytorch >= 1.12
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ batch_size = 1
28
+ set_print_with_timestamp()
29
+
30
+ weights_path = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric"
31
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+ model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
33
+ chkpt_tag = hash_md5(weights_path)
34
+
35
+ tmpdirname = tempfile.mkdtemp(suffix='_mast3r_gradio_demo')
36
+ image_size = 512
37
+ silent = True
38
+ gradio_delete_cache = 7200
39
+
40
+
41
+ class FileState:
42
+ def __init__(self, outfile_name=None):
43
+ self.outfile_name = outfile_name
44
+
45
+ def __del__(self):
46
+ if self.outfile_name is not None and os.path.isfile(self.outfile_name):
47
+ os.remove(self.outfile_name)
48
+ self.outfile_name = None
49
+
50
+
51
+ @spaces.GPU(duration=180)
52
+ def local_get_reconstructed_scene(filelist, min_conf_thr, matching_conf_thr,
53
+ as_pointcloud, cam_size,
54
+ shared_intrinsics, **kw):
55
+ lr1 = 0.07
56
+ niter1 = 500
57
+ lr2 = 0.014
58
+ niter2 = 200
59
+ optim_level = 'refine'
60
+ mask_sky, clean_depth, transparent_cams = False, True, False
61
+ if len(filelist) < 5:
62
+ scenegraph_type = 'complete'
63
+ winsize = 1
64
+ else:
65
+ scenegraph_type = 'logwin'
66
+ half_size = math.ceil((len(filelist) - 1) / 2)
67
+ max_winsize = max(1, math.ceil(math.log(half_size, 2)))
68
+ winsize = min(5, max_winsize)
69
+ refid = 0
70
+ win_cyclic = False
71
+ scene_state, outfile = get_reconstructed_scene(tmpdirname, gradio_delete_cache, model, device, silent, image_size, None,
72
+ filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
73
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
74
+ win_cyclic, refid, TSDF_thresh=0, shared_intrinsics=shared_intrinsics, **kw)
75
+ filestate = FileState(scene_state.outfile_name)
76
+ scene_state.outfile_name = None
77
+ del scene_state
78
+ return filestate, outfile
79
+
80
+
81
+ def run_example(snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, filelist, **kw):
82
+ return local_get_reconstructed_scene(filelist, min_conf_thr, matching_conf_thr, as_pointcloud, cam_size, shared_intrinsics, **kw)
83
+
84
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
85
+ title = "MASt3R Demo"
86
+ with gradio.Blocks(css=css, title=title, delete_cache=(gradio_delete_cache, gradio_delete_cache)) as demo:
87
+ filestate = gradio.State(None)
88
+ gradio.HTML('<h2 style="text-align: center;">3D Reconstruction with MASt3R</h2>')
89
+ gradio.HTML('<p>Upload one or multiple images (wait for them to be fully uploaded before hitting the run button). '
90
+ 'We tested with up to 18 images before running into the allocation timeout - set at 3 minutes but your mileage may vary. '
91
+ 'At the very bottom of this page, you will find an example. If you click on it, it will pull the 3D reconstruction from 7 images of the small Naver Labs Europe tower from cache. '
92
+ 'If you want to try larger image collections, you can find the more complete version of this demo that you can run locally '
93
+ 'and more details about the method at <a href="https://github.com/naver/mast3r">github.com/naver/mast3r</a>. '
94
+ 'The checkpoint used in this demo is available at <a href="https://huggingface.co/naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric">huggingface.co/naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric</a>.</p>')
95
+ with gradio.Column():
96
+ inputfiles = gradio.File(file_count="multiple")
97
+ snapshot = gradio.Image(None, visible=False)
98
+ with gradio.Row():
99
+ matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=2.,
100
+ minimum=0., maximum=30., step=0.1,
101
+ info="Before Fallback to Regr3D!")
102
+ # adjust the confidence threshold
103
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
104
+ # adjust the camera size in the output pointcloud
105
+ cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
106
+ with gradio.Row():
107
+ as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
108
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
109
+ info="Only optimize one set of intrinsics for all views")
110
+ run_btn = gradio.Button("Run")
111
+ outmodel = gradio.Model3D()
112
+
113
+ examples = gradio.Examples(
114
+ examples=[
115
+ [
116
+ os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/FF5599FD-768B-431A-AB83-BDA5FB44CB9D-83120-000041DADDE35483.jpg'),
117
+ 0.0, 1.5, 0.2, True, False,
118
+ [os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/01D90321-69C8-439F-B0B0-E87E7634741C-83120-000041DAE419D7AE.jpg'),
119
+ os.path.join(
120
+ HERE_PATH, 'mast3r/assets/NLE_tower/1AD85EF5-B651-4291-A5C0-7BDB7D966384-83120-000041DADF639E09.jpg'),
121
+ os.path.join(
122
+ HERE_PATH, 'mast3r/assets/NLE_tower/28EDBB63-B9F9-42FB-AC86-4852A33ED71B-83120-000041DAF22407A1.jpg'),
123
+ os.path.join(
124
+ HERE_PATH, 'mast3r/assets/NLE_tower/91E9B685-7A7D-42D7-B933-23A800EE4129-83120-000041DAE12C8176.jpg'),
125
+ os.path.join(
126
+ HERE_PATH, 'mast3r/assets/NLE_tower/2679C386-1DC0-4443-81B5-93D7EDE4AB37-83120-000041DADB2EA917.jpg'),
127
+ os.path.join(
128
+ HERE_PATH, 'mast3r/assets/NLE_tower/CDBBD885-54C3-4EB4-9181-226059A60EE0-83120-000041DAE0C3D612.jpg'),
129
+ os.path.join(HERE_PATH, 'mast3r/assets/NLE_tower/FF5599FD-768B-431A-AB83-BDA5FB44CB9D-83120-000041DADDE35483.jpg')]
130
+ ]
131
+ ],
132
+ inputs=[snapshot, matching_conf_thr, min_conf_thr, cam_size, as_pointcloud, shared_intrinsics, inputfiles],
133
+ outputs=[filestate, outmodel],
134
+ fn=run_example,
135
+ cache_examples="lazy",
136
+ )
137
+
138
+ # events
139
+ run_btn.click(fn=local_get_reconstructed_scene,
140
+ inputs=[inputfiles, min_conf_thr, matching_conf_thr,
141
+ as_pointcloud,
142
+ cam_size, shared_intrinsics],
143
+ outputs=[filestate, outmodel])
144
+
145
+ demo.launch(show_error=True, share=None, server_name=None, server_port=None)
146
+ shutil.rmtree(tmpdirname)