lemonaddie commited on
Commit
86c9440
1 Parent(s): 9c7e27e

Update app_recon.py

Browse files
Files changed (1) hide show
  1. app_recon.py +55 -32
app_recon.py CHANGED
@@ -55,12 +55,18 @@ from torchvision.transforms import InterpolationMode
55
 
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
 
58
- stable_diffusion_repo_path = "stabilityai/stable-diffusion-2-1-unclip"
59
- vae = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder='vae')
60
- scheduler = DDIMScheduler.from_pretrained(stable_diffusion_repo_path, subfolder='scheduler')
61
- sd_image_variations_diffusers_path = 'lambdalabs/sd-image-variations-diffusers'
62
- image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_image_variations_diffusers_path, subfolder="image_encoder")
63
- feature_extractor = CLIPImageProcessor.from_pretrained(sd_image_variations_diffusers_path, subfolder="feature_extractor")
 
 
 
 
 
 
64
  unet = UNet2DConditionModel.from_pretrained('.', subfolder="unet")
65
 
66
  pipe = DepthNormalEstimationPipeline(vae=vae,
@@ -77,6 +83,16 @@ except:
77
 
78
  pipe = pipe.to(device)
79
 
 
 
 
 
 
 
 
 
 
 
80
  def sam_init():
81
  #sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_l_0b3195.pth")
82
  #model_type = "vit_l"
@@ -110,6 +126,7 @@ def sam_segment(predictor, input_image, *bbox_coords):
110
  torch.cuda.empty_cache()
111
  return Image.fromarray(out_image_bbox, mode='RGBA'), masks_bbox
112
 
 
113
  @spaces.GPU
114
  def depth_normal(img_path,
115
  denoising_steps,
@@ -124,6 +141,8 @@ def depth_normal(img_path,
124
 
125
  img = Image.open(img_path)
126
 
 
 
127
  pipe_out = pipe(
128
  img,
129
  denoising_steps=denoising_steps,
@@ -152,16 +171,14 @@ def depth_normal(img_path,
152
 
153
  return depth_colored, normal_colored, [depth_path, normal_path]
154
 
155
- @spaces.GPU
156
- def reconstruction(image, files):
157
-
158
- torch.cuda.empty_cache()
159
 
160
- img = Image.open(image)
 
161
 
162
- width, height = img.size
163
 
164
- image_rem = img.convert('RGBA').resize((width//2, height//2), Image.LANCZOS)
 
165
  image_nobg = remove(image_rem, alpha_matting=True)
166
  arr = np.asarray(image_nobg)[:,:,-1]
167
  x_nonzero = np.nonzero(arr.sum(axis=0))
@@ -172,10 +189,21 @@ def reconstruction(image, files):
172
  y_max = int(y_nonzero[0].max())
173
  masked_image, mask = sam_segment(sam_predictor, img.convert('RGB'), x_min, y_min, x_max, y_max)
174
 
175
- mask = mask[-1].resize((width, height), Image.LANCZOS)
 
 
 
 
 
 
 
 
 
176
  depth_np = np.load(files[0])
177
  normal_np = np.load(files[1])
178
 
 
 
179
  dir_name = os.path.dirname(os.path.realpath(files[0]))
180
  mask_output_temp = mask
181
  name_base = os.path.splitext(os.path.basename(files[0]))[0][:-6]
@@ -193,7 +221,7 @@ def reconstruction(image, files):
193
 
194
  torch.cuda.empty_cache()
195
 
196
- return obj_path, masked_image, [ply_path]
197
 
198
  def run_demo():
199
 
@@ -278,6 +306,8 @@ def run_demo():
278
  depth = gr.Image(interactive=False, show_label=False)
279
  with gr.Column():
280
  normal = gr.Image(interactive=False, show_label=False)
 
 
281
 
282
  with gr.Row():
283
  files = gr.Files(
@@ -287,29 +317,21 @@ def run_demo():
287
  )
288
 
289
  with gr.Row():
290
- recon_btn = gr.Button('(Beta) Is there a salient foreground object? If yes, Click here to Reconstruct its 3D model.', variant='primary', interactive=True)
291
-
292
  with gr.Row():
293
- with gr.Column():
294
- masked_image = gr.Image(interactive=False, height=320, label="Masked foreground.")
295
- with gr.Column():
296
- reconstructed_3d = gr.Model3D(
297
- label = 'Bini post-processed 3D model', height=320, interactive=False,
298
  )
299
- # reconstructed_3d = gr.Files(
300
- # label = "Bini post-processed 3D model (plyfile)",
301
- # elem_id = "download",
302
- # interactive=False,
303
- # )
304
 
305
  with gr.Row():
306
  reconstructed_file = gr.Files(
307
  label = "3D Mesh (plyfile)",
308
  elem_id = "download",
309
- interactive=False,
310
  )
311
 
312
-
313
  run_btn.click(fn=depth_normal,
314
  inputs=[input_image, denoising_steps,
315
  ensemble_size,
@@ -318,9 +340,10 @@ def run_demo():
318
  domain],
319
  outputs=[depth, normal, files]
320
  )
321
- recon_btn.click(fn=reconstruction,
322
- inputs=[input_image, files],
323
- outputs=[reconstructed_3d, masked_image, reconstructed_file]
 
324
  )
325
  demo.queue().launch(share=True, max_threads=80)
326
 
 
55
 
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
 
58
+ # stable_diffusion_repo_path = "stabilityai/stable-diffusion-2-1-unclip"
59
+ # sd_image_variations_diffusers_path = 'lambdalabs/sd-image-variations-diffusers'
60
+ # vae = AutoencoderKL.from_pretrained(stable_diffusion_repo_path, subfolder='vae')
61
+ # scheduler = DDIMScheduler.from_pretrained(stable_diffusion_repo_path, subfolder='scheduler')
62
+ # image_encoder = CLIPVisionModelWithProjection.from_pretrained(sd_image_variations_diffusers_path, subfolder="image_encoder")
63
+ # feature_extractor = CLIPImageProcessor.from_pretrained(sd_image_variations_diffusers_path, subfolder="feature_extractor")
64
+
65
+ vae = AutoencoderKL.from_pretrained("./", subfolder='vae')
66
+ scheduler = DDIMScheduler.from_pretrained("./", subfolder='scheduler')
67
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("./", subfolder="image_encoder")
68
+ feature_extractor = CLIPImageProcessor.from_pretrained("./", subfolder="feature_extractor")
69
+
70
  unet = UNet2DConditionModel.from_pretrained('.', subfolder="unet")
71
 
72
  pipe = DepthNormalEstimationPipeline(vae=vae,
 
83
 
84
  pipe = pipe.to(device)
85
 
86
+
87
+ def scale_img(img):
88
+ width, height = img.size
89
+
90
+ if min(width, height) > 480:
91
+ scale = 480 / min(width, height)
92
+ img = img.resize((int(width*scale), int(scale*height)), Image.LANCZOS)
93
+
94
+ return img
95
+
96
  def sam_init():
97
  #sam_checkpoint = os.path.join(os.path.dirname(__file__), "sam_pt", "sam_vit_l_0b3195.pth")
98
  #model_type = "vit_l"
 
126
  torch.cuda.empty_cache()
127
  return Image.fromarray(out_image_bbox, mode='RGBA'), masks_bbox
128
 
129
+
130
  @spaces.GPU
131
  def depth_normal(img_path,
132
  denoising_steps,
 
141
 
142
  img = Image.open(img_path)
143
 
144
+ img = scale_img(img)
145
+
146
  pipe_out = pipe(
147
  img,
148
  denoising_steps=denoising_steps,
 
171
 
172
  return depth_colored, normal_colored, [depth_path, normal_path]
173
 
 
 
 
 
174
 
175
+ def seg_foreground(image_file):
176
+ img = Image.open(image_file)
177
 
178
+ img = scale_img(img)
179
 
180
+ image_rem = img.convert('RGBA') #
181
+ print("after resize ", image_rem.size)
182
  image_nobg = remove(image_rem, alpha_matting=True)
183
  arr = np.asarray(image_nobg)[:,:,-1]
184
  x_nonzero = np.nonzero(arr.sum(axis=0))
 
189
  y_max = int(y_nonzero[0].max())
190
  masked_image, mask = sam_segment(sam_predictor, img.convert('RGB'), x_min, y_min, x_max, y_max)
191
 
192
+ mask = Image.fromarray(np.array(mask[-1]).astype(np.uint8) * 255)
193
+
194
+ return masked_image, mask
195
+
196
+ @spaces.GPU
197
+ def reconstruction(mask, files):
198
+
199
+ torch.cuda.empty_cache()
200
+
201
+ mask = mask[:, :, 0] > 0.5
202
  depth_np = np.load(files[0])
203
  normal_np = np.load(files[1])
204
 
205
+ h, w, _ = np.shape(normal_np)
206
+
207
  dir_name = os.path.dirname(os.path.realpath(files[0]))
208
  mask_output_temp = mask
209
  name_base = os.path.splitext(os.path.basename(files[0]))[0][:-6]
 
221
 
222
  torch.cuda.empty_cache()
223
 
224
+ return obj_path, [ply_path]
225
 
226
  def run_demo():
227
 
 
306
  depth = gr.Image(interactive=False, show_label=False)
307
  with gr.Column():
308
  normal = gr.Image(interactive=False, show_label=False)
309
+ with gr.Column():
310
+ masked_image = gr.Image(interactive=False, label="Masked foreground.")
311
 
312
  with gr.Row():
313
  files = gr.Files(
 
317
  )
318
 
319
  with gr.Row():
320
+ recon_btn = gr.Button('Is there a salient foreground object? If yes, Click here to Reconstruct its 3D model.', variant='primary', interactive=True)
321
+
322
  with gr.Row():
323
+ reconstructed_3d = gr.Model3D(
324
+ label = 'Bini post-processed 3D model', interactive=False
 
 
 
325
  )
 
 
 
 
 
326
 
327
  with gr.Row():
328
  reconstructed_file = gr.Files(
329
  label = "3D Mesh (plyfile)",
330
  elem_id = "download",
331
+ interactive=False
332
  )
333
 
334
+ mask = gr.Image(interactive=False, label="Masked foreground.", visible=False)
335
  run_btn.click(fn=depth_normal,
336
  inputs=[input_image, denoising_steps,
337
  ensemble_size,
 
340
  domain],
341
  outputs=[depth, normal, files]
342
  )
343
+ recon_btn.click(fn=seg_foreground, inputs=[input_image], outputs=[masked_image, mask]
344
+ ).success(fn=reconstruction,
345
+ inputs=[mask, files],
346
+ outputs=[reconstructed_3d, reconstructed_file]
347
  )
348
  demo.queue().launch(share=True, max_threads=80)
349