File size: 5,078 Bytes
f5a0315
 
 
 
 
e2bd985
fbad7a8
 
 
 
 
 
 
 
 
 
 
7fbe3ab
e2bd985
 
fbad7a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
029bde2
 
ee2f8db
e2bd985
 
 
 
dbb5354
e2bd985
 
 
 
 
 
 
 
 
 
 
 
f5a0315
ee2f8db
f5a0315
e2bd985
fbad7a8
ee2f8db
e2bd985
ee2f8db
 
 
e2bd985
fbad7a8
 
 
 
ee2f8db
fbad7a8
 
 
ee2f8db
e0b6027
aeb7b27
ee2f8db
e0b6027
 
 
 
 
 
 
 
 
ee2f8db
fbad7a8
 
 
326dd31
fbad7a8
e2bd985
 
 
 
 
ee2f8db
e2bd985
ee2f8db
 
 
e2bd985
 
ee2f8db
 
 
e2bd985
ee2f8db
 
 
dbb5354
ee2f8db
fbad7a8
ee2f8db
fbad7a8
 
 
 
6c9dbcd
ee2f8db
fbad7a8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
###########################################################################################
# Code based on the Hugging Face Space of Depth Anything v2
# https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/app.py
###########################################################################################

import gradio as gr
import cv2
import matplotlib
import numpy as np
import os
from PIL import Image
import spaces
import torch
import tempfile
from gradio_imageslider import ImageSlider
from huggingface_hub import hf_hub_download

from Marigold.marigold import MarigoldPipeline
from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

css = """
#img-display-container {
    max-height: 100vh;
}
#img-display-input {
    max-height: 80vh;
}
#img-display-output {
    max-height: 80vh;
}
#download {
    height: 62px;
}
"""
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.float32
variant = None
checkpoint_path = "GonzaloMG/marigold-e2e-ft-normals"
unet         = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")   
vae          = AutoencoderKL.from_pretrained(checkpoint_path, subfolder="vae")  
text_encoder = CLIPTextModel.from_pretrained(checkpoint_path, subfolder="text_encoder")  
tokenizer    = CLIPTokenizer.from_pretrained(checkpoint_path, subfolder="tokenizer") 
scheduler    = DDIMScheduler.from_pretrained(checkpoint_path, timestep_spacing="trailing", subfolder="scheduler") 
pipe = MarigoldPipeline.from_pretrained(pretrained_model_name_or_path = checkpoint_path,
                                        unet=unet, 
                                        vae=vae, 
                                        scheduler=scheduler, 
                                        text_encoder=text_encoder, 
                                        tokenizer=tokenizer, 
                                        variant=variant, 
                                        torch_dtype=dtype, 
                                        )
pipe = pipe.to(DEVICE)
pipe.unet.eval()


title = "# End-to-End Fine-Tuned Marigold for Normals Estimation"
description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
    
@spaces.GPU
def predict_normals(image, processing_res_choice):
    with torch.no_grad():
        pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", normals=True, processing_res=processing_res_choice, match_input_res=True)
    pred = pipe_out.normal_np
    pred_colored = pipe_out.normal_colored
    return pred, pred_colored

with gr.Blocks(css=css) as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown("### Normals Prediction demo")

    with gr.Row():
        input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
        normals_image_slider = ImageSlider(label="Surface Normals with Slider View", elem_id='img-display-output', position=0.5)

    with gr.Row():
        submit = gr.Button(value="Compute Normals")
        processing_res_choice = gr.Radio(
                [
                    ("Recommended (768)", 768),
                    ("Native", 0),
                ],
                label="Processing resolution",
                value=768,
            )

    raw_file = gr.File(label="Raw Normals Data (.npy)", elem_id="download")

    cmap = matplotlib.colormaps.get_cmap('Spectral_r')

    def on_submit(image, processing_res_choice):

        if image is None:
            print("No image uploaded.")
            return None
    
        pil_image = Image.fromarray(image.astype('uint8'))
        normal_npy, normal_colored = predict_normals(pil_image, processing_res_choice)
    
        # Save the npy data (raw normals)
        tmp_npy_normal = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
        np.save(tmp_npy_normal.name, normal_npy)
    
        # Save the grayscale depth map
        # depth_gray = (depth_npy * 65535.0).astype(np.uint16)
        # tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        # Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
    
        # Save the colored normals map
        tmp_colored_normal = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
        normal_colored.save(tmp_colored_normal.name)
   
        return [(image, normal_colored), tmp_npy_normal.name]

    submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[normals_image_slider, raw_file])

    example_files = os.listdir('assets/examples')
    example_files.sort()
    example_files = [os.path.join('assets/examples', filename) for filename in example_files]
    example_files = [[image, 768] for image in example_files]
    examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[normals_image_slider, raw_file], fn=on_submit)


if __name__ == '__main__':
    demo.queue().launch(share=True)