lemonaddie commited on
Commit
7c4a89c
1 Parent(s): 74f851f
Files changed (1) hide show
  1. app1.py +432 -0
app1.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import functools
3
+ import os
4
+ import shutil
5
+ import sys
6
+
7
+ import git
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch as torch
11
+ from PIL import Image
12
+
13
+ from gradio_imageslider import ImageSlider
14
+
15
+
16
+ def process(
17
+ pipe,
18
+ path_input,
19
+ ensemble_size,
20
+ denoise_steps,
21
+ processing_res,
22
+ path_out_16bit=None,
23
+ path_out_fp32=None,
24
+ path_out_vis=None,
25
+ _input_3d_plane_near=None,
26
+ _input_3d_plane_far=None,
27
+ _input_3d_embossing=None,
28
+ _input_3d_filter_size=None,
29
+ _input_3d_frame_near=None,
30
+ ):
31
+ if path_out_vis is not None:
32
+ return (
33
+ [path_out_16bit, path_out_vis],
34
+ [path_out_16bit, path_out_fp32, path_out_vis],
35
+ )
36
+
37
+ input_image = Image.open(path_input)
38
+
39
+ pipe_out = pipe(
40
+ input_image,
41
+ ensemble_size=ensemble_size,
42
+ denoising_steps=denoise_steps,
43
+ processing_res=processing_res,
44
+ batch_size=1 if processing_res == 0 else 0,
45
+ show_progress_bar=True,
46
+ )
47
+
48
+ depth_pred = pipe_out.depth_np
49
+ depth_colored = pipe_out.depth_colored
50
+ depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
51
+
52
+ path_output_dir = os.path.splitext(path_input)[0] + "_output"
53
+ os.makedirs(path_output_dir, exist_ok=True)
54
+
55
+ name_base = os.path.splitext(os.path.basename(path_input))[0]
56
+ path_out_fp32 = os.path.join(path_output_dir, f"{name_base}_depth_fp32.npy")
57
+ path_out_16bit = os.path.join(path_output_dir, f"{name_base}_depth_16bit.png")
58
+ path_out_vis = os.path.join(path_output_dir, f"{name_base}_depth_colored.png")
59
+
60
+ np.save(path_out_fp32, depth_pred)
61
+ Image.fromarray(depth_16bit).save(path_out_16bit, mode="I;16")
62
+ depth_colored.save(path_out_vis)
63
+
64
+ return (
65
+ [path_out_16bit, path_out_vis],
66
+ [path_out_16bit, path_out_fp32, path_out_vis],
67
+ )
68
+
69
+
70
+
71
+ def run_demo_server(pipe):
72
+ process_pipe = functools.partial(process, pipe)
73
+ os.environ["GRADIO_ALLOW_FLAGGING"] = "never"
74
+
75
+ with gr.Blocks(
76
+ analytics_enabled=False,
77
+ title="Marigold Depth Estimation",
78
+ css="""
79
+ #download {
80
+ height: 118px;
81
+ }
82
+ .slider .inner {
83
+ width: 5px;
84
+ background: #FFF;
85
+ }
86
+ .viewport {
87
+ aspect-ratio: 4/3;
88
+ }
89
+ """,
90
+ ) as demo:
91
+ gr.Markdown(
92
+ """
93
+ <h1 align="center">Marigold Depth Estimation</h1>
94
+ <p align="center">
95
+ <a title="Website" href="https://marigoldmonodepth.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
96
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
97
+ </a>
98
+ <a title="arXiv" href="https://arxiv.org/abs/2312.02145" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
99
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
100
+ </a>
101
+ <a title="Github" href="https://github.com/prs-eth/marigold" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
102
+ <img src="https://img.shields.io/github/stars/prs-eth/marigold?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
103
+ </a>
104
+ <a title="Social" href="https://twitter.com/antonobukhov1" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
105
+ <img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social">
106
+ </a>
107
+ </p>
108
+ <p align="justify">
109
+ Marigold is the new state-of-the-art depth estimator for images in the wild.
110
+ Upload your image into the <b>left</b> side, or click any of the <b>examples</b> below.
111
+ The result will be computed and appear on the <b>right</b> in the output comparison window.
112
+ <b style="color: red;">NEW</b>: Scroll down to the new 3D printing part of the demo!
113
+ </p>
114
+ """
115
+ )
116
+
117
+ with gr.Row():
118
+ with gr.Column():
119
+ input_image = gr.Image(
120
+ label="Input Image",
121
+ type="filepath",
122
+ )
123
+ with gr.Accordion("Advanced options", open=False):
124
+ ensemble_size = gr.Slider(
125
+ label="Ensemble size",
126
+ minimum=1,
127
+ maximum=20,
128
+ step=1,
129
+ value=10,
130
+ )
131
+ denoise_steps = gr.Slider(
132
+ label="Number of denoising steps",
133
+ minimum=1,
134
+ maximum=20,
135
+ step=1,
136
+ value=10,
137
+ )
138
+ processing_res = gr.Radio(
139
+ [
140
+ ("Native", 0),
141
+ ("Recommended", 768),
142
+ ],
143
+ label="Processing resolution",
144
+ value=768,
145
+ )
146
+ input_output_16bit = gr.File(
147
+ label="Predicted depth (16-bit)",
148
+ visible=False,
149
+ )
150
+ input_output_fp32 = gr.File(
151
+ label="Predicted depth (32-bit)",
152
+ visible=False,
153
+ )
154
+ input_output_vis = gr.File(
155
+ label="Predicted depth (red-near, blue-far)",
156
+ visible=False,
157
+ )
158
+ with gr.Row():
159
+ submit_btn = gr.Button(value="Compute Depth", variant="primary")
160
+ clear_btn = gr.Button(value="Clear")
161
+ with gr.Column():
162
+ output_slider = ImageSlider(
163
+ label="Predicted depth (red-near, blue-far)",
164
+ type="filepath",
165
+ show_download_button=True,
166
+ show_share_button=True,
167
+ interactive=False,
168
+ elem_classes="slider",
169
+ position=0.25,
170
+ )
171
+ files = gr.Files(
172
+ label="Depth outputs",
173
+ elem_id="download",
174
+ interactive=False,
175
+ )
176
+
177
+ demo_3d_header = gr.Markdown(
178
+ """
179
+ <h3 align="center">3D Printing Depth Maps</h3>
180
+ <p align="justify">
181
+ This part of the demo uses Marigold depth maps estimated in the previous step to create a
182
+ 3D-printable model. The models are watertight, with correct normals, and exported in the STL format.
183
+ We recommended creating the first model with the default parameters and iterating on it until the best
184
+ result (see Pro Tips below).
185
+ </p>
186
+ """,
187
+ render=False,
188
+ )
189
+
190
+ demo_3d = gr.Row(render=False)
191
+ with demo_3d:
192
+ with gr.Column():
193
+ with gr.Accordion("3D printing demo: Main options", open=True):
194
+ plane_near = gr.Slider(
195
+ label="Relative position of the near plane (between 0 and 1)",
196
+ minimum=0.0,
197
+ maximum=1.0,
198
+ step=0.001,
199
+ value=0.0,
200
+ )
201
+ plane_far = gr.Slider(
202
+ label="Relative position of the far plane (between near and 1)",
203
+ minimum=0.0,
204
+ maximum=1.0,
205
+ step=0.001,
206
+ value=1.0,
207
+ )
208
+ embossing = gr.Slider(
209
+ label="Embossing level",
210
+ minimum=0,
211
+ maximum=100,
212
+ step=1,
213
+ value=20,
214
+ )
215
+ with gr.Accordion("3D printing demo: Advanced options", open=False):
216
+ size_longest_px = gr.Slider(
217
+ label="Size (px) of the longest side",
218
+ minimum=256,
219
+ maximum=1024,
220
+ step=256,
221
+ value=512,
222
+ )
223
+ size_longest_cm = gr.Slider(
224
+ label="Size (cm) of the longest side",
225
+ minimum=1,
226
+ maximum=100,
227
+ step=1,
228
+ value=10,
229
+ )
230
+ filter_size = gr.Slider(
231
+ label="Size (px) of the smoothing filter",
232
+ minimum=1,
233
+ maximum=5,
234
+ step=2,
235
+ value=3,
236
+ )
237
+ frame_thickness = gr.Slider(
238
+ label="Frame thickness",
239
+ minimum=0,
240
+ maximum=100,
241
+ step=1,
242
+ value=5,
243
+ )
244
+ frame_near = gr.Slider(
245
+ label="Frame's near plane offset",
246
+ minimum=-100,
247
+ maximum=100,
248
+ step=1,
249
+ value=1,
250
+ )
251
+ frame_far = gr.Slider(
252
+ label="Frame's far plane offset",
253
+ minimum=1,
254
+ maximum=10,
255
+ step=1,
256
+ value=1,
257
+ )
258
+ with gr.Row():
259
+ submit_3d = gr.Button(value="Create 3D", variant="primary")
260
+ clear_3d = gr.Button(value="Clear 3D")
261
+ gr.Markdown(
262
+ """
263
+ <h5 align="center">Pro Tips</h5>
264
+ <ol>
265
+ <li><b>Re-render with new parameters</b>: Click "Clear 3D" and then "Create 3D".</li>
266
+ <li><b>Adjust 3D scale and cut-off focus</b>: Set the frame's near plane offset to the
267
+ minimum and use 3D preview to evaluate depth scaling. Repeat until the scale is correct and
268
+ everything important is in the focus. Set the optimal value for frame's near
269
+ plane offset as a last step.</li>
270
+ <li><b>Increase details</b>: Decrease size of the smoothing filter (also increases noise).</li>
271
+ </ol>
272
+ """
273
+ )
274
+
275
+ with gr.Column():
276
+ viewer_3d = gr.Model3D(
277
+ camera_position=(75.0, 90.0, 1.25),
278
+ elem_classes="viewport",
279
+ label="3D preview (low-res, relief highlight)",
280
+ interactive=False,
281
+ )
282
+ files_3d = gr.Files(
283
+ label="3D model outputs (high-res)",
284
+ elem_id="download",
285
+ interactive=False,
286
+ )
287
+
288
+ blocks_settings_depth = [ensemble_size, denoise_steps, processing_res]
289
+ blocks_settings_3d = [plane_near, plane_far, embossing, size_longest_px, size_longest_cm, filter_size,
290
+ frame_thickness, frame_near, frame_far]
291
+ blocks_settings = blocks_settings_depth + blocks_settings_3d
292
+ map_id_to_default = {b._id: b.value for b in blocks_settings}
293
+
294
+ inputs = [
295
+ input_image,
296
+ ensemble_size,
297
+ denoise_steps,
298
+ processing_res,
299
+ input_output_16bit,
300
+ input_output_fp32,
301
+ input_output_vis,
302
+ plane_near,
303
+ plane_far,
304
+ embossing,
305
+ filter_size,
306
+ frame_near,
307
+ ]
308
+ outputs = [
309
+ submit_btn,
310
+ input_image,
311
+ output_slider,
312
+ files,
313
+ ]
314
+
315
+ def submit_depth_fn(*args):
316
+ out = list(process_pipe(*args))
317
+ out = [gr.Button(interactive=False), gr.Image(interactive=False)] + out
318
+ return out
319
+
320
+ submit_btn.click(
321
+ fn=submit_depth_fn,
322
+ inputs=inputs,
323
+ outputs=outputs,
324
+ concurrency_limit=1,
325
+ )
326
+
327
+ gr.Examples(
328
+ fn=submit_depth_fn,
329
+ examples=[
330
+ [
331
+ "files/bee.jpg",
332
+ 10, # ensemble_size
333
+ 10, # denoise_steps
334
+ 768, # processing_res
335
+ "files/bee_depth_16bit.png",
336
+ "files/bee_depth_fp32.npy",
337
+ "files/bee_depth_colored.png",
338
+ 0.0, # plane_near
339
+ 0.5, # plane_far
340
+ 20, # embossing
341
+ 3, # filter_size
342
+ 0, # frame_near
343
+ ],
344
+ ],
345
+ inputs=inputs,
346
+ outputs=outputs,
347
+ cache_examples=True,
348
+ )
349
+
350
+ demo_3d_header.render()
351
+ demo_3d.render()
352
+
353
+ def clear_fn():
354
+ out = []
355
+ for b in blocks_settings:
356
+ out.append(map_id_to_default[b._id])
357
+ out += [
358
+ gr.Button(interactive=True),
359
+ gr.Button(interactive=True),
360
+ gr.Image(value=None, interactive=True),
361
+ None, None, None, None, None, None, None,
362
+ ]
363
+ return out
364
+
365
+ clear_btn.click(
366
+ fn=clear_fn,
367
+ inputs=[],
368
+ outputs=blocks_settings + [
369
+ submit_btn,
370
+ submit_3d,
371
+ input_image,
372
+ input_output_16bit,
373
+ input_output_fp32,
374
+ input_output_vis,
375
+ output_slider,
376
+ files,
377
+ viewer_3d,
378
+ files_3d,
379
+ ],
380
+ )
381
+
382
+ demo.queue(
383
+ api_open=False,
384
+ ).launch(
385
+ server_name="0.0.0.0",
386
+ server_port=7860,
387
+ )
388
+
389
+
390
+ def prefetch_hf_cache(pipe):
391
+ process(pipe, "files/bee.jpg", 1, 1, 64)
392
+ shutil.rmtree("files/bee_output")
393
+
394
+
395
+ def main():
396
+
397
+ REPO_URL = "https://github.com/lemonaddie/geowizard.git"
398
+ CHECKPOINT = "lemonaddie/Geowizard"
399
+ REPO_DIR = "geowizard"
400
+
401
+ if os.path.isdir(REPO_DIR):
402
+ shutil.rmtree(REPO_DIR)
403
+
404
+ repo = git.Repo.clone_from(REPO_URL, REPO_DIR)
405
+ sys.path.append(os.path.join(os.getcwd(), REPO_DIR))
406
+
407
+ from pipeline.depth_normal_pipeline_clip_cfg import DepthNormalEstimationPipeline
408
+
409
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
410
+ pipe = DepthNormalEstimationPipeline.from_pretrained(CHECKPOINT)
411
+
412
+ try:
413
+ import xformers
414
+ pipe.enable_xformers_memory_efficient_attention()
415
+ except:
416
+ pass # run without xformers
417
+
418
+ pipe = pipe.to(device)
419
+ try:
420
+ import xformers
421
+ pipe.enable_xformers_memory_efficient_attention()
422
+ except:
423
+ pass # run without xformers
424
+
425
+ pipe = pipe.to(device)
426
+ prefetch_hf_cache(pipe)
427
+ run_demo_server(pipe)
428
+
429
+
430
+ if __name__ == "__main__":
431
+ main()
432
+