File size: 3,804 Bytes
bcec54e
 
 
 
 
2a7f680
bcec54e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import sys

depth_directory = os.path.abspath(os.path.join(os.path.dirname(__file__), 'depth'))
sys.path.append(depth_directory)
sys.path.append('./latent-diffusion')
os.chdir(depth_directory)

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from depth.models_depth.model import EVPDepth
from depth.configs.train_options import TrainOptions
from depth.configs.test_options import TestOptions
import glob
import utils
import torchvision.transforms as transforms
from utils_depth.misc import colorize
from PIL import Image
import torch.nn.functional as F
import gradio as gr
import tempfile


css = """
#img-display-container {
    max-height: 50vh;
    }
#img-display-input {
    max-height: 40vh;
    }
#img-display-output {
    max-height: 40vh;
    }

"""

def create_demo(model, device):
    gr.Markdown("### Depth Prediction demo")
    with gr.Row():
        input_image = gr.Image(label="Input Image", type='pil', elem_id='img-display-input')
        depth_image = gr.Image(label="Depth Map", elem_id='img-display-output')
    raw_file = gr.File(label="16-bit raw depth, multiplier:256")
    submit = gr.Button("Submit")
    
    def on_submit(image):
        transform = transforms.ToTensor()
        image = transform(image).unsqueeze(0).to(device)
        shape = image.shape
        image = torch.nn.functional.interpolate(image, (440,480), mode='bilinear', align_corners=True)
        image = F.pad(image, (0, 0, 40, 0))
        with torch.no_grad():
            pred = model(image)['pred_d']

        pred = pred[:,:,40:,:]
        pred = torch.nn.functional.interpolate(pred, shape[2:], mode='bilinear', align_corners=True)
        pred_d_numpy = pred.squeeze().cpu().numpy()
        colored_depth, _, _ = colorize(pred_d_numpy, cmap='gray_r')
        
        tmp = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        raw_depth = Image.fromarray((pred_d_numpy*256).astype('uint16'))
        raw_depth.save(tmp.name)
        return [colored_depth, tmp.name]

    submit.click(on_submit, inputs=[input_image], outputs=[depth_image, raw_file])
    examples = gr.Examples(examples=["test_img.jpg"],
                           inputs=[input_image])


def main():
    opt = TestOptions().initialize()
    opt.add_argument('--img_path',  type=str)
    args = opt.parse_args()
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EVPDepth(args=args, caption_aggregation=True)
    cudnn.benchmark = True
    model.to(device)
    model_weight = torch.load(args.ckpt_dir)['model']
    if 'module' in next(iter(model_weight.items()))[0]:
        model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items())
    model.load_state_dict(model_weight, strict=False)
    model.eval()
    
    title = "# EVP"
    description = """Official demo for **EVP: Enhanced Visual Perception using Inverse Multi-Attentive Feature
    Refinement and Regularized Image-Text Alignment**.
    EVP is a deep learning model for metric depth estimation from a single image.
    Please refer to our [paper](https://arxiv.org/abs/2312.08548) or [github](https://github.com/Lavreniuk/EVP) for more details."""

    with gr.Blocks() as demo:
        gr.Markdown(title)
        gr.Markdown(description)
        with gr.Tab("Depth Prediction"):
            create_demo(model, device)
        gr.HTML('''<br><br><br><center>You can duplicate this Space to skip the queue:<a href="https://huggingface.co/spaces/shariqfarooq/ZoeDepth?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a><br>
                <p><img src="https://visitor-badge.glitch.me/badge?page_id=shariqfarooq.zoedepth_demo_hf" alt="visitors"></p></center>''')

    demo.queue().launch(share=True)


if __name__ == '__main__':
    main()