tokenid commited on
Commit
b5dfbe4
1 Parent(s): 15de7a2

lazy cache

Browse files
Files changed (2) hide show
  1. app.py +66 -36
  2. src/pose_funcs.py +3 -3
app.py CHANGED
@@ -18,6 +18,7 @@ from src.pose_estimation import load_model_from_config, estimate_poses, estimate
18
  from src.pose_funcs import find_optimal_poses
19
  from src.utils import spherical_to_cartesian, elu_to_c2w
20
 
 
21
  if torch.cuda.is_available():
22
  _device_ = 'cuda:0'
23
  else:
@@ -139,12 +140,10 @@ def image_to_tensor(img, width=256, height=256):
139
 
140
 
141
  @spaces.GPU(duration=110)
142
- def run_pose_exploration(cam_vis, image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value):
143
 
144
  seed_everything(seed_value)
145
 
146
- cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
147
-
148
  image1 = image_to_tensor(image1).to(_device_)
149
  image2 = image_to_tensor(image2).to(_device_)
150
 
@@ -186,31 +185,20 @@ def run_pose_exploration(cam_vis, image1, image2, probe_bsz, adj_bsz, adj_iters,
186
  if anchor_polar is None:
187
  anchor_polar = np.pi/2
188
 
189
- xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.))
190
- c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.]))
191
-
192
- xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius))
193
- c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
194
-
195
- cam_vis._poses = [c2w0, c2w1]
196
- fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
197
-
198
  explored_sph = (theta, azimuth, radius)
199
 
200
- return anchor_polar, explored_sph, fig, gr.update(interactive=True)
201
 
202
 
203
  @spaces.GPU(duration=110)
204
- def run_pose_refinement(cam_vis, image1, image2, anchor_polar, explored_sph, refine_iters, seed_value):
205
 
206
  seed_everything(seed_value)
207
 
