Nadine Rueegg commited on
Commit
eb37a1f
1 Parent(s): 29e68bd

enable skipping ttopt and add more example images

Browse files
datasets/test_image_crops/Picture25.jpg DELETED
Binary file (5.86 kB)
 
datasets/test_image_crops/n02087394-Rhodesian_ridgeback_n02087394_7480.png ADDED

Git LFS Details

  • SHA256: 0683be3ff95740b0aa24def3007b3b385657f2864c4668b3ea226af0a08c01df
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
datasets/test_image_crops/n02088238-basset_n02088238_13799.png ADDED

Git LFS Details

  • SHA256: f6be3590c14009f5342cac8f469abc14253371ac69e93b24841a76564532002d
  • Pointer size: 131 Bytes
  • Size of remote file: 160 kB
datasets/test_image_crops/n02091134-whippet_n02091134_18902.png ADDED

Git LFS Details

  • SHA256: 3818474079cf32cc57175056051ffd8ac9bca3ceb0c1a9f2d8e4beb9b5e3bb4f
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
datasets/test_image_crops/n02093991-Irish_terrier_n02093991_114.png ADDED

Git LFS Details

  • SHA256: b633d03c7b6c604032d3cf60355658605f7529cf64e9d98d520ce11d416aa65c
  • Pointer size: 131 Bytes
  • Size of remote file: 126 kB
datasets/test_image_crops/n02100236-German_short-haired_pointer_n02100236_5505.png ADDED

Git LFS Details

  • SHA256: 67182f410274a4f4f0f237d68289a45f9af98d98552dc00d7382c71b79de198e
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
datasets/test_image_crops/n02110185-Siberian_husky_n02110185_8154.png ADDED

Git LFS Details

  • SHA256: ef5a13020f44acbc91f8a7c0dfb4b197f83d73175816f468addabfa8d78cbcd2
  • Pointer size: 130 Bytes
  • Size of remote file: 59.6 kB
datasets/test_image_crops/n02110806-basenji_n02110806_5688.png ADDED

Git LFS Details

  • SHA256: 0674998bfc82e7bd59377a4f44d27e9480a1349743ba87a11af7d2cbc4886db5
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
datasets/test_image_crops/z_dog-parks-in-austin-850x520.jpg ADDED

Git LFS Details

  • SHA256: 02d3ff3c6280f76c5999bd64317e717df2a16d65c5b978d9f28da6d17963a42a
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
datasets/test_image_crops/z_dog_with_red_ball.jpeg ADDED
scripts/gradio_demo.py CHANGED
@@ -247,7 +247,7 @@ if not os.path.exists(save_imgs_path):
247
 
248
 
249
 
250
- def run_bite_inference(input_image, bbox=None):
251
 
252
  with open(loss_weight_path, 'r') as j:
253
  losses = json.loads(j.read())
@@ -386,8 +386,9 @@ def run_bite_inference(input_image, bbox=None):
386
  target_gc_class = sm(res['vertexwise_ground_contact'][ind_img, :, :])[None, :, 1] # values between 0 and 1
387
  target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
388
  target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
389
- vert_colors = np.repeat(255*target_gc_class.detach().cpu().numpy()[0, :, None], 3, 1)
390
- vert_colors[:, 2] = 255
 
391
  faces_prep = smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
392
  # prepare target silhouette and keypoints, from stacked hourglass predictions
393
  target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
@@ -405,152 +406,165 @@ def run_bite_inference(input_image, bbox=None):
405
  ignore_pose_optimization = False
406
 
407
 
