yocabon commited on
Commit
7494687
1 Parent(s): e83109b

try without partial

Browse files
Files changed (1) hide show
  1. app.py +137 -11
app.py CHANGED
@@ -10,16 +10,13 @@ import sys
10
  import os.path as path
11
  import torch
12
  import tempfile
 
13
 
14
  HERE_PATH = path.normpath(path.dirname(__file__)) # noqa
15
  MASt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, './mast3r')) # noqa
16
  sys.path.insert(0, MASt3R_REPO_PATH) # noqa
17
 
18
- import mast3r.demo
19
- mast3r.demo.get_reconstructed_scene = spaces.GPU(mast3r.demo.get_reconstructed_scene)
20
- mast3r.demo.get_3D_model_from_scene = spaces.GPU(mast3r.demo.get_3D_model_from_scene)
21
-
22
- from mast3r.demo import main_demo
23
  from mast3r.model import AsymmetricMASt3R
24
  from mast3r.utils.misc import hash_md5
25
 
@@ -35,9 +32,138 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
35
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
36
  chkpt_tag = hash_md5(weights_path)
37
 
38
- # mast3r will write the 3D model inside tmpdirname/chkpt_tag
39
- with tempfile.TemporaryDirectory(suffix='_mast3r_gradio_demo') as tmpdirname:
40
- cache_path = os.path.join(tmpdirname, chkpt_tag)
41
- os.makedirs(cache_path, exist_ok=True)
42
- main_demo(tmpdirname, model, device, 512, server_name=None, server_port=None,
43
- silent=True, share=None, gradio_delete_cache=7200)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import os.path as path
11
  import torch
12
  import tempfile
13
+ import gradio
14
 
15
  HERE_PATH = path.normpath(path.dirname(__file__)) # noqa
16
  MASt3R_REPO_PATH = path.normpath(path.join(HERE_PATH, './mast3r')) # noqa
17
  sys.path.insert(0, MASt3R_REPO_PATH) # noqa
18
 
19
+ from mast3r.demo import get_reconstructed_scene, get_3D_model_from_scene, set_scenegraph_options
 
 
 
 
20
  from mast3r.model import AsymmetricMASt3R
21
  from mast3r.utils.misc import hash_md5
22
 
 
32
  model = AsymmetricMASt3R.from_pretrained(weights_path).to(device)
33
  chkpt_tag = hash_md5(weights_path)
34
 
