File size: 5,171 Bytes
f5a0315
 
 
 
 
e2bd985
fbad7a8
 
 
 
 
 
 
 
 
 
 
7fbe3ab
e2bd985
 
fbad7a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
029bde2
 
e2bd985
 
 
 
 
dbb5354
e2bd985
 
 
 
 
 
 
 
 
 
 
 
f5a0315
 
 
e2bd985
fbad7a8
326dd31
e2bd985
326dd31
e2bd985
 
 
fbad7a8
 
 
 
 
 
 
 
 
e0b6027
aeb7b27
e0b6027
 
 
 
 
 
 
 
 
 
fbad7a8
fd55a71
fbad7a8
 
 
326dd31
fbad7a8
e2bd985
 
 
 
 
326dd31
e2bd985
 
fd55a71
 
e2bd985
 
 
 
 
 
 
 
 
dbb5354
fd55a71
fbad7a8
e0b6027
fbad7a8
 
 
 
6c9dbcd
326dd31
fbad7a8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
###########################################################################################
# Code based on the Hugging Face Space of Depth Anything v2
# https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/app.py
###########################################################################################

import gradio as gr
import cv2
import matplotlib
import numpy as np
import os
from PIL import Image
import spaces
import torch
import tempfile
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download

from Marigold.marigold import MarigoldPipeline
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

css = """
#img-display-container {
    max-height: 100vh;
}
#img-display-input {
    max-height: 80vh;
}
#img-display-output {
    max-height: 80vh;
}
#download {
    height: 62px;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32
variant = None
checkpoint_path = "GonzaloMG/marigold-e2e-ft-depth"
unet         = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")   
vae          = AutoencoderKL.from_pretrained(checkpoint_path, subfolder="vae")  
text_encoder = CLIPTextModel.from_pretrained(checkpoint_path, subfolder="text_encoder")  
tokenizer    = CLIPTokenizer.from_pretrained(checkpoint_path, subfolder="tokenizer") 
scheduler    = DDIMScheduler.from_pretrained(checkpoint_path, timestep_spacing="trailing", subfolder="scheduler") 
pipe = MarigoldPipeline.from_pretrained(pretrained_model_name_or_path = checkpoint_path,
                                        unet=unet, 
                                        vae=vae, 
                                        scheduler=scheduler, 
                                        text_encoder=text_encoder, 
                                        tokenizer=tokenizer, 
                                        variant=variant, 
                                        torch_dtype=dtype, 
                                        )
pipe = pipe.to(DEVICE)
pipe.unet.eval()


title = "# End-to-End Fine-Tuned Marigold for Depth Estimation"
description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
    
@spaces.GPU
def predict_depth(image, processing_res_choice):
    with torch.no_grad():
        pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", normals=False, processing_res=processing_res_choice, match_input_res=True)
    pred = pipe_out.depth_np
    pred_colored = pipe_out.depth_colored
    return pred, pred_colored

with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown("### Depth Prediction demo")

    with gr.Row():
        input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
        depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)

    with gr.Row():
        submit = gr.Button(value="Compute Depth")
        processing_res_choice = gr.Radio(
                [
                    ("Recommended (768)", 768),
                    ("Native", 0),
                ],
                label="Processing resolution",
                value=768,
            )

    gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
    raw_file = gr.File(label="Raw Depth Data (.npy)", elem_id="download")

    cmap = matplotlib.colormaps.get_cmap('Spectral_r')

    def on_submit(image, processing_res_choice):

        if image is None:
            print("No image uploaded.")
            return None
    
        pil_image = Image.fromarray(image.astype('uint8'))
        depth_npy, depth_colored = predict_depth(pil_image, processing_res_choice)
    
        # Save the npy data (raw depth map)
        tmp_npy_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
        np.save(tmp_npy_depth.name, depth_npy)
    
        # Save the grayscale depth map
        depth_gray = (depth_npy * 65535.0).astype(np.uint16)
        tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
    
        # Save the colored depth map
        tmp_colored_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        depth_colored.save(tmp_colored_depth.name)
   
        return [(image, depth_colored),  tmp_gray_depth.name, tmp_npy_depth.name]

    submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider, gray_depth_file, raw_file])

    example_files = os.listdir('assets/examples')
    example_files.sort()
    example_files = [os.path.join('assets/examples', filename) for filename in example_files]
    example_files = [[image, 768] for image in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)


if __name__ == '__main__':
    demo.queue().launch(share=True)