408
- ##########################################################################################################
409
- # start optimizing for this image
410
- n_iter = 301 # how many iterations are desired? (+1)
411
- loop = tqdm(range(n_iter))
412
- per_loop_lst = []
413
- list_error_procrustes = []
414
- for i in loop:
415
- # for the first 150 iterations steps we don't allow vertex shifts
416
- if i == 0:
417
- current_i = 0
418
- if ignore_pose_optimization:
419
- current_optimizer = nopose_optimizer
420
- else:
421
- current_optimizer = optimizer
422
- current_scheduler = scheduler
423
- current_weight_name = 'weight'
424
- # after 150 iteration steps we start with vertex shifts
425
- elif i == 150:
426
- current_i = 0
427
- if ignore_pose_optimization:
428
- current_optimizer = nopose_optimizer_vshift
429
- else:
430
- current_optimizer = optimizer_vshift
431
- current_scheduler = scheduler_vshift
432
- current_weight_name = 'weight_vshift'
433
- # set up arap loss
434
- if losses["arap"]['weight_vshift'] > 0.0:
435
- with torch.no_grad():
436
- torch_mesh_comparison = Meshes(smal_verts.detach(), faces_prep.detach())
437
- arap_loss = Arap_Loss(meshes=torch_mesh_comparison, device=device)
438
- # is there a laplacian loss similar as in coarse-to-fine?
439
- if losses["lapctf"]['weight_vshift'] > 0.0:
440
- torch_verts_comparison = smal_verts.detach().clone()
441
- smal_model_type_downsampling = '39dogs_norm'
442
- smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type_downsampling]['smal_model_path']).replace('.pkl', '_template.npz')
443
- smal_downsampling_npz_path = os.path.join(root_smal_downsampling, smal_downsampling_npz_name)
444
- data = np.load(smal_downsampling_npz_path, encoding='latin1', allow_pickle=True)
445
- adjmat = data['A'][0]
446
- laplacian_ctf = LaplacianCTF(adjmat, device=device)
447
- else:
448
- pass
449
-
450
-
451
- current_optimizer.zero_grad()
452
-
453
  # get 3d smal model
454
  optimed_pose_with_glob = get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d)
455
  optimed_trans = torch.cat((optimed_trans_xy, optimed_trans_z), dim=1)
456
  smal_verts, keyp_3d, _ = smal(beta=optimed_betas, betas_limbs=optimed_betas_limbs, pose=optimed_pose_with_glob, vert_off_compact=optimed_vert_off_compact, trans=optimed_trans, keyp_conf='olive', get_skin=True)
457
 
