GonzaloMG commited on
Commit
58cc205
1 Parent(s): 4e67073

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -26
app.py CHANGED
@@ -15,9 +15,11 @@ import tempfile
15
  from gradio_imageslider import ImageSlider
16
  from huggingface_hub import hf_hub_download
17
 
18
- from Marigold.marigold import MarigoldPipeline
19
- from diffusers import AutoencoderKL, DDIMScheduler, UNet2DConditionModel
20
- from transformers import CLIPTextModel, CLIPTokenizer
 
 
21
 
22
  css = """
23
  #img-display-container {
@@ -51,35 +53,67 @@ title = "# End-to-End Fine-Tuned GeoWizard"
51
  description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
52
 
53
  @spaces.GPU
54
- def predict_depth(image, processing_res_choice):
55
  with torch.no_grad():
56
- pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", normals=False, processing_res=processing_res_choice, match_input_res=True)
57
- pred = pipe_out.depth_np
58
- pred_colored = pipe_out.depth_colored
59
- return pred, pred_colored
 
 
 
 
60
 
61
  with gr.Blocks(css=css) as demo:
62
  gr.Markdown(title)
63
  gr.Markdown(description)
64
- gr.Markdown("### Depth Prediction demo")
65
 
66
  with gr.Row():
67
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
68
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
69
-
70
- with gr.Row():
71
- submit = gr.Button(value="Compute Depth")
72
  processing_res_choice = gr.Radio(
73
  [
74
- ("Recommended (768)", 768),
75
  ("Native", 0),
 
76
  ],
77
  label="Processing resolution",
78
- value=768,
79
  )
 
 
 
 
 
 
 
 
 
 
80
 
81
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
82
- raw_file = gr.File(label="Raw Depth Data (.npy)", elem_id="download")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
85
 
@@ -88,32 +122,62 @@ with gr.Blocks(css=css) as demo:
88
  if image is None:
89
  print("No image uploaded.")
90
  return None
91
-
92
  pil_image = Image.fromarray(image.astype('uint8'))
93
- depth_npy, depth_colored = predict_depth(pil_image, processing_res_choice)
94
 
95
- # Save the npy data (raw depth map)
96
  tmp_npy_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
97
- np.save(tmp_npy_depth.name, depth_npy)
 
 
98
 
99
  # Save the grayscale depth map
100
- depth_gray = (depth_npy * 65535.0).astype(np.uint16)
101
  tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
102
  Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
103
 
104
- # Save the colored depth map
105
  tmp_colored_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
106
  depth_colored.save(tmp_colored_depth.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- return [(image, depth_colored), tmp_gray_depth.name, tmp_npy_depth.name]
109
 
110
- submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider, gray_depth_file, raw_file])
111
 
112
  example_files = os.listdir('assets/examples')
113
  example_files.sort()
114
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
115
  example_files = [[image, 768] for image in example_files]
116
- examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
117
 
118
 
119
  if __name__ == '__main__':
 
15
  from gradio_imageslider import ImageSlider
16
  from huggingface_hub import hf_hub_download
17
 
18
+ from geowizard.models.geowizard_pipeline import DepthNormalEstimationPipeline
19
+ from Geowizard.geowizard.models.unet_2d_condition import UNet2DConditionModel
20
+ from diffusers import DDIMScheduler, AutoencoderKL
21
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
22
+
23
 
24
  css = """
25
  #img-display-container {
 
53
  description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
54
 
55
  @spaces.GPU
56
+ def predict(image, processing_res_choice):
57
  with torch.no_grad():
58
+ pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", processing_res=processing_res_choice, match_input_res=True)
59
+ # depth
60
+ depth_pred = pipe_out.depth_np
61
+ depth_colored = pipe_out.depth_colored
62
+ # normals
63
+ normal_pred = pipe_out.normal_np
64
+ normal_colored = pipe_out.normal_colored
65
+ return depth_pred, depth_colored, normal_pred, normal_colored
66
 
67
  with gr.Blocks(css=css) as demo:
68
  gr.Markdown(title)
69
  gr.Markdown(description)
70
+ gr.Markdown("### Depth and Normals Prediction demo")
71
 
72
  with gr.Row():
73
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
 
 
 
 
74
  processing_res_choice = gr.Radio(
75
  [
 
76
  ("Native", 0),
77
+ ("Recommended", 768),
78
  ],
79
  label="Processing resolution",
80
+ value=0,
81
  )
