Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -36,7 +36,7 @@ css = """
|
|
36 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
37 |
dtype = torch.float32
|
38 |
variant = None
|
39 |
-
checkpoint_path = "GonzaloMG/marigold-e2e-ft-
|
40 |
unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
|
41 |
vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder="vae")
|
42 |
text_encoder = CLIPTextModel.from_pretrained(checkpoint_path, subfolder="text_encoder")
|
@@ -55,28 +55,28 @@ pipe = pipe.to(DEVICE)
|
|
55 |
pipe.unet.eval()
|
56 |
|
57 |
|
58 |
-
title = "# End-to-End Fine-Tuned Marigold for
|
59 |
description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
|
60 |
|
61 |
@spaces.GPU
|
62 |
-
def
|
63 |
with torch.no_grad():
|
64 |
-
pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", normals=
|
65 |
-
pred = pipe_out.
|
66 |
-
pred_colored = pipe_out.
|
67 |
return pred, pred_colored
|
68 |
|
69 |
with gr.Blocks(css=css) as demo:
|
70 |
gr.Markdown(title)
|
71 |
gr.Markdown(description)
|
72 |
-
gr.Markdown("###
|
73 |
|
74 |
with gr.Row():
|
75 |
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
76 |
-
|
77 |
|
78 |
with gr.Row():
|
79 |
-
submit = gr.Button(value="Compute
|
80 |
processing_res_choice = gr.Radio(
|
81 |
[
|
82 |
("Recommended (768)", 768),
|
@@ -86,8 +86,7 @@ with gr.Blocks(css=css) as demo:
|
|
86 |
value=768,
|
87 |
)
|
88 |
|
89 |
-
|
90 |
-
raw_file = gr.File(label="Raw Depth Data (.npy)", elem_id="download")
|
91 |
|
92 |
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
|
93 |
|
@@ -98,30 +97,30 @@ with gr.Blocks(css=css) as demo:
|
|
98 |
return None
|
99 |
|
100 |
pil_image = Image.fromarray(image.astype('uint8'))
|
101 |
-
|
102 |
|
103 |
-
# Save the npy data (raw
|
104 |
-
|
105 |
-
np.save(
|
106 |
|
107 |
# Save the grayscale depth map
|
108 |
-
depth_gray = (depth_npy * 65535.0).astype(np.uint16)
|
109 |
-
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
110 |
-
Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
|
111 |
|
112 |
-
# Save the colored
|
113 |
-
|
114 |
-
|
115 |
|
116 |
-
return [(image,
|
117 |
|
118 |
-
submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[
|
119 |
|
120 |
example_files = os.listdir('assets/examples')
|
121 |
example_files.sort()
|
122 |
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
|
123 |
example_files = [[image, 768] for image in example_files]
|
124 |
-
examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[
|
125 |
|
126 |
|
127 |
if __name__ == '__main__':
|
|
|
36 |
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
37 |
dtype = torch.float32
|
38 |
variant = None
|
39 |
+
checkpoint_path = "GonzaloMG/marigold-e2e-ft-normals"
|
40 |
unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
|
41 |
vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder="vae")
|
42 |
text_encoder = CLIPTextModel.from_pretrained(checkpoint_path, subfolder="text_encoder")
|
|
|
55 |
pipe.unet.eval()
|
56 |
|
57 |
|
58 |
+
title = "# End-to-End Fine-Tuned Marigold for Normals Estimation"
|
59 |
description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
|
60 |
|
61 |
@spaces.GPU
|
62 |
+
def predict_normals(image, processing_res_choice):
|
63 |
with torch.no_grad():
|
64 |
+
pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", normals=True, processing_res=processing_res_choice, match_input_res=True)
|
65 |
+
pred = pipe_out.normal_np
|
66 |
+
pred_colored = pipe_out.normal_colored
|
67 |
return pred, pred_colored
|
68 |
|
69 |
with gr.Blocks(css=css) as demo:
|
70 |
gr.Markdown(title)
|
71 |
gr.Markdown(description)
|
72 |
+
gr.Markdown("### Normals Prediction demo")
|
73 |
|
74 |
with gr.Row():
|
75 |
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
76 |
+
normals_image_slider = ImageSlider(label="Surface Normals with Slider View", elem_id='img-display-output', position=0.5)
|
77 |
|
78 |
with gr.Row():
|
79 |
+
submit = gr.Button(value="Compute Normals")
|
80 |
processing_res_choice = gr.Radio(
|
81 |
[
|
82 |
("Recommended (768)", 768),
|
|
|
86 |
value=768,
|
87 |
)
|
88 |
|
89 |
+
raw_file = gr.File(label="Raw Normals Data (.npy)", elem_id="download")
|
|
|
90 |
|
91 |
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
|
92 |
|
|
|
97 |
return None
|
98 |
|
99 |
pil_image = Image.fromarray(image.astype('uint8'))
|
100 |
+
normal_npy, normal_colored = predict_normals(pil_image, processing_res_choice)
|
101 |
|
102 |
+
# Save the npy data (raw normals)
|
103 |
+
tmp_npy_normal = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
|
104 |
+
np.save(tmp_npy_normal.name, normal_npy)
|
105 |
|
106 |
# Save the grayscale depth map
|
107 |
+
# depth_gray = (depth_npy * 65535.0).astype(np.uint16)
|
108 |
+
# tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
109 |
+
# Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
|
110 |
|
111 |
+
# Save the colored normals map
|
112 |
+
tmp_colored_normal = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
113 |
+
normal_colored.save(tmp_colored_normal.name)
|
114 |
|
115 |
+
return [(image, normal_colored), tmp_npy_normal.name]
|
116 |
|
117 |
+
submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[normals_image_slider, raw_file])
|
118 |
|
119 |
example_files = os.listdir('assets/examples')
|
120 |
example_files.sort()
|
121 |
example_files = [os.path.join('assets/examples', filename) for filename in example_files]
|
122 |
example_files = [[image, 768] for image in example_files]
|
123 |
+
examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[normals_image_slider, raw_file], fn=on_submit)
|
124 |
|
125 |
|
126 |
if __name__ == '__main__':
|