458
- # render silhouette and keypoints
459
- pred_silh_images, pred_keyp_raw = silh_renderer(vertices=smal_verts, points=keyp_3d, faces=faces_prep, focal_lengths=optimed_camera_flength)
460
- pred_keyp = pred_keyp_raw[:, :24, :]
461
-
462
- # save silhouette reprojection visualization
463
- if i==0:
464
- img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
465
- img_silh.save(root_out_path_details + name + '_silh_ainit.png')
466
- my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
467
- my_mesh_tri.export(root_out_path_details + name + '_res_ainit.obj')
468
-
469
- # silhouette loss
470
- diff_silh = torch.abs(pred_silh_images[0, 0, :, :] - target_hg_silh)
471
- losses['silhouette']['value'] = diff_silh.mean()
472
-
473
- # keypoint_loss
474
- output_kp_resh = (pred_keyp[0, :, :]).reshape((-1, 2))
475
- losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt() * \
476
- weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
477
- max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
478
- # losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
479
-
480
- # pose priors on refined pose
481
- losses['pose_legs_side']['value'] = leg_sideway_error(optimed_pose_with_glob)
482
- losses['pose_legs_tors']['value'] = leg_torsion_error(optimed_pose_with_glob)
483
- losses['pose_tail_side']['value'] = tail_sideway_error(optimed_pose_with_glob)
484
- losses['pose_tail_tors']['value'] = tail_torsion_error(optimed_pose_with_glob)
485
- losses['pose_spine_side']['value'] = spine_sideway_error(optimed_pose_with_glob)
486
- losses['pose_spine_tors']['value'] = spine_torsion_error(optimed_pose_with_glob)
487
-
488
- # ground contact loss
489
- sel_verts = torch.index_select(smal_verts, dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((batch_size, remeshing_relevant_faces.shape[0], 3, 3))
490
- verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
491
-
492
- # gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
493
- gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, isflat, istouching)
494
-
495
- losses['gc_plane']['value'] = torch.mean(gc_errors_plane)
496
- losses['gc_belowplane']['value'] = torch.mean(gc_errors_under_plane)
497
-
498
- # edge length of the predicted mesh
499
- if (losses["edge"][current_weight_name] + losses["normal"][ current_weight_name] + losses["laplacian"][ current_weight_name]) > 0:
500
- torch_mesh = Meshes(smal_verts, faces_prep.detach())
501
- losses["edge"]['value'] = mesh_edge_loss(torch_mesh)
502
- # mesh normal consistency
503
- losses["normal"]['value'] = mesh_normal_consistency(torch_mesh)
504
- # mesh laplacian smoothing
505
- losses["laplacian"]['value'] = mesh_laplacian_smoothing(torch_mesh, method="uniform")
506
-
507
- # arap loss
508
- if losses["arap"][current_weight_name] > 0.0:
509
- torch_mesh = Meshes(smal_verts, faces_prep.detach())
510
- losses["arap"]['value'] = arap_loss(torch_mesh)
511
-
512
- # laplacian loss for comparison (from coarse-to-fine paper)
513
- if losses["lapctf"][current_weight_name] > 0.0:
514
- verts_refine = smal_verts
515
- loss_almost_arap, loss_smooth = laplacian_ctf(verts_refine, torch_verts_comparison)
516
- losses["lapctf"]['value'] = loss_almost_arap
517
-
518
- # Weighted sum of the losses
519
- total_loss = 0.0
520
- for k in ['keyp', 'silhouette', 'pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_tors', 'pose_spine_side', 'gc_plane', 'gc_belowplane', 'edge', 'normal', 'laplacian', 'arap', 'lapctf']:
521
- if losses[k][current_weight_name] > 0.0:
522
- total_loss += losses[k]['value'] * losses[k][current_weight_name]
523
-
524
- # calculate gradient and make optimization step
525
- total_loss.backward(retain_graph=True) #
526
- current_optimizer.step()
527
- current_scheduler.step(total_loss)
528
- loop.set_description(f"Body Fitting = {total_loss.item():.3f}")
529
-
530
- # save the result three times (0, 150, 300)
531
- if i % 150 == 0:
532
- # save silhouette image
533
- img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
534
- img_silh.save(root_out_path_details + name + '_silh_e' + format(i, '03d') + '.png')
535
- # save image overlay
536
- visualizations = silh_renderer.get_visualization_nograd(smal_verts, faces_prep, optimed_camera_flength, color=0)
537
- pred_tex = visualizations[0, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
538
- # out_path = root_out_path_details + name + '_tex_pred_e' + format(i, '03d') + '.png'
539
- # plt.imsave(out_path, pred_tex)
540
- input_image_np = img_inp.copy()
541
- im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
542
- pred_tex_max = np.max(pred_tex, axis=2)
543
- im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
544
- out_path = root_out_path + name + '_comp_pred_e' + format(i, '03d') + '.png'
545
- plt.imsave(out_path, im_masked)
546
- # save mesh
547
- my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
548
- my_mesh_tri.visual.vertex_colors = vert_colors
549
- my_mesh_tri.export(root_out_path + name + '_res_e' + format(i, '03d') + '.obj')
550
- # save focal length (together with the mesh this is enough to create an overlay in blender)
551
- out_file_flength = root_out_path_details + name + '_flength_e' + format(i, '03d') # + '.npz'
552
- np.save(out_file_flength, optimed_camera_flength.detach().cpu().numpy())
553
- current_i += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
  # prepare output mesh
556
  mesh = my_mesh_tri # all_results[0]['mesh_posed']
@@ -570,7 +584,7 @@ def run_bite_inference(input_image, bbox=None):
570
  # -------------------------------------------------------------------------------------------------------------------- #
571
 
572
 
573
- def run_complete_inference(img_path_or_img, crop_choice):
574
  # depending on crop_choice: run faster r-cnn or take the input image directly
575
  if crop_choice == "input image is cropped":
576
  if isinstance(img_path_or_img, str):
@@ -581,8 +595,12 @@ def run_complete_inference(img_path_or_img, crop_choice):
581
  output_interm_bbox = None
582
  else:
583
  output_interm_image, output_interm_bbox = run_bbox_inference(img_path_or_img.copy())
 
 
 
 
584
  # run barc inference
585
- result_gltf = run_bite_inference(img_path_or_img, output_interm_bbox)
586
  # add white border to image for nicer alignment
587
  output_interm_image_vis = np.concatenate((255*np.ones_like(output_interm_image), output_interm_image, 255*np.ones_like(output_interm_image)), axis=1)
588
  return [result_gltf, result_gltf, output_interm_image_vis]
@@ -602,7 +620,8 @@ description = '''
602
 
603
  #### Description
604
  This is a demo for BITE (*B*eyond Priors for *I*mproved *T*hree-{D} Dog Pose *E*stimation).
605
- You can either submit a cropped image or choose the option to run a pretrained Faster R-CNN in order to obtain a bounding box.
 
606
  Please have a look at the examples below.
607
  <details>
608
 
@@ -623,11 +642,9 @@ Please have a look at the examples below.
623
  #### Image Sources
624
  * Stanford extra image dataset
625
  * Images from google search engine
626
- * https://www.dogtrainingnation.com/wp-content/uploads/2015/02/keep-dog-training-sessions-short.jpg
627
  * https://thumbs.dreamstime.com/b/hund-und-seine-neue-hundeh%C3%BCtte-36757551.jpg
628
- * https://www.mydearwhippet.com/wp-content/uploads/2021/04/whippet-temperament-2.jpg
629
- * https://media.istockphoto.com/photos/ibizan-hound-at-the-shore-in-winter-picture-id1092705644?k=20&m=1092705644&s=612x612&w=0&h=ppwg92s9jI8GWnk22SOR_DWWNP8b2IUmLXSQmVey5Ss=
630
-
631
 
632
  </details>
633
  '''
@@ -644,16 +661,16 @@ random.shuffle(example_images)
644
  examples = []
645
  for img in example_images:
646
  if os.path.basename(img)[:2] == 'z_':
647
- examples.append([img, "use Faster R-CNN to get a bounding box"])
648
  else:
649
- examples.append([img, "input image is cropped"])
650
 
651
  demo = gr.Interface(
652
  fn=run_complete_inference,
653
  description=description,
654
- # inputs=gr.Image(type="filepath", label="Input Image"),
655
  inputs=[gr.Image(label="Input Image"),
656
  gr.Radio(["input image is cropped", "use Faster R-CNN to get a bounding box"], value="use Faster R-CNN to get a bounding box", label="Crop Choice"),
 
657
  ],
658
  outputs=[
659
  gr.Model3D(
@@ -665,8 +682,8 @@ demo = gr.Interface(
665
  examples=examples,
666
  thumbnail="bite_thumbnail.png",
667
  allow_flagging="never",
668
- cache_examples=True,
669
  examples_per_page=14,
670
  )
671
 
672
- demo.launch(share=True)
 
247
 
248
 
249
 
250
+ def run_bite_inference(input_image, bbox=None, apply_ttopt=True):
251
 
252
  with open(loss_weight_path, 'r') as j:
253
  losses = json.loads(j.read())
 
386
  target_gc_class = sm(res['vertexwise_ground_contact'][ind_img, :, :])[None, :, 1] # values between 0 and 1
387
  target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
388
  target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
389
+ # vert_colors = np.repeat(255*target_gc_class.detach().cpu().numpy()[0, :, None], 3, 1)
390
+ # vert_colors[:, 2] = 255
391
+ vert_colors = np.ones_like(np.repeat(target_gc_class.detach().cpu().numpy()[0, :, None], 3, 1)) * 255
392
  faces_prep = smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
393
  # prepare target silhouette and keypoints, from stacked hourglass predictions
394
  target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
 
406
  ignore_pose_optimization = False
407
 
408
 
409
+ if not apply_ttopt:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  # get 3d smal model
411
  optimed_pose_with_glob = get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d)
412
  optimed_trans = torch.cat((optimed_trans_xy, optimed_trans_z), dim=1)
413
  smal_verts, keyp_3d, _ = smal(beta=optimed_betas, betas_limbs=optimed_betas_limbs, pose=optimed_pose_with_glob, vert_off_compact=optimed_vert_off_compact, trans=optimed_trans, keyp_conf='olive', get_skin=True)
414
 
415
+ # save mesh
416
+ my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
417
+ my_mesh_tri.visual.vertex_colors = vert_colors
418
+ my_mesh_tri.export(root_out_path + name + '_res_e000' + '.obj')
419
+
420
+ else:
421
+
422
+ ##########################################################################################################
423
+ # start optimizing for this image
424
+ n_iter = 301 # how many iterations are desired? (+1)
425
+ loop = tqdm(range(n_iter))
426
+ per_loop_lst = []
427
+ list_error_procrustes = []
428
+ for i in loop:
429
+ # for the first 150 iterations steps we don't allow vertex shifts
430
+ if i == 0:
431
+ current_i = 0
432
+ if ignore_pose_optimization:
433
+ current_optimizer = nopose_optimizer
434
+ else:
435
+ current_optimizer = optimizer
436
+ current_scheduler = scheduler
437
+ current_weight_name = 'weight'
438
+ # after 150 iteration steps we start with vertex shifts
439
+ elif i == 150:
440
+ current_i = 0
441
+ if ignore_pose_optimization:
442
+ current_optimizer = nopose_optimizer_vshift
443
+ else:
444
+ current_optimizer = optimizer_vshift
445
+ current_scheduler = scheduler_vshift
446
+ current_weight_name = 'weight_vshift'
447
+ # set up arap loss
448
+ if losses["arap"]['weight_vshift'] > 0.0:
449
+ with torch.no_grad():
450
+ torch_mesh_comparison = Meshes(smal_verts.detach(), faces_prep.detach())
451
+ arap_loss = Arap_Loss(meshes=torch_mesh_comparison, device=device)
452
+ # is there a laplacian loss similar as in coarse-to-fine?
453
+ if losses["lapctf"]['weight_vshift'] > 0.0:
454
+ torch_verts_comparison = smal_verts.detach().clone()
455
+ smal_model_type_downsampling = '39dogs_norm'
456
+ smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type_downsampling]['smal_model_path']).replace('.pkl', '_template.npz')
457
+ smal_downsampling_npz_path = os.path.join(root_smal_downsampling, smal_downsampling_npz_name)
458
+ data = np.load(smal_downsampling_npz_path, encoding='latin1', allow_pickle=True)
459
+ adjmat = data['A'][0]
460
+ laplacian_ctf = LaplacianCTF(adjmat, device=device)
461
+ else:
462
+ pass
463
+
464
+
465
+ current_optimizer.zero_grad()
466
+
467
+ # get 3d smal model
468
+ optimed_pose_with_glob = get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d)
469
+ optimed_trans = torch.cat((optimed_trans_xy, optimed_trans_z), dim=1)
470
+ smal_verts, keyp_3d, _ = smal(beta=optimed_betas, betas_limbs=optimed_betas_limbs, pose=optimed_pose_with_glob, vert_off_compact=optimed_vert_off_compact, trans=optimed_trans, keyp_conf='olive', get_skin=True)
471
+
472
+ # render silhouette and keypoints
473
+ pred_silh_images, pred_keyp_raw = silh_renderer(vertices=smal_verts, points=keyp_3d, faces=faces_prep, focal_lengths=optimed_camera_flength)
474
+ pred_keyp = pred_keyp_raw[:, :24, :]
475
+
476
+ # save silhouette reprojection visualization
477
+ if i==0:
478
+ img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
479
+ img_silh.save(root_out_path_details + name + '_silh_ainit.png')
480
+ my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
481
+ my_mesh_tri.export(root_out_path_details + name + '_res_ainit.obj')
482
+
483
+ # silhouette loss
484
+ diff_silh = torch.abs(pred_silh_images[0, 0, :, :] - target_hg_silh)
485
+ losses['silhouette']['value'] = diff_silh.mean()
486
+
487
+ # keypoint_loss
488
+ output_kp_resh = (pred_keyp[0, :, :]).reshape((-1, 2))
489
+ losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt() * \
490
+ weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
491
+ max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
492
+ # losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
493
+
494
+ # pose priors on refined pose
495
+ losses['pose_legs_side']['value'] = leg_sideway_error(optimed_pose_with_glob)
496
+ losses['pose_legs_tors']['value'] = leg_torsion_error(optimed_pose_with_glob)
497
+ losses['pose_tail_side']['value'] = tail_sideway_error(optimed_pose_with_glob)
498
+ losses['pose_tail_tors']['value'] = tail_torsion_error(optimed_pose_with_glob)
499
+ losses['pose_spine_side']['value'] = spine_sideway_error(optimed_pose_with_glob)
500
+ losses['pose_spine_tors']['value'] = spine_torsion_error(optimed_pose_with_glob)
501
+
502
+ # ground contact loss
503
+ sel_verts = torch.index_select(smal_verts, dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((batch_size, remeshing_relevant_faces.shape[0], 3, 3))
504
+ verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
505
+
506
+ # gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
507
+ gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, isflat, istouching)
508
+
509
+ losses['gc_plane']['value'] = torch.mean(gc_errors_plane)
510
+ losses['gc_belowplane']['value'] = torch.mean(gc_errors_under_plane)
511
+
512
+ # edge length of the predicted mesh
513
+ if (losses["edge"][current_weight_name] + losses["normal"][ current_weight_name] + losses["laplacian"][ current_weight_name]) > 0:
514
+ torch_mesh = Meshes(smal_verts, faces_prep.detach())
515
+ losses["edge"]['value'] = mesh_edge_loss(torch_mesh)
516
+ # mesh normal consistency
517
+ losses["normal"]['value'] = mesh_normal_consistency(torch_mesh)
518
+ # mesh laplacian smoothing
519
+ losses["laplacian"]['value'] = mesh_laplacian_smoothing(torch_mesh, method="uniform")
520
+
521
+ # arap loss
522
+ if losses["arap"][current_weight_name] > 0.0:
523
+ torch_mesh = Meshes(smal_verts, faces_prep.detach())
524
+ losses["arap"]['value'] = arap_loss(torch_mesh)
525
+
526
+ # laplacian loss for comparison (from coarse-to-fine paper)
527
+ if losses["lapctf"][current_weight_name] > 0.0:
528
+ verts_refine = smal_verts
529
+ loss_almost_arap, loss_smooth = laplacian_ctf(verts_refine, torch_verts_comparison)
530
+ losses["lapctf"]['value'] = loss_almost_arap
531
+
532
+ # Weighted sum of the losses
533
+ total_loss = 0.0
534
+ for k in ['keyp', 'silhouette', 'pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_tors', 'pose_spine_side', 'gc_plane', 'gc_belowplane', 'edge', 'normal', 'laplacian', 'arap', 'lapctf']:
535
+ if losses[k][current_weight_name] > 0.0:
536
+ total_loss += losses[k]['value'] * losses[k][current_weight_name]
537
+
538
+ # calculate gradient and make optimization step
539
+ total_loss.backward(retain_graph=True) #
540
+ current_optimizer.step()
541
+ current_scheduler.step(total_loss)
542
+ loop.set_description(f"Body Fitting = {total_loss.item():.3f}")
543
+
544
+ # save the result three times (0, 150, 300)
545
+ if i % 150 == 0:
546
+ # save silhouette image
547
+ img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
548
+ img_silh.save(root_out_path_details + name + '_silh_e' + format(i, '03d') + '.png')
549
+ # save image overlay
550
+ visualizations = silh_renderer.get_visualization_nograd(smal_verts, faces_prep, optimed_camera_flength, color=0)
551
+ pred_tex = visualizations[0, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
552
+ # out_path = root_out_path_details + name + '_tex_pred_e' + format(i, '03d') + '.png'
553
+ # plt.imsave(out_path, pred_tex)
554
+ input_image_np = img_inp.copy()
555
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
556
+ pred_tex_max = np.max(pred_tex, axis=2)
557
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
558
+ out_path = root_out_path + name + '_comp_pred_e' + format(i, '03d') + '.png'
559
+ plt.imsave(out_path, im_masked)
560
+ # save mesh
561
+ my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
562
+ my_mesh_tri.visual.vertex_colors = vert_colors
563
+ my_mesh_tri.export(root_out_path + name + '_res_e' + format(i, '03d') + '.obj')
564
+ # save focal length (together with the mesh this is enough to create an overlay in blender)
565
+ out_file_flength = root_out_path_details + name + '_flength_e' + format(i, '03d') # + '.npz'
566
+ np.save(out_file_flength, optimed_camera_flength.detach().cpu().numpy())
567
+ current_i += 1
568
 
569
  # prepare output mesh
570
  mesh = my_mesh_tri # all_results[0]['mesh_posed']
 
584
  # -------------------------------------------------------------------------------------------------------------------- #
585
 
586
 
587
+ def run_complete_inference(img_path_or_img, crop_choice, use_ttopt):
588
  # depending on crop_choice: run faster r-cnn or take the input image directly
589
  if crop_choice == "input image is cropped":
590
  if isinstance(img_path_or_img, str):
 
595
  output_interm_bbox = None
596
  else:
597
  output_interm_image, output_interm_bbox = run_bbox_inference(img_path_or_img.copy())
598
+ if use_ttopt == "enable test-time optimization":
599
+ apply_ttopt = True
600
+ else:
601
+ apply_ttopt = False
602
  # run barc inference
603
+ result_gltf = run_bite_inference(img_path_or_img, output_interm_bbox, apply_ttopt)
604
  # add white border to image for nicer alignment
605
  output_interm_image_vis = np.concatenate((255*np.ones_like(output_interm_image), output_interm_image, 255*np.ones_like(output_interm_image)), axis=1)
606
  return [result_gltf, result_gltf, output_interm_image_vis]
 
620
 
621
  #### Description
622
  This is a demo for BITE (*B*eyond Priors for *I*mproved *T*hree-{D} Dog Pose *E*stimation).
623
+ You can either submit a cropped image or choose the option to run a pretrained Faster R-CNN in order to obtain a bounding box.
624
+ Furthermore, you have the option to skip test-time optimization, which will lead to faster calculation at the cost of less accurate results.
625
  Please have a look at the examples below.
626
  <details>
627
 
 
642
  #### Image Sources
643
  * Stanford extra image dataset
644
  * Images from google search engine
645
+ * https://www.dogtrainingnation.com/wp-content/uploads/2015/02/keep-dog-training-sessions-short.jpghttps://www.ktvb.com/article/news/local/dogs-can-now-be-off-leash-again-in-boises-ann-morrison-park-optimist-youth-sports-complex/277-609691113
646
  * https://thumbs.dreamstime.com/b/hund-und-seine-neue-hundeh%C3%BCtte-36757551.jpg
647
+ * https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRnx2sHnnLU3zy1XnJB7BvGUR9spmAh5bxTUg&usqp=CAU
 
 
648
 
649
  </details>
650
  '''
 
661
  examples = []
662
  for img in example_images:
663
  if os.path.basename(img)[:2] == 'z_':
664
+ examples.append([img, "use Faster R-CNN to get a bounding box", "enable test-time optimization"])
665
  else:
666
+ examples.append([img, "input image is cropped", "enable test-time optimization"])
667
 
668
  demo = gr.Interface(
669
  fn=run_complete_inference,
670
  description=description,
 
671
  inputs=[gr.Image(label="Input Image"),
672
  gr.Radio(["input image is cropped", "use Faster R-CNN to get a bounding box"], value="use Faster R-CNN to get a bounding box", label="Crop Choice"),
673
+ gr.Radio(["enable test-time optimization", "skip test-time optimization"], value="enable test-time optimization", label="Test Time Optimization"),
674
  ],
675
  outputs=[
676
  gr.Model3D(
 
682
  examples=examples,
683
  thumbnail="bite_thumbnail.png",
684
  allow_flagging="never",
685
+ cache_examples=False, # True,
686
  examples_per_page=14,
687
  )
688
 
689
+ demo.launch() # share=True)
src/combined_model/model_shape_v7_withref_withgraphcnn.py CHANGED
@@ -304,7 +304,7 @@ class ModelRefinement(nn.Module):
304
  num_downsampling = 1
305
  smal_model_type = '39dogs_norm'
306
  smal = SMAL(smal_model_type=smal_model_type, template_name='neutral')
307
- ROOT_smal_downsampling = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/'
308
  smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path']).replace('.pkl', '_template.npz')
309
  smal_downsampling_npz_path = ROOT_smal_downsampling + smal_downsampling_npz_name # 'data/mesh_downsampling.npz'
310
  self.my_custom_smal_dog_mesh = Mesh(filename=smal_downsampling_npz_path, num_downsampling=num_downsampling, nsize=1, body_model=smal) # , device=device)
 
304
  num_downsampling = 1
305
  smal_model_type = '39dogs_norm'
306
  smal = SMAL(smal_model_type=smal_model_type, template_name='neutral')
307
+ ROOT_smal_downsampling = os.path.join(os.path.dirname(__file__), './../../data/graphcmr_data/')
308
  smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path']).replace('.pkl', '_template.npz')
309
  smal_downsampling_npz_path = ROOT_smal_downsampling + smal_downsampling_npz_name # 'data/mesh_downsampling.npz'
310
  self.my_custom_smal_dog_mesh = Mesh(filename=smal_downsampling_npz_path, num_downsampling=num_downsampling, nsize=1, body_model=smal) # , device=device)