GonzaloMG commited on
Commit
fbad7a8
1 Parent(s): 7809a37

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mport gradio as gr
2
+ import cv2
3
+ import matplotlib
4
+ import numpy as np
5
+ import os
6
+ from PIL import Image
7
+ import spaces
8
+ import torch
9
+ import tempfile
10
+ from gradio_imageslider import ImageSlider
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ from depth_anything_v2.dpt import DepthAnythingV2
14
+
15
+ css = """
16
+ #img-display-container {
17
+ max-height: 100vh;
18
+ }
19
+ #img-display-input {
20
+ max-height: 80vh;
21
+ }
22
+ #img-display-output {
23
+ max-height: 80vh;
24
+ }
25
+ #download {
26
+ height: 62px;
27
+ }
28
+ """
29
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
30
+ model_configs = {
31
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
32
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
33
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
34
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
35
+ }
36
+ encoder2name = {
37
+ 'vits': 'Small',
38
+ 'vitb': 'Base',
39
+ 'vitl': 'Large',
40
+ 'vitg': 'Giant', # we are undergoing company review procedures to release our giant model checkpoint
41
+ }
42
+ encoder = 'vitl'
43
+ model_name = encoder2name[encoder]
44
+ model = DepthAnythingV2(**model_configs[encoder])
45
+ filepath = hf_hub_download(repo_id=f"depth-anything/Depth-Anything-V2-{model_name}", filename=f"depth_anything_v2_{encoder}.pth", repo_type="model")
46
+ state_dict = torch.load(filepath, map_location="cpu")
47
+ model.load_state_dict(state_dict)
48
+ model = model.to(DEVICE).eval()
49
+
50
+ title = "# Depth Anything V2"
51
+ description = """Official demo for **Depth Anything V2**.
52
+ Please refer to our [paper](https://arxiv.org/abs/2406.09414), [project page](https://depth-anything-v2.github.io), and [github](https://github.com/DepthAnything/Depth-Anything-V2) for more details."""
53
+
54
+ @spaces.GPU
55
+ def predict_depth(image):
56
+ return model.infer_image(image)
57
+
58
+ with gr.Blocks(css=css) as demo:
59
+ gr.Markdown(title)
60
+ gr.Markdown(description)
61
+ gr.Markdown("### Depth Prediction demo")
62
+
63
+ with gr.Row():
64
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
65
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
66
+ submit = gr.Button(value="Compute Depth")
67
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
68
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download",)
69
+
70
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
71
+
72
+ def on_submit(image):
73
+ original_image = image.copy()
74
+
75
+ h, w = image.shape[:2]
76
+
77
+ depth = predict_depth(image[:, :, ::-1])
78
+
79
+ raw_depth = Image.fromarray(depth.astype('uint16'))
80
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
81
+ raw_depth.save(tmp_raw_depth.name)
82
+
83
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
84
+ depth = depth.astype(np.uint8)
85
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
86
+
87
+ gray_depth = Image.fromarray(depth)
88
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
89
+ gray_depth.save(tmp_gray_depth.name)
90
+
91
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name]
92
+
93
+ submit.click(on_submit, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file])
94
+
95
+ example_files = os.listdir('assets/examples')
96
+ example_files.sort()
97
+ example_files = [os.path.join('assets/examples', filename) for filename in example_files]
98
+ examples = gr.Examples(examples=example_files, inputs=[input_image], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
99
+
100
+
101
+ if __name__ == '__main__':
102
+ demo.queue().launch(share=True)