82
+ model_choice = gr.Dropdown(
83
+ list(models.keys()), label="Select Model", value=list(models.keys())[0]
84
+ )
85
+
86
+ submit = gr.Button(value="Compute Depth and Normals")
87
+
88
+ with gr.Row():
89
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
90
+ normal_image_slider = ImageSlider(label="Normal Map with Slider View", elem_id='normal-display-output', position=0.5)
91
+
92
 
93
+ colored_depth_file = gr.File(label="Colored Depth Image", elem_id="download")
94
+ gray_depth_file = gr.File(label="Grayscale Depth Map", elem_id="download")
95
+ raw_depth_file = gr.File(label="Raw Depth Data (.npy)", elem_id="download")
96
+ colored_normal_file = gr.File(label="Colored Normal Image", elem_id="download")
97
+ raw_normal_file = gr.File(label="Raw Normal Data (.npy)", elem_id="download")
98
+
99
+
100
+ # with gr.Row():
101
+ # input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
102
+ # depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
103
+
104
+ # with gr.Row():
105
+ # submit = gr.Button(value="Compute Depth")
106
+ # processing_res_choice = gr.Radio(
107
+ # [
108
+ # ("Recommended (768)", 768),
109
+ # ("Native", 0),
110
+ # ],
111
+ # label="Processing resolution",
112
+ # value=768,
113
+ # )
114
+
115
+ # gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
116
+ # raw_file = gr.File(label="Raw Depth Data (.npy)", elem_id="download")
117
 
118
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
119
 
 
122
  if image is None:
123
  print("No image uploaded.")
124
  return None
125
+
126
  pil_image = Image.fromarray(image.astype('uint8'))
127
+ depth_pred, depth_colored, normal_pred, normal_colored = predict(pil_image, processing_res, model_choice, current_model)
128
 
129
+ # Save depth and normals npy data
130
  tmp_npy_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
131
+ np.save(tmp_npy_depth.name, depth_pred)
132
+ tmp_npy_normal = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
133
+ np.save(tmp_npy_normal.name, normal_pred)
134
 
135
  # Save the grayscale depth map
136
+ depth_gray = (depth_pred * 65535.0).astype(np.uint16)
137
  tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
138
  Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
139
 
140
+ # Save the colored depth and normals maps
141
  tmp_colored_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
142
  depth_colored.save(tmp_colored_depth.name)
143
+ tmp_colored_normal = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
144
+ normal_colored.save(tmp_colored_normal.name)
145
+
146
+ return (
147
+ (pil_image, depth_colored), # For ImageSlider: (base image, overlay image)
148
+ (pil_image, normal_colored), # For gr.Image
149
+ tmp_colored_depth.name, # File outputs
150
+ tmp_gray_depth.name,
151
+ tmp_npy_depth.name,
152
+ tmp_colored_normal.name,
153
+ tmp_npy_normal.name
154
+ )
155
+
156
+ # pil_image = Image.fromarray(image.astype('uint8'))
157
+ # depth_npy, depth_colored = predict_depth(pil_image, processing_res_choice)
158
+
159
+ # # Save the npy data (raw depth map)
160
+ # tmp_npy_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
161
+ # np.save(tmp_npy_depth.name, depth_npy)
162
+
163
+ # # Save the grayscale depth map
164
+ # depth_gray = (depth_npy * 65535.0).astype(np.uint16)
165
+ # tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
166
+ # Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
167
+
168
+ # # Save the colored depth map
169
+ # tmp_colored_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
170
+ # depth_colored.save(tmp_colored_depth.name)
171
 
172
+ # return [(image, depth_colored), tmp_gray_depth.name, tmp_npy_depth.name]
173
 
174
+ submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider,normal_image_slider,colored_depth_file,gray_depth_file,raw_depth_file,colored_normal_file,raw_normal_file])
175
 
176
  example_files = os.listdir('assets/examples')
177
  example_files.sort()
178
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
179
  example_files = [[image, 768] for image in example_files]
180
+ examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider,normal_image_slider,colored_depth_file,gray_depth_file,raw_depth_file,colored_normal_file,raw_normal_file], fn=on_submit)
181
 
182
 
183
  if __name__ == '__main__':