35
+ tmpdirname = "tmp/gradio"
36
+ image_size = 512
37
+ silent = True
38
+ gradio_delete_cache = 7200
39
+
40
+
41
+ @spaces.GPU()
42
+ def local_get_reconstructed_scene(current_scene_state,
43
+ filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
44
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
45
+ win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw):
46
+ return get_reconstructed_scene(tmpdirname, gradio_delete_cache, model, device, silent, image_size, current_scene_state,
47
+ filelist, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
48
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size, scenegraph_type, winsize,
49
+ win_cyclic, refid, TSDF_thresh, shared_intrinsics, **kw)
50
+
51
+
52
+ @spaces.GPU()
53
+ def local_get_3D_model_from_scene(scene_state, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
54
+ clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
55
+ return get_3D_model_from_scene(silent, scene_state, min_conf_thr, as_pointcloud, mask_sky,
56
+ clean_depth, transparent_cams, cam_size, TSDF_thresh)
57
+
58
+
59
+ recon_fun = local_get_reconstructed_scene
60
+ model_from_scene_fun = local_get_3D_model_from_scene
61
+
62
+
63
+ def get_context(delete_cache):
64
+ css = """.gradio-container {margin: 0 !important; min-width: 100%};"""
65
+ title = "MASt3R Demo"
66
+ if delete_cache:
67
+ return gradio.Blocks(css=css, title=title, delete_cache=(delete_cache, delete_cache))
68
+ else:
69
+ return gradio.Blocks(css=css, title="MASt3R Demo") # for compatibility with older versions
70
+
71
+
72
+ with get_context(gradio_delete_cache) as demo:
73
+ # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
74
+ scene = gradio.State(None)
75
+ gradio.HTML('<h2 style="text-align: center;">MASt3R Demo</h2>')
76
+ with gradio.Column():
77
+ inputfiles = gradio.File(file_count="multiple")
78
+ with gradio.Row():
79
+ with gradio.Column():
80
+ with gradio.Row():
81
+ lr1 = gradio.Slider(label="Coarse LR", value=0.07, minimum=0.01, maximum=0.2, step=0.01)
82
+ niter1 = gradio.Number(value=500, precision=0, minimum=0, maximum=10_000,
83
+ label="num_iterations", info="For coarse alignment!")
84
+ lr2 = gradio.Slider(label="Fine LR", value=0.014, minimum=0.005, maximum=0.05, step=0.001)
85
+ niter2 = gradio.Number(value=200, precision=0, minimum=0, maximum=100_000,
86
+ label="num_iterations", info="For refinement!")
87
+ optim_level = gradio.Dropdown(["coarse", "refine", "refine+depth"],
88
+ value='refine', label="OptLevel",
89
+ info="Optimization level")
90
+ with gradio.Row():
91
+ matching_conf_thr = gradio.Slider(label="Matching Confidence Thr", value=5.,
92
+ minimum=0., maximum=30., step=0.1,
93
+ info="Before Fallback to Regr3D!")
94
+ shared_intrinsics = gradio.Checkbox(value=False, label="Shared intrinsics",
95
+ info="Only optimize one set of intrinsics for all views")
96
+ scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
97
+ ("swin: sliding window", "swin"),
98
+ ("logwin: sliding window with long range", "logwin"),
99
+ ("oneref: match one image with all", "oneref")],
100
+ value='complete', label="Scenegraph",
101
+ info="Define how to make pairs",
102
+ interactive=True)
103
+ with gradio.Column(visible=False) as win_col:
104
+ winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
105
+ minimum=1, maximum=1, step=1)
106
+ win_cyclic = gradio.Checkbox(value=False, label="Cyclic sequence")
107
+ refid = gradio.Slider(label="Scene Graph: Id", value=0,
108
+ minimum=0, maximum=0, step=1, visible=False)
109
+ run_btn = gradio.Button("Run")
110
+
111
+ with gradio.Row():
112
+ # adjust the confidence threshold
113
+ min_conf_thr = gradio.Slider(label="min_conf_thr", value=1.5, minimum=0.0, maximum=10, step=0.1)
114
+ # adjust the camera size in the output pointcloud
115
+ cam_size = gradio.Slider(label="cam_size", value=0.2, minimum=0.001, maximum=1.0, step=0.001)
116
+ TSDF_thresh = gradio.Slider(label="TSDF Threshold", value=0., minimum=0., maximum=1., step=0.01)
117
+ with gradio.Row():
118
+ as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud")
119
+ # two post process implemented
120
+ mask_sky = gradio.Checkbox(value=False, label="Mask sky")
121
+ clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps")
122
+ transparent_cams = gradio.Checkbox(value=False, label="Transparent cameras")
123
+
124
+ outmodel = gradio.Model3D()
125
+
126
+ # events
127
+ scenegraph_type.change(set_scenegraph_options,
128
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
129
+ outputs=[win_col, winsize, win_cyclic, refid])
130
+ inputfiles.change(set_scenegraph_options,
131
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
132
+ outputs=[win_col, winsize, win_cyclic, refid])
133
+ win_cyclic.change(set_scenegraph_options,
134
+ inputs=[inputfiles, win_cyclic, refid, scenegraph_type],
135
+ outputs=[win_col, winsize, win_cyclic, refid])
136
+ run_btn.click(fn=recon_fun,
137
+ inputs=[scene, inputfiles, optim_level, lr1, niter1, lr2, niter2, min_conf_thr, matching_conf_thr,
138
+ as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
139
+ scenegraph_type, winsize, win_cyclic, refid, TSDF_thresh, shared_intrinsics],
140
+ outputs=[scene, outmodel])
141
+ min_conf_thr.release(fn=model_from_scene_fun,
142
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
143
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
144
+ outputs=outmodel)
145
+ cam_size.change(fn=model_from_scene_fun,
146
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
147
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
148
+ outputs=outmodel)
149
+ TSDF_thresh.change(fn=model_from_scene_fun,
150
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
151
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
152
+ outputs=outmodel)
153
+ as_pointcloud.change(fn=model_from_scene_fun,
154
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
155
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
156
+ outputs=outmodel)
157
+ mask_sky.change(fn=model_from_scene_fun,
158
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
159
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
160
+ outputs=outmodel)
161
+ clean_depth.change(fn=model_from_scene_fun,
162
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
163
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
164
+ outputs=outmodel)
165
+ transparent_cams.change(model_from_scene_fun,
166
+ inputs=[scene, min_conf_thr, as_pointcloud, mask_sky,
167
+ clean_depth, transparent_cams, cam_size, TSDF_thresh],
168
+ outputs=outmodel)
169
+ demo.launch(share=None, server_name=None, server_port=None)