208
- cam_vis.set_images([np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
 
209
 
210
- image1 = image_to_tensor(image1).to(_device_)
211
- image2 = image_to_tensor(image2).to(_device_)
212
-
213
- images = [image1, image2]
214
  images = [ img.permute(0, 2, 3, 1) for img in images ]
215
 
216
  out_poses, _, loss = find_optimal_poses(
@@ -234,10 +222,39 @@ def run_pose_refinement(cam_vis, image1, image2, anchor_polar, explored_sph, ref
234
  xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius))
235
  c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
236
 
237
- cam_vis._poses = [c2w0, c2w1]
238
  fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
239
 
240
- return final_sph, fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
 
243
  _HEADER_ = '''
@@ -267,6 +284,9 @@ def run_demo():
267
  demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models')
268
 
269
  with demo:
 
 
 
270
  gr.Markdown(_HEADER_)
271
 
272
  with gr.Row(variant='panel'):
@@ -327,8 +347,10 @@ def run_demo():
327
  ['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'],
328
  ],
329
  inputs=[input_image1, input_image2],
 
 
330
  label='Examples (Captured)',
331
- cache_examples=False,
332
  examples_per_page=5
333
  )
334
 
@@ -342,8 +364,10 @@ def run_demo():
342
  ['data/gradio_demo/christ_0.png', 'data/gradio_demo/christ_1.png'],
343
  ],
344
  inputs=[input_image1, input_image2],
 
 
345
  label='Examples (Internet)',
346
- cache_examples=False,
347
  examples_per_page=5
348
  )
349
 
@@ -357,31 +381,37 @@ def run_demo():
357
  ['data/gradio_demo/ride_horse_0.png', 'data/gradio_demo/ride_horse_1.png'],
358
  ],
359
  inputs=[input_image1, input_image2],
 
 
360
  label='Examples (Generated)',
361
- cache_examples=False,
362
  examples_per_page=5
363
  )
364
 
365
- cam_vis = CameraVisualizer([np.eye(4), np.eye(4)], ['Image 1', 'Image 2'], ['red', 'blue'])
366
-
367
- explored_sph = gr.State()
368
- anchor_polar = gr.State()
369
- refined_sph = gr.State()
370
-
371
  run_btn.click(
372
  fn=run_preprocess,
373
  inputs=[input_image1, input_image2, preprocess_chk, seed_value],
374
  outputs=[processed_image1, processed_image2],
375
  ).success(
376
- fn=partial(run_pose_exploration, cam_vis),
377
- inputs=[processed_image1, processed_image2, probe_bsz, adj_bsz, adj_iters, seed_value],
378
- outputs=[anchor_polar, explored_sph, vis_output, refine_btn]
379
  )
380
 
381
  refine_btn.click(
382
- fn=partial(run_pose_refinement, cam_vis),
383
- inputs=[processed_image1, processed_image2, anchor_polar, explored_sph, refine_iters, seed_value],
384
- outputs=[refined_sph, vis_output]
 
 
 
 
 
 
 
 
 
 
385
  )
386
 
387
  demo.launch()
 
18
  from src.pose_funcs import find_optimal_poses
19
  from src.utils import spherical_to_cartesian, elu_to_c2w
20
 
21
+
22
  if torch.cuda.is_available():
23
  _device_ = 'cuda:0'
24
  else:
 
140
 
141
 
142
  @spaces.GPU(duration=110)
143
+ def run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value):
144
 
145
  seed_everything(seed_value)
146
 
 
 
147
  image1 = image_to_tensor(image1).to(_device_)
148
  image2 = image_to_tensor(image2).to(_device_)
149
 
 
185
  if anchor_polar is None:
186
  anchor_polar = np.pi/2
187
 
 
 
 
 
 
 
 
 
 
188
  explored_sph = (theta, azimuth, radius)
189
 
190
+ return anchor_polar, explored_sph
191
 
192
 
193
  @spaces.GPU(duration=110)
194
+ def run_pose_refinement(image1, image2, est_result, refine_iters, seed_value):
195
 
196
  seed_everything(seed_value)
197
 
198
+ anchor_polar = est_result[0]
199
+ explored_sph = est_result[1]
200
 
201
+ images = [image_to_tensor(image1).to(_device_), image_to_tensor(image2).to(_device_)]
 
 
 
202
  images = [ img.permute(0, 2, 3, 1) for img in images ]
203
 
204
  out_poses, _, loss = find_optimal_poses(
 
222
  xyz1 = spherical_to_cartesian((theta + anchor_polar, 0. + azimuth, 4. + radius))
223
  c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
224
 
225
+ cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
226
  fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
227
 
228
+ return (anchor_polar, final_sph), fig
229
+
230
+
231
+ def run_example(image1, image2):
232
+
233
+ image1, image2 = run_preprocess(image1, image2, True, 0)
234
+ anchor_polar, explored_sph = run_pose_exploration(image1, image2, 16, 4, 10, 0)
235
+
236
+ return (anchor_polar, explored_sph), image1, image2
237
+
238
+
239
+ def run_or_visualize(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result):
240
+
241
+ if est_result is None:
242
+ anchor_polar, explored_sph = run_pose_exploration(image1, image2, probe_bsz, adj_bsz, adj_iters, seed_value)
243
+ else:
244
+ anchor_polar = est_result[0]
245
+ explored_sph = est_result[1]
246
+ print('Using cache result.')
247
+
248
+ xyz0 = spherical_to_cartesian((anchor_polar, 0., 4.))
249
+ c2w0 = elu_to_c2w(xyz0, np.zeros(3), np.array([0., 0., 1.]))
250
+
251
+ xyz1 = spherical_to_cartesian((explored_sph[0] + anchor_polar, 0. + explored_sph[1], 4. + explored_sph[2]))
252
+ c2w1 = elu_to_c2w(xyz1, np.zeros(3), np.array([0., 0., 1.]))
253
+
254
+ cam_vis = CameraVisualizer([c2w0, c2w1], ['Image 1', 'Image 2'], ['red', 'blue'], images=[np.asarray(image1, dtype=np.uint8), np.asarray(image2, dtype=np.uint8)])
255
+ fig = cam_vis.update_figure(5, base_radius=-1.2, font_size=16, show_background=True, show_grid=True, show_ticklabels=True)
256
+
257
+ return (anchor_polar, explored_sph), fig, gr.update(interactive=True)
258
 
259
 
260
  _HEADER_ = '''
 
284
  demo = gr.Blocks(title='ID-Pose: Sparse-view Camera Pose Estimation By Inverting Diffusion Models')
285
 
286
  with demo:
287
+
288
+ est_result = gr.JSON(visible=False)
289
+
290
  gr.Markdown(_HEADER_)
291
 
292
  with gr.Row(variant='panel'):
 
347
  ['data/gradio_demo/circo_0.png', 'data/gradio_demo/circo_1.png'],
348
  ],
349
  inputs=[input_image1, input_image2],
350
+ fn=run_example,
351
+ outputs=[est_result, processed_image1, processed_image2],
352
  label='Examples (Captured)',
353
+ cache_examples='lazy',
354
  examples_per_page=5
355
  )
