GonzaloMG commited on
Commit
ee2f8db
1 Parent(s): 363b008

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -24
app.py CHANGED
@@ -36,7 +36,7 @@ css = """
36
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
37
  dtype = torch.float32
38
  variant = None
39
- checkpoint_path = "GonzaloMG/marigold-e2e-ft-depth"
40
  unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
41
  vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder="vae")
42
  text_encoder = CLIPTextModel.from_pretrained(checkpoint_path, subfolder="text_encoder")
@@ -55,28 +55,28 @@ pipe = pipe.to(DEVICE)
55
  pipe.unet.eval()
56
 
57
 
58
- title = "# End-to-End Fine-Tuned Marigold for Depth Estimation"
59
  description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
60
 
61
  @spaces.GPU
62
- def predict_depth(image, processing_res_choice):
63
  with torch.no_grad():
64
- pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", normals=False, processing_res=processing_res_choice, match_input_res=True)
65
- pred = pipe_out.depth_np
66
- pred_colored = pipe_out.depth_colored
67
  return pred, pred_colored
68
 
69
  with gr.Blocks(css=css) as demo:
70
  gr.Markdown(title)
71
  gr.Markdown(description)
72
- gr.Markdown("### Depth Prediction demo")
73
 
74
  with gr.Row():
75
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
76
- depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5)
77
 
78
  with gr.Row():
79
- submit = gr.Button(value="Compute Depth")
80
  processing_res_choice = gr.Radio(
81
  [
82
  ("Recommended (768)", 768),
@@ -86,8 +86,7 @@ with gr.Blocks(css=css) as demo:
86
  value=768,
87
  )
88
 
89
- gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download",)
90
- raw_file = gr.File(label="Raw Depth Data (.npy)", elem_id="download")
91
 
92
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
93
 
@@ -98,30 +97,30 @@ with gr.Blocks(css=css) as demo:
98
  return None
99
 
100
  pil_image = Image.fromarray(image.astype('uint8'))
101
- depth_npy, depth_colored = predict_depth(pil_image, processing_res_choice)
102
 
103
- # Save the npy data (raw depth map)
104
- tmp_npy_depth = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
105
- np.save(tmp_npy_depth.name, depth_npy)
106
 
107
  # Save the grayscale depth map
108
- depth_gray = (depth_npy * 65535.0).astype(np.uint16)
109
- tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
110
- Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
111
 
112
- # Save the colored depth map
113
- tmp_colored_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
114
- depth_colored.save(tmp_colored_depth.name)
115
 
116
- return [(image, depth_colored), tmp_gray_depth.name, tmp_npy_depth.name]
117
 
118
- submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider, gray_depth_file, raw_file])
119
 
120
  example_files = os.listdir('assets/examples')
121
  example_files.sort()
122
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
123
  example_files = [[image, 768] for image in example_files]
124
- examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[depth_image_slider, gray_depth_file, raw_file], fn=on_submit)
125
 
126
 
127
  if __name__ == '__main__':
 
36
  DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
37
  dtype = torch.float32
38
  variant = None
39
+ checkpoint_path = "GonzaloMG/marigold-e2e-ft-normals"
40
  unet = UNet2DConditionModel.from_pretrained(checkpoint_path, subfolder="unet")
41
  vae = AutoencoderKL.from_pretrained(checkpoint_path, subfolder="vae")
42
  text_encoder = CLIPTextModel.from_pretrained(checkpoint_path, subfolder="text_encoder")
 
55
  pipe.unet.eval()
56
 
57
 
58
+ title = "# End-to-End Fine-Tuned Marigold for Normals Estimation"
59
  description = """ Please refer to our [paper](https://arxiv.org/abs/2409.11355) and [GitHub](https://vision.rwth-aachen.de/diffusion-e2e-ft) for more details."""
60
 
61
  @spaces.GPU
62
+ def predict_normals(image, processing_res_choice):
63
  with torch.no_grad():
64
+ pipe_out = pipe(image, denoising_steps=1, ensemble_size=1, noise="zeros", normals=True, processing_res=processing_res_choice, match_input_res=True)
65
+ pred = pipe_out.normal_np
66
+ pred_colored = pipe_out.normal_colored
67
  return pred, pred_colored
68
 
69
  with gr.Blocks(css=css) as demo:
70
  gr.Markdown(title)
71
  gr.Markdown(description)
72
+ gr.Markdown("### Normals Prediction demo")
73
 
74
  with gr.Row():
75
  input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
76
+ normals_image_slider = ImageSlider(label="Surface Normals with Slider View", elem_id='img-display-output', position=0.5)
77
 
78
  with gr.Row():
79
+ submit = gr.Button(value="Compute Normals")
80
  processing_res_choice = gr.Radio(
81
  [
82
  ("Recommended (768)", 768),
 
86
  value=768,
87
  )
88
 
89
+ raw_file = gr.File(label="Raw Normals Data (.npy)", elem_id="download")
 
90
 
91
  cmap = matplotlib.colormaps.get_cmap('Spectral_r')
92
 
 
97
  return None
98
 
99
  pil_image = Image.fromarray(image.astype('uint8'))
100
+ normal_npy, normal_colored = predict_normals(pil_image, processing_res_choice)
101
 
102
+ # Save the npy data (raw normals)
103
+ tmp_npy_normal = tempfile.NamedTemporaryFile(suffix='.npy', delete=False)
104
+ np.save(tmp_npy_normal.name, normal_npy)
105
 
106
  # Save the grayscale depth map
107
+ # depth_gray = (depth_npy * 65535.0).astype(np.uint16)
108
+ # tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
109
+ # Image.fromarray(depth_gray).save(tmp_gray_depth.name, mode="I;16")
110
 
111
+ # Save the colored normals map
112
+ tmp_colored_normal = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
113
+ normal_colored.save(tmp_colored_normal.name)
114
 
115
+ return [(image, normal_colored), tmp_npy_normal.name]
116
 
117
+ submit.click(on_submit, inputs=[input_image, processing_res_choice], outputs=[normals_image_slider, raw_file])
118
 
119
  example_files = os.listdir('assets/examples')
120
  example_files.sort()
121
  example_files = [os.path.join('assets/examples', filename) for filename in example_files]
122
  example_files = [[image, 768] for image in example_files]
123
+ examples = gr.Examples(examples=example_files, inputs=[input_image, processing_res_choice], outputs=[normals_image_slider, raw_file], fn=on_submit)
124
 
125
 
126
  if __name__ == '__main__':