356
 
 
364
  ['data/gradio_demo/christ_0.png', 'data/gradio_demo/christ_1.png'],
365
  ],
366
  inputs=[input_image1, input_image2],
367
+ fn=run_example,
368
+ outputs=[est_result, processed_image1, processed_image2],
369
  label='Examples (Internet)',
370
+ cache_examples='lazy',
371
  examples_per_page=5
372
  )
373
 
 
381
  ['data/gradio_demo/ride_horse_0.png', 'data/gradio_demo/ride_horse_1.png'],
382
  ],
383
  inputs=[input_image1, input_image2],
384
+ fn=run_example,
385
+ outputs=[est_result, processed_image1, processed_image2],
386
  label='Examples (Generated)',
387
+ cache_examples='lazy',
388
  examples_per_page=5
389
  )
390
 
 
 
 
 
 
 
391
  run_btn.click(
392
  fn=run_preprocess,
393
  inputs=[input_image1, input_image2, preprocess_chk, seed_value],
394
  outputs=[processed_image1, processed_image2],
395
  ).success(
396
+ fn=run_or_visualize,
397
+ inputs=[processed_image1, processed_image2, probe_bsz, adj_bsz, adj_iters, seed_value, est_result],
398
+ outputs=[est_result, vis_output, refine_btn]
399
  )
400
 
401
  refine_btn.click(
402
+ fn=run_pose_refinement,
403
+ inputs=[processed_image1, processed_image2, est_result, refine_iters, seed_value],
404
+ outputs=[est_result, vis_output]
405
+ )
406
+
407
+ input_image1.clear(
408
+ fn=lambda: None,
409
+ outputs=[est_result]
410
+ )
411
+
412
+ input_image2.clear(
413
+ fn=lambda: None,
414
+ outputs=[est_result]
415
  )
416
 
417
  demo.launch()
src/pose_funcs.py CHANGED
@@ -101,9 +101,9 @@ def add_pose(pose1, pose2):
101
 
102
  def create_pose_params(pose, device):
103
 
104
- theta = torch.tensor([pose[0]], requires_grad=True, device=device)
105
- azimuth = torch.tensor([pose[1]], requires_grad=True, device=device)
106
- radius = torch.tensor([pose[2]], requires_grad=True, device=device)
107
 
108
  return [theta, azimuth, radius]
109
 
 
101
 
102
  def create_pose_params(pose, device):
103
 
104
+ theta = torch.tensor([float(pose[0])], requires_grad=True, device=device)
105
+ azimuth = torch.tensor([float(pose[1])], requires_grad=True, device=device)
106
+ radius = torch.tensor([float(pose[2])], requires_grad=True, device=device)
107
 
108
  return [theta, azimuth, radius]
109