Nadine Rueegg commited on
Commit
753fd9a
1 Parent(s): 45abb23

initial commit with code and data

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +60 -0
  2. README.md +7 -10
  3. packages.txt +8 -0
  4. requirements.txt +15 -0
  5. scripts/gradio_demo.py +672 -0
  6. src/__init__.py +0 -0
  7. src/bps_2d/bps_for_segmentation.py +114 -0
  8. src/combined_model/__init__.py +0 -0
  9. src/combined_model/helper.py +207 -0
  10. src/combined_model/helper3.py +17 -0
  11. src/combined_model/loss_image_to_3d_refinement.py +216 -0
  12. src/combined_model/loss_image_to_3d_withbreedrel.py +342 -0
  13. src/combined_model/loss_utils/loss_arap.py +153 -0
  14. src/combined_model/loss_utils/loss_laplacian_mesh_comparison.py +45 -0
  15. src/combined_model/loss_utils/loss_sdf.py +122 -0
  16. src/combined_model/loss_utils/loss_utils.py +191 -0
  17. src/combined_model/loss_utils/loss_utils_gc.py +179 -0
  18. src/combined_model/model_shape_v7_withref_withgraphcnn.py +927 -0
  19. src/combined_model/train_main_image_to_3d_wbr_withref.py +955 -0
  20. src/combined_model/train_main_image_to_3d_withbreedrel.py +496 -0
  21. src/configs/SMAL_configs.py +230 -0
  22. src/configs/anipose_data_info.py +74 -0
  23. src/configs/barc_cfg_defaults.py +121 -0
  24. src/configs/barc_cfg_train.yaml +24 -0
  25. src/configs/barc_loss_weights_allzeros.json +30 -0
  26. src/configs/barc_loss_weights_with3dcgloss_higherbetaloss_v2_dm39dnnv3v2.json +30 -0
  27. src/configs/data_info.py +115 -0
  28. src/configs/dataset_path_configs.py +21 -0
  29. src/configs/dog_breeds/dog_breed_class.py +170 -0
  30. src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml +23 -0
  31. src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml +23 -0
  32. src/configs/refinement_cfg_train_withvertexwisegc_isflat_csmorestanding.yaml +31 -0
  33. src/configs/refinement_loss_weights_withgc_withvertexwise_addnonflat.json +20 -0
  34. src/configs/ttopt_loss_weights/bite_loss_weights_ttopt.json +77 -0
  35. src/configs/ttopt_loss_weights/ttopt_loss_weights_v2c_withlapcft_v2.json +77 -0
  36. src/graph_networks/__init__.py +0 -0
  37. src/graph_networks/graphcmr/__init__.py +0 -0
  38. src/graph_networks/graphcmr/get_downsampled_mesh_npz.py +84 -0
  39. src/graph_networks/graphcmr/graph_cnn.py +53 -0
  40. src/graph_networks/graphcmr/graph_cnn_groundcontact.py +101 -0
  41. src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage.py +174 -0
  42. src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage_includingresnet.py +170 -0
  43. src/graph_networks/graphcmr/graph_layers.py +125 -0
  44. src/graph_networks/graphcmr/graphcnn_coarse_to_fine_animal_pose.py +97 -0
  45. src/graph_networks/graphcmr/my_remarks.txt +11 -0
  46. src/graph_networks/graphcmr/pytorch_coma_mesh_operations.py +282 -0
  47. src/graph_networks/graphcmr/utils_mesh.py +138 -0
  48. src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py +245 -0
  49. src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py +213 -0
  50. src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py +317 -0
LICENSE ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ License
2
+ Software Copyright License for non-commercial scientific research purposes
3
+ Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use BITE data, model and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
4
+
5
+ Ownership / Licensees
6
+ The Software and the associated materials has been developed at the
7
+
8
+ Max Planck Institute for Intelligent Systems
9
+ and
10
+ ETH Zurich
11
+
12
+ Any copyright or patent right is owned by and proprietary material of the
13
+
14
+ Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”)
15
+
16
+ hereinafter the “Licensor”.
17
+
18
+ License Grant
19
+ Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
20
+
21
+ To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization;
22
+ To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
23
+ Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
24
+
25
+ The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
26
+
27
+ No Distribution
28
+ The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
29
+
30
+ Disclaimer of Representations and Warranties
31
+ You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
32
+
33
+ Limitation of Liability
34
+ Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
35
+ Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
36
+ Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
37
+ The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
38
+
39
+ No Maintenance Services
40
+ You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
41
+
42
+ Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
43
+
44
+ Publications using the Data & Software
45
+ You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
46
+
47
+ Citation:
48
+
49
+
50
+ @inproceedings{BITE:2023,
51
+ title = {BITE: Beyond priors for Improved Three-D dog pose Estimation},
52
+ author = {Rueegg, Nadine and Tripathi, Shashank and Schindler, Konrad and Black, Michael J. and Zuffi, Silvia},
53
+ booktitle = {under review},
54
+ year = {2023}
55
+ url = {https://bite.is.tue.mpg.de}
56
+ }
57
+ Commercial licensing opportunities
58
+ For commercial uses of the Data & Software, please send email to [email protected]
59
+
60
+ This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
README.md CHANGED
@@ -1,12 +1,9 @@
1
- ---
2
- title: Bite Gradio
3
- emoji: 👀
4
- colorFrom: blue
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.35.2
8
- app_file: app.py
9
  pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ title: BITE
2
+ emoji: 🐩 🐶 🐕
3
+ colorFrom: pink
4
+ colorTo: green
 
5
  sdk: gradio
6
+ sdk_version: 3.0.2
7
+ app_file: ./scripts/gradio_demo.py
8
  pinned: false
9
+ python_version: 3.7.6
 
 
packages.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ libgl1
2
+ unzip
3
+ ffmpeg
4
+ libsm6
5
+ libxext6
6
+ libgl1-mesa-dri
7
+ libegl1-mesa
8
+ libgbm1
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.6.0
2
+ torchvision==0.7.0
3
+ pytorch3d==0.2.5
4
+ kornia==0.4.0
5
+ matplotlib
6
+ opencv-python
7
+ trimesh
8
+ scipy
9
+ chumpy
10
+ pymp
11
+ importlib-resources
12
+ pycocotools
13
+ openpyxl
14
+ dominate
15
+ git+https://github.com/runa91/FrEIA.git
scripts/gradio_demo.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # aenv_new_icon_2
3
+
4
+ # was used for ttoptv6_sketchfab_v16: python src/test_time_optimization/ttopt_fromref_v6_sketchfab.py --workers 12 --save-images True --config refinement_cfg_visualization_withgc_withvertexwisegc_isflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar --sketchfab 1
5
+
6
+ # for stanext images:
7
+ # python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar -s ttopt_vtest1
8
+ # for all images from the folder datasets/test_image_crops:
9
+ # python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar -s ttopt_vtest2
10
+
11
+ '''import os
12
+ os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
13
+ os.environ["CUDA_VISIBLE_DEVICES"]="0"
14
+ try:
15
+ # os.system("pip install --upgrade torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html")
16
+ os.system("pip install --upgrade torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/cu101/torch_stable.html")
17
+ except Exception as e:
18
+ print(e)'''
19
+
20
+ import argparse
21
+ import os.path
22
+ import json
23
+ import numpy as np
24
+ import pickle as pkl
25
+ import csv
26
+ from distutils.util import strtobool
27
+ import torch
28
+ from torch import nn
29
+ import torch.backends.cudnn
30
+ from torch.nn import DataParallel
31
+ from torch.utils.data import DataLoader
32
+ from collections import OrderedDict
33
+ import glob
34
+ from tqdm import tqdm
35
+ from dominate import document
36
+ from dominate.tags import *
37
+ from PIL import Image
38
+ from matplotlib import pyplot as plt
39
+ import trimesh
40
+ import cv2
41
+ import shutil
42
+ import random
43
+ import gradio as gr
44
+
45
+ import torchvision
46
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
47
+ import torchvision.transforms as T
48
+ from pytorch3d.structures import Meshes
49
+ from pytorch3d.loss import mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency
50
+
51
+
52
+ import sys
53
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
54
+
55
+ from combined_model.train_main_image_to_3d_wbr_withref import do_validation_epoch
56
+ from combined_model.model_shape_v7_withref_withgraphcnn import ModelImageTo3d_withshape_withproj
57
+
58
+ from configs.barc_cfg_defaults import get_cfg_defaults, update_cfg_global_with_yaml, get_cfg_global_updated
59
+
60
+ from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d
61
+ from stacked_hourglass.datasets.utils_dataset_selection import get_evaluation_dataset, get_sketchfab_evaluation_dataset, get_crop_evaluation_dataset, get_norm_dict, get_single_crop_dataset_from_image
62
+
63
+ from test_time_optimization.bite_inference_model_for_ttopt import BITEInferenceModel
64
+ from smal_pytorch.smal_model.smal_torch_new import SMAL
65
+ from configs.SMAL_configs import SMAL_MODEL_CONFIG
66
+ from smal_pytorch.renderer.differentiable_renderer import SilhRenderer
67
+ from test_time_optimization.utils.utils_ttopt import reset_loss_values, get_optimed_pose_with_glob
68
+
69
+ from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error
70
+ from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch
71
+ from combined_model.loss_utils.loss_arap import Arap_Loss
72
+ from combined_model.loss_utils.loss_laplacian_mesh_comparison import LaplacianCTF # (coarse to fine animal)
73
+ from graph_networks import graphcmr # .utils_mesh import Mesh
74
+ from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
75
+
76
+ random.seed(0)
77
+
78
+ print(
79
+ "torch: ", torch.__version__,
80
+ "\ntorchvision: ", torchvision.__version__,
81
+ )
82
+
83
+
84
+ def get_prediction(model, img_path_or_img, confidence=0.5):
85
+ """
86
+ see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g
87
+ get_prediction
88
+ parameters:
89
+ - img_path - path of the input image
90
+ - confidence - threshold value for prediction score
91
+ method:
92
+ - Image is obtained from the image path
93
+ - the image is converted to image tensor using PyTorch's Transforms
94
+ - image is passed through the model to get the predictions
95
+ - class, box coordinates are obtained, but only prediction score > threshold
96
+ are chosen.
97
+ """
98
+ if isinstance(img_path_or_img, str):
99
+ img = Image.open(img_path_or_img).convert('RGB')
100
+ else:
101
+ img = img_path_or_img
102
+ transform = T.Compose([T.ToTensor()])
103
+ img = transform(img)
104
+ pred = model([img])
105
+ # pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
106
+ pred_class = list(pred[0]['labels'].numpy())
107
+ pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
108
+ pred_score = list(pred[0]['scores'].detach().numpy())
109
+ try:
110
+ pred_t = [pred_score.index(x) for x in pred_score if x>confidence][-1]
111
+ pred_boxes = pred_boxes[:pred_t+1]
112
+ pred_class = pred_class[:pred_t+1]
113
+ return pred_boxes, pred_class, pred_score
114
+ except:
115
+ print('no bounding box with a score that is high enough found! -> work on full image')
116
+ return None, None, None
117
+
118
+
119
+ def detect_object(model, img_path_or_img, confidence=0.5, rect_th=2, text_size=0.5, text_th=1):
120
+ """
121
+ see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g
122
+ object_detection_api
123
+ parameters:
124
+ - img_path_or_img - path of the input image
125
+ - confidence - threshold value for prediction score
126
+ - rect_th - thickness of bounding box
127
+ - text_size - size of the class label text
128
+ - text_th - thichness of the text
129
+ method:
130
+ - prediction is obtained from get_prediction method
131
+ - for each prediction, bounding box is drawn and text is written
132
+ with opencv
133
+ - the final image is displayed
134
+ """
135
+ boxes, pred_cls, pred_scores = get_prediction(model, img_path_or_img, confidence)
136
+ if isinstance(img_path_or_img, str):
137
+ img = cv2.imread(img_path_or_img)
138
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
139
+ else:
140
+ img = img_path_or_img
141
+ is_first = True
142
+ bbox = None
143
+ if boxes is not None:
144
+ for i in range(len(boxes)):
145
+ cls = pred_cls[i]
146
+ if cls == 18 and bbox is None:
147
+ cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
148
+ # cv2.putText(img, pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
149
+ # cv2.putText(img, str(pred_scores[i]), boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
150
+ bbox = boxes[i]
151
+ return img, bbox
152
+
153
+
154
+ # -------------------------------------------------------------------------------------------------------------------- #
155
+ model_bbox = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
156
+ model_bbox.eval()
157
+
158
+ def run_bbox_inference(input_image):
159
+ # load configs
160
+ cfg = get_cfg_global_updated()
161
+ out_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples', 'test2.png')
162
+ img, bbox = detect_object(model=model_bbox, img_path_or_img=input_image, confidence=0.5)
163
+ fig = plt.figure() # plt.figure(figsize=(20,30))
164
+ plt.imsave(out_path, img)
165
+ return img, bbox
166
+
167
+
168
+
169
+ # -------------------------------------------------------------------------------------------------------------------- #
170
+ # python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar
171
+ args_config = "refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml"
172
+ args_model_file_complete = "cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar"
173
+ args_suffix = "ttopt_v0"
174
+ args_loss_weight_ttopt_path = "bite_loss_weights_ttopt.json"
175
+ args_workers = 12
176
+ # -------------------------------------------------------------------------------------------------------------------- #
177
+
178
+
179
+
180
+ # load configs
181
+ # step 1: load default configs
182
+ # step 2: load updates from .yaml file
183
+ path_config = os.path.join(get_cfg_defaults().barc_dir, 'src', 'configs', args_config)
184
+ update_cfg_global_with_yaml(path_config)
185
+ cfg = get_cfg_global_updated()
186
+
187
+ # define path to load the trained model
188
+ path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args_model_file_complete)
189
+
190
+ # define and create paths to save results
191
+ out_sub_name = cfg.data.VAL_OPT + '_' + cfg.data.DATASET + '_' + args_suffix + '/'
192
+ root_out_path = os.path.join(os.path.dirname(path_model_file_complete).replace(cfg.paths.ROOT_CHECKPOINT_PATH, cfg.paths.ROOT_OUT_PATH + 'results_gradio/'), out_sub_name)
193
+ root_out_path_details = root_out_path + 'details/'
194
+ if not os.path.exists(root_out_path): os.makedirs(root_out_path)
195
+ if not os.path.exists(root_out_path_details): os.makedirs(root_out_path_details)
196
+ print('root_out_path: ' + root_out_path)
197
+
198
+ # other paths
199
+ root_data_path = os.path.join(os.path.dirname(__file__), '../', 'data')
200
+ # downsampling as used in graph neural network
201
+ root_smal_downsampling = os.path.join(root_data_path, 'graphcmr_data')
202
+ # remeshing as used for ground contact
203
+ remeshing_path = os.path.join(root_data_path, 'smal_data_remeshed', 'uniform_surface_sampling', 'my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl')
204
+
205
+ loss_weight_path = os.path.join(os.path.dirname(__file__), '../', 'src', 'configs', 'ttopt_loss_weights', args_loss_weight_ttopt_path)
206
+ print(loss_weight_path)
207
+
208
+
209
+ # Select the hardware device to use for training.
210
+ if torch.cuda.is_available() and cfg.device=='cuda':
211
+ device = torch.device('cuda', torch.cuda.current_device())
212
+ torch.backends.cudnn.benchmark = False # True
213
+ else:
214
+ device = torch.device('cpu')
215
+
216
+ print('structure_pose_net: ' + cfg.params.STRUCTURE_POSE_NET)
217
+ print('refinement network type: ' + cfg.params.REF_NET_TYPE)
218
+ print('smal_model_type: ' + cfg.smal.SMAL_MODEL_TYPE)
219
+
220
+ # prepare complete model
221
+ norm_dict = get_norm_dict(data_info=None, device=device)
222
+ bite_model = BITEInferenceModel(cfg, path_model_file_complete, norm_dict)
223
+ smal_model_type = bite_model.smal_model_type
224
+ logscale_part_list = SMAL_MODEL_CONFIG[smal_model_type]['logscale_part_list'] # ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
225
+ smal = SMAL(smal_model_type=smal_model_type, template_name='neutral', logscale_part_list=logscale_part_list).to(device)
226
+ silh_renderer = SilhRenderer(image_size=256).to(device)
227
+
228
+ # load loss modules -> not necessary!
229
+ # loss_module = Loss(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device)
230
+ # loss_module_ref = LossRef(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device)
231
+
232
+ # remeshing utils
233
+ with open(remeshing_path, 'rb') as fp:
234
+ remeshing_dict = pkl.load(fp)
235
+ remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
236
+ remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
237
+
238
+
239
+
240
+
241
+ # create path for output files
242
+ save_imgs_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples')
243
+ if not os.path.exists(save_imgs_path):
244
+ os.makedirs(save_imgs_path)
245
+
246
+
247
+
248
+
249
+
250
+ def run_bite_inference(input_image, bbox=None):
251
+
252
+ with open(loss_weight_path, 'r') as j:
253
+ losses = json.loads(j.read())
254
+ shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
255
+ print(losses)
256
+
257
+ # prepare dataset and dataset loader
258
+ val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
259
+
260
+ # summarize information for normalization
261
+ norm_dict = get_norm_dict(stanext_data_info, device)
262
+ # get keypoint weights
263
+ keypoint_weights = torch.tensor(stanext_data_info.keypoint_weights, dtype=torch.float)[None, :].to(device)
264
+
265
+
266
+ # prepare progress bar
267
+ iterable = enumerate(val_loader) # the length of this iterator should be 1
268
+ progress = None
269
+ if True: # not quiet:
270
+ progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
271
+ iterable = progress
272
+ ind_img_tot = 0
273
+
274
+ for i, (input, target_dict) in iterable:
275
+ batch_size = input.shape[0]
276
+ # prepare variables, put them on the right device
277
+ for key in target_dict.keys():
278
+ if key == 'breed_index':
279
+ target_dict[key] = target_dict[key].long().to(device)
280
+ elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
281
+ target_dict[key] = target_dict[key].float().to(device)
282
+ elif key == 'has_seg':
283
+ target_dict[key] = target_dict[key].to(device)
284
+ else:
285
+ pass
286
+ input = input.float().to(device)
287
+
288
+ # get starting values for the optimization
289
+ preds_dict = bite_model.get_all_results(input)
290
+ # res_normal_and_ref = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['normal', 'ref'])
291
+ res = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['ref'])['ref']
292
+ bs = res['pose_rotmat'].shape[0]
293
+ all_pose_6d = rotmat_to_rot6d(res['pose_rotmat'][:, None, 1:, :, :].clone().reshape((-1, 3, 3))).reshape((bs, -1, 6)) # [bs, 34, 6]
294
+ all_orient_6d = rotmat_to_rot6d(res['pose_rotmat'][:, None, :1, :, :].clone().reshape((-1, 3, 3))).reshape((bs, -1, 6)) # [bs, 1, 6]
295
+
296
+
297
+ ind_img = 0
298
+ name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
299
+
300
+ print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name)
301
+ ind_img_tot += 1
302
+ batch_size = 1
303
+
304
+ # save initial visualizations
305
+ # save the image with keypoints as predicted by the stacked hourglass
306
+ pred_unp_prep = torch.cat((res['hg_keyp_256'][ind_img, :, :].detach(), res['hg_keyp_scores'][ind_img, :, :]), 1)
307
+ inp_img = input[ind_img, :, :, :].detach().clone()
308
+ out_path = root_out_path + name + '_hg_key.png'
309
+ save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.01, print_scores=True, ratio_in_out=1.0) # threshold=0.3
310
+ # save the input image
311
+ img_inp = input[ind_img, :, :, :].clone()
312
+ for t, m, s in zip(img_inp, stanext_data_info.rgb_mean, stanext_data_info.rgb_stddev): t.add_(m) # inverse to transforms.color_normalize()
313
+ img_inp = img_inp.detach().cpu().numpy().transpose(1, 2, 0)
314
+ img_init = Image.fromarray(np.uint8(255*img_inp)).convert('RGB')
315
+ img_init.save(root_out_path_details + name + '_img_ainit.png')
316
+ # save ground truth silhouette (for visualization only, it is not used during the optimization)
317
+ target_img_silh = Image.fromarray(np.uint8(255*target_dict['silh'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
318
+ target_img_silh.save(root_out_path_details + name + '_target_silh.png')
319
+ # save the silhouette as predicted by the stacked hourglass
320
+ hg_img_silh = Image.fromarray(np.uint8(255*res['hg_silh_prep'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
321
+ hg_img_silh.save(root_out_path + name + '_hg_silh.png')
322
+
323
+ # initialize the variables over which we want to optimize
324
+ optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
325
+ optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
326
+ optimed_betas = res['betas'][ind_img, None, :].to(device).clone().detach().requires_grad_(True) # [1,30]
327
+ optimed_trans_xy = res['trans'][ind_img, None, :2].to(device).clone().detach().requires_grad_(True)
328
+ optimed_trans_z =res['trans'][ind_img, None, 2:3].to(device).clone().detach().requires_grad_(True)
329
+ optimed_camera_flength = res['flength'][ind_img, None, :].to(device).clone().detach().requires_grad_(True) # [1,1]
330
+ n_vert_comp = 2*smal.n_center + 3*smal.n_left
331
+ optimed_vert_off_compact = torch.tensor(np.zeros((batch_size, n_vert_comp)), dtype=torch.float,
332
+ device=device,
333
+ requires_grad=True)
334
+ assert len(logscale_part_list) == 7
335
+ new_betas_limb_lengths = res['betas_limbs'][ind_img, None, :]
336
+ optimed_betas_limbs = new_betas_limb_lengths.to(device).clone().detach().requires_grad_(True) # [1,7]
337
+
338
+ # define the optimizers
339
+ optimizer = torch.optim.SGD(
340
+ # [optimed_pose, optimed_trans_xy, optimed_betas, optimed_betas_limbs, optimed_orient, optimed_vert_off_compact],
341
+ [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_pose_6d, optimed_orient_6d, optimed_betas, optimed_betas_limbs],
342
+ lr=5*1e-4, # 1e-3,
343
+ momentum=0.9)
344
+ optimizer_vshift = torch.optim.SGD(
345
+ [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_pose_6d, optimed_orient_6d, optimed_betas, optimed_betas_limbs, optimed_vert_off_compact],
346
+ lr=1e-4, # 1e-4,
347
+ momentum=0.9)
348
+ nopose_optimizer = torch.optim.SGD(
349
+ # [optimed_pose, optimed_trans_xy, optimed_betas, optimed_betas_limbs, optimed_orient, optimed_vert_off_compact],
350
+ [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_orient_6d, optimed_betas, optimed_betas_limbs],
351
+ lr=5*1e-4, # 1e-3,
352
+ momentum=0.9)
353
+ nopose_optimizer_vshift = torch.optim.SGD(
354
+ [optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_orient_6d, optimed_betas, optimed_betas_limbs, optimed_vert_off_compact],
355
+ lr=1e-4, # 1e-4,
356
+ momentum=0.9)
357
+ # define schedulers
358
+ patience = 5
359
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
360
+ optimizer,
361
+ mode='min',
362
+ factor=0.5,
363
+ verbose=0,
364
+ min_lr=1e-5,
365
+ patience=patience)
366
+ scheduler_vshift = torch.optim.lr_scheduler.ReduceLROnPlateau(
367
+ optimizer_vshift,
368
+ mode='min',
369
+ factor=0.5,
370
+ verbose=0,
371
+ min_lr=1e-5,
372
+ patience=patience)
373
+
374
+ # set all loss values to 0
375
+ losses = reset_loss_values(losses)
376
+
377
+ # prepare all the target labels: keypoints, silhouette, ground contact, ...
378
+ with torch.no_grad():
379
+ thr_kp = 0.2
380
+ kp_weights = res['hg_keyp_scores']
381
+ kp_weights[res['hg_keyp_scores']<thr_kp] = 0
382
+ weights_resh = kp_weights[ind_img, None, :, :].reshape((-1)) # target_dict['tpts'][:, :, 2].reshape((-1))
383
+ keyp_w_resh = keypoint_weights.repeat((batch_size, 1)).reshape((-1))
384
+ # prepare predicted ground contact labels
385
+ sm = nn.Softmax(dim=1)
386
+ target_gc_class = sm(res['vertexwise_ground_contact'][ind_img, :, :])[None, :, 1] # values between 0 and 1
387
+ target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
388
+ target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
389
+ vert_colors = np.repeat(255*target_gc_class.detach().cpu().numpy()[0, :, None], 3, 1)
390
+ vert_colors[:, 2] = 255
391
+ faces_prep = smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
392
+ # prepare target silhouette and keypoints, from stacked hourglass predictions
393
+ target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
394
+ target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
395
+ # find out if ground contact constraints should be used for the image at hand
396
+ # print('is flat: ' + str(res['isflat_prep'][ind_img]))
397
+ if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
398
+ isflat = [True]
399
+ else:
400
+ isflat = [False]
401
+ if target_gc_class_remeshed_prep.sum() > 3:
402
+ istouching = [True]
403
+ else:
404
+ istouching = [False]
405
+ ignore_pose_optimization = False
406
+
407
+
408
+ ##########################################################################################################
409
+ # start optimizing for this image
410
+ n_iter = 301 # how many iterations are desired? (+1)
411
+ loop = tqdm(range(n_iter))
412
+ per_loop_lst = []
413
+ list_error_procrustes = []
414
+ for i in loop:
415
+ # for the first 150 iterations steps we don't allow vertex shifts
416
+ if i == 0:
417
+ current_i = 0
418
+ if ignore_pose_optimization:
419
+ current_optimizer = nopose_optimizer
420
+ else:
421
+ current_optimizer = optimizer
422
+ current_scheduler = scheduler
423
+ current_weight_name = 'weight'
424
+ # after 150 iteration steps we start with vertex shifts
425
+ elif i == 150:
426
+ current_i = 0
427
+ if ignore_pose_optimization:
428
+ current_optimizer = nopose_optimizer_vshift
429
+ else:
430
+ current_optimizer = optimizer_vshift
431
+ current_scheduler = scheduler_vshift
432
+ current_weight_name = 'weight_vshift'
433
+ # set up arap loss
434
+ if losses["arap"]['weight_vshift'] > 0.0:
435
+ with torch.no_grad():
436
+ torch_mesh_comparison = Meshes(smal_verts.detach(), faces_prep.detach())
437
+ arap_loss = Arap_Loss(meshes=torch_mesh_comparison, device=device)
438
+ # is there a laplacian loss similar as in coarse-to-fine?
439
+ if losses["lapctf"]['weight_vshift'] > 0.0:
440
+ torch_verts_comparison = smal_verts.detach().clone()
441
+ smal_model_type_downsampling = '39dogs_norm'
442
+ smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type_downsampling]['smal_model_path']).replace('.pkl', '_template.npz')
443
+ smal_downsampling_npz_path = os.path.join(root_smal_downsampling, smal_downsampling_npz_name)
444
+ data = np.load(smal_downsampling_npz_path, encoding='latin1', allow_pickle=True)
445
+ adjmat = data['A'][0]
446
+ laplacian_ctf = LaplacianCTF(adjmat, device=device)
447
+ else:
448
+ pass
449
+
450
+
451
+ current_optimizer.zero_grad()
452
+
453
+ # get 3d smal model
454
+ optimed_pose_with_glob = get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d)
455
+ optimed_trans = torch.cat((optimed_trans_xy, optimed_trans_z), dim=1)
456
+ smal_verts, keyp_3d, _ = smal(beta=optimed_betas, betas_limbs=optimed_betas_limbs, pose=optimed_pose_with_glob, vert_off_compact=optimed_vert_off_compact, trans=optimed_trans, keyp_conf='olive', get_skin=True)
457
+
458
+ # render silhouette and keypoints
459
+ pred_silh_images, pred_keyp_raw = silh_renderer(vertices=smal_verts, points=keyp_3d, faces=faces_prep, focal_lengths=optimed_camera_flength)
460
+ pred_keyp = pred_keyp_raw[:, :24, :]
461
+
462
+ # save silhouette reprojection visualization
463
+ if i==0:
464
+ img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
465
+ img_silh.save(root_out_path_details + name + '_silh_ainit.png')
466
+ my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
467
+ my_mesh_tri.export(root_out_path_details + name + '_res_ainit.obj')
468
+
469
+ # silhouette loss
470
+ diff_silh = torch.abs(pred_silh_images[0, 0, :, :] - target_hg_silh)
471
+ losses['silhouette']['value'] = diff_silh.mean()
472
+
473
+ # keypoint_loss
474
+ output_kp_resh = (pred_keyp[0, :, :]).reshape((-1, 2))
475
+ losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt() * \
476
+ weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
477
+ max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
478
+ # losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
479
+
480
+ # pose priors on refined pose
481
+ losses['pose_legs_side']['value'] = leg_sideway_error(optimed_pose_with_glob)
482
+ losses['pose_legs_tors']['value'] = leg_torsion_error(optimed_pose_with_glob)
483
+ losses['pose_tail_side']['value'] = tail_sideway_error(optimed_pose_with_glob)
484
+ losses['pose_tail_tors']['value'] = tail_torsion_error(optimed_pose_with_glob)
485
+ losses['pose_spine_side']['value'] = spine_sideway_error(optimed_pose_with_glob)
486
+ losses['pose_spine_tors']['value'] = spine_torsion_error(optimed_pose_with_glob)
487
+
488
+ # ground contact loss
489
+ sel_verts = torch.index_select(smal_verts, dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((batch_size, remeshing_relevant_faces.shape[0], 3, 3))
490
+ verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
491
+
492
+ # gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
493
+ gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, isflat, istouching)
494
+
495
+ losses['gc_plane']['value'] = torch.mean(gc_errors_plane)
496
+ losses['gc_belowplane']['value'] = torch.mean(gc_errors_under_plane)
497
+
498
+ # edge length of the predicted mesh
499
+ if (losses["edge"][current_weight_name] + losses["normal"][ current_weight_name] + losses["laplacian"][ current_weight_name]) > 0:
500
+ torch_mesh = Meshes(smal_verts, faces_prep.detach())
501
+ losses["edge"]['value'] = mesh_edge_loss(torch_mesh)
502
+ # mesh normal consistency
503
+ losses["normal"]['value'] = mesh_normal_consistency(torch_mesh)
504
+ # mesh laplacian smoothing
505
+ losses["laplacian"]['value'] = mesh_laplacian_smoothing(torch_mesh, method="uniform")
506
+
507
+ # arap loss
508
+ if losses["arap"][current_weight_name] > 0.0:
509
+ torch_mesh = Meshes(smal_verts, faces_prep.detach())
510
+ losses["arap"]['value'] = arap_loss(torch_mesh)
511
+
512
+ # laplacian loss for comparison (from coarse-to-fine paper)
513
+ if losses["lapctf"][current_weight_name] > 0.0:
514
+ verts_refine = smal_verts
515
+ loss_almost_arap, loss_smooth = laplacian_ctf(verts_refine, torch_verts_comparison)
516
+ losses["lapctf"]['value'] = loss_almost_arap
517
+
518
+ # Weighted sum of the losses
519
+ total_loss = 0.0
520
+ for k in ['keyp', 'silhouette', 'pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_tors', 'pose_spine_side', 'gc_plane', 'gc_belowplane', 'edge', 'normal', 'laplacian', 'arap', 'lapctf']:
521
+ if losses[k][current_weight_name] > 0.0:
522
+ total_loss += losses[k]['value'] * losses[k][current_weight_name]
523
+
524
+ # calculate gradient and make optimization step
525
+ total_loss.backward(retain_graph=True) #
526
+ current_optimizer.step()
527
+ current_scheduler.step(total_loss)
528
+ loop.set_description(f"Body Fitting = {total_loss.item():.3f}")
529
+
530
+ # save the result three times (0, 150, 300)
531
+ if i % 150 == 0:
532
+ # save silhouette image
533
+ img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
534
+ img_silh.save(root_out_path_details + name + '_silh_e' + format(i, '03d') + '.png')
535
+ # save image overlay
536
+ visualizations = silh_renderer.get_visualization_nograd(smal_verts, faces_prep, optimed_camera_flength, color=0)
537
+ pred_tex = visualizations[0, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
538
+ # out_path = root_out_path_details + name + '_tex_pred_e' + format(i, '03d') + '.png'
539
+ # plt.imsave(out_path, pred_tex)
540
+ input_image_np = img_inp.copy()
541
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
542
+ pred_tex_max = np.max(pred_tex, axis=2)
543
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
544
+ out_path = root_out_path + name + '_comp_pred_e' + format(i, '03d') + '.png'
545
+ plt.imsave(out_path, im_masked)
546
+ # save mesh
547
+ my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
548
+ my_mesh_tri.visual.vertex_colors = vert_colors
549
+ my_mesh_tri.export(root_out_path + name + '_res_e' + format(i, '03d') + '.obj')
550
+ # save focal length (together with the mesh this is enough to create an overlay in blender)
551
+ out_file_flength = root_out_path_details + name + '_flength_e' + format(i, '03d') # + '.npz'
552
+ np.save(out_file_flength, optimed_camera_flength.detach().cpu().numpy())
553
+ current_i += 1
554
+
555
+ # prepare output mesh
556
+ mesh = my_mesh_tri # all_results[0]['mesh_posed']
557
+ mesh.apply_transform([[-1, 0, 0, 0],
558
+ [0, -1, 0, 0],
559
+ [0, 0, 1, 1],
560
+ [0, 0, 0, 1]])
561
+ result_path = os.path.join(save_imgs_path, test_name_list[0] + '_z')
562
+ mesh.export(file_obj=result_path + '.glb')
563
+ result_gltf = result_path + '.glb'
564
+ return result_gltf
565
+
566
+
567
+
568
+
569
+
570
+ # -------------------------------------------------------------------------------------------------------------------- #
571
+
572
+
573
+ def run_complete_inference(img_path_or_img, crop_choice):
574
+ # depending on crop_choice: run faster r-cnn or take the input image directly
575
+ if crop_choice == "input image is cropped":
576
+ if isinstance(img_path_or_img, str):
577
+ img = cv2.imread(img_path_or_img)
578
+ output_interm_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
579
+ else:
580
+ output_interm_image = img_path_or_img
581
+ output_interm_bbox = None
582
+ else:
583
+ output_interm_image, output_interm_bbox = run_bbox_inference(img_path_or_img.copy())
584
+ # run barc inference
585
+ result_gltf = run_bite_inference(img_path_or_img, output_interm_bbox)
586
+ # add white border to image for nicer alignment
587
+ output_interm_image_vis = np.concatenate((255*np.ones_like(output_interm_image), output_interm_image, 255*np.ones_like(output_interm_image)), axis=1)
588
+ return [result_gltf, result_gltf, output_interm_image_vis]
589
+
590
+
591
+
592
+
593
+ ########################################################################################################################
594
+
595
+ # see: https://huggingface.co/spaces/radames/PIFu-Clothed-Human-Digitization/blob/main/PIFu/spaces.py
596
+
597
+ description = '''
598
+ # BITE
599
+
600
+ #### Project Page
601
+ * https://bite.is.tue.mpg.de/
602
+
603
+ #### Description
604
+ This is a demo for BITE (*B*eyond Priors for *I*mproved *T*hree-{D} Dog Pose *E*stimation).
605
+ You can either submit a cropped image or choose the option to run a pretrained Faster R-CNN in order to obtain a bounding box.
606
+ Please have a look at the examples below.
607
+ <details>
608
+
609
+ <summary>More</summary>
610
+
611
+ #### Citation
612
+
613
+ ```
614
+ @inproceedings{bite2023rueegg,
615
+ title = {{BITE}: Beyond Priors for Improved Three-{D} Dog Pose Estimation},
616
+ author = {R\"uegg, Nadine and Tripathi, Shashank and Schindler, Konrad and Black, Michael J. and Zuffi, Silvia},
617
+ booktitle = {IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)},
618
+ pages = {8867-8876},
619
+ year = {2023},
620
+ }
621
+ ```
622
+
623
+ #### Image Sources
624
+ * Stanford extra image dataset
625
+ * Images from google search engine
626
+ * https://www.dogtrainingnation.com/wp-content/uploads/2015/02/keep-dog-training-sessions-short.jpg
627
+ * https://thumbs.dreamstime.com/b/hund-und-seine-neue-hundeh%C3%BCtte-36757551.jpg
628
+ * https://www.mydearwhippet.com/wp-content/uploads/2021/04/whippet-temperament-2.jpg
629
+ * https://media.istockphoto.com/photos/ibizan-hound-at-the-shore-in-winter-picture-id1092705644?k=20&m=1092705644&s=612x612&w=0&h=ppwg92s9jI8GWnk22SOR_DWWNP8b2IUmLXSQmVey5Ss=
630
+
631
+
632
+ </details>
633
+ '''
634
+
635
+
636
+
637
+
638
+
639
+
640
+ example_images = sorted(glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.jpg')) + glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.png')))
641
+ random.shuffle(example_images)
642
+ # example_images.reverse()
643
+ # examples = [[img, "input image is cropped"] for img in example_images]
644
+ examples = []
645
+ for img in example_images:
646
+ if os.path.basename(img)[:2] == 'z_':
647
+ examples.append([img, "use Faster R-CNN to get a bounding box"])
648
+ else:
649
+ examples.append([img, "input image is cropped"])
650
+
651
+ demo = gr.Interface(
652
+ fn=run_complete_inference,
653
+ description=description,
654
+ # inputs=gr.Image(type="filepath", label="Input Image"),
655
+ inputs=[gr.Image(label="Input Image"),
656
+ gr.Radio(["input image is cropped", "use Faster R-CNN to get a bounding box"], value="use Faster R-CNN to get a bounding box", label="Crop Choice"),
657
+ ],
658
+ outputs=[
659
+ gr.Model3D(
660
+ clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
661
+ gr.File(label="Download 3D Model"),
662
+ gr.Image(label="Bounding Box (Faster R-CNN prediction)"),
663
+
664
+ ],
665
+ examples=examples,
666
+ thumbnail="bite_thumbnail.png",
667
+ allow_flagging="never",
668
+ cache_examples=True,
669
+ examples_per_page=14,
670
+ )
671
+
672
+ demo.launch(share=True)
src/__init__.py ADDED
File without changes
src/bps_2d/bps_for_segmentation.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # code idea from https://github.com/sergeyprokudin/bps
3
+
4
+ import os
5
+ import numpy as np
6
+ from PIL import Image
7
+ import time
8
+ import scipy
9
+ import scipy.spatial
10
+ import pymp
11
+
12
+
13
+ #####################
14
+ QUERY_POINTS = np.asarray([30, 34, 31, 55, 29, 84, 35, 108, 34, 145, 29, 171, 27,
15
+ 196, 29, 228, 58, 35, 61, 55, 57, 83, 56, 109, 63, 148, 58, 164, 57, 197, 60,
16
+ 227, 81, 26, 87, 58, 85, 87, 89, 117, 86, 142, 89, 172, 84, 197, 88, 227, 113,
17
+ 32, 116, 58, 112, 88, 118, 113, 109, 147, 114, 173, 119, 201, 113, 229, 139,
18
+ 29, 141, 59, 142, 93, 139, 117, 146, 147, 141, 173, 142, 201, 143, 227, 170,
19
+ 26, 173, 59, 166, 90, 174, 117, 176, 141, 169, 175, 167, 198, 172, 227, 198,
20
+ 30, 195, 59, 204, 85, 198, 116, 195, 140, 198, 175, 194, 193, 199, 227, 221,
21
+ 26, 223, 57, 227, 83, 227, 113, 227, 140, 226, 173, 230, 196, 228, 229]).reshape((64, 2))
22
+ #####################
23
+
24
+ class SegBPS():
25
+
26
+ def __init__(self, query_points=QUERY_POINTS, size=256):
27
+ self.size = size
28
+ self.query_points = query_points
29
+ row, col = np.indices((self.size, self.size))
30
+ self.indices_rc = np.stack((row, col), axis=2) # (256, 256, 2)
31
+ self.pts_aranged = np.arange(64)
32
+ return
33
+
34
+ def _do_kdtree(self, combined_x_y_arrays, points):
35
+ # see https://stackoverflow.com/questions/10818546/finding-index-of-nearest-
36
+ # point-in-numpy-arrays-of-x-and-y-coordinates
37
+ mytree = scipy.spatial.cKDTree(combined_x_y_arrays)
38
+ dist, indexes = mytree.query(points)
39
+ return indexes
40
+
41
+ def calculate_bps_points(self, seg, thr=0.5, vis=False, out_path=None):
42
+ # seg: input segmentation image of shape (256, 256) with values between 0 and 1
43
+ query_val = seg[self.query_points[:, 0], self.query_points[:, 1]]
44
+ pts_fg = self.pts_aranged[query_val>=thr]
45
+ pts_bg = self.pts_aranged[query_val<thr]
46
+ candidate_inds_bg = self.indices_rc[seg<thr]
47
+ candidate_inds_fg = self.indices_rc[seg>=thr]
48
+ if candidate_inds_bg.shape[0] == 0:
49
+ candidate_inds_bg = np.ones((1, 2)) * 128 # np.zeros((1, 2))
50
+ if candidate_inds_fg.shape[0] == 0:
51
+ candidate_inds_fg = np.ones((1, 2)) * 128 # np.zeros((1, 2))
52
+ # calculate nearest points
53
+ all_nearest_points = np.zeros((64, 2))
54
+ all_nearest_points[pts_fg, :] = candidate_inds_bg[self._do_kdtree(candidate_inds_bg, self.query_points[pts_fg, :]), :]
55
+ all_nearest_points[pts_bg, :] = candidate_inds_fg[self._do_kdtree(candidate_inds_fg, self.query_points[pts_bg, :]), :]
56
+ all_nearest_points_01 = all_nearest_points / 255.
57
+ if vis:
58
+ self.visualize_result(seg, all_nearest_points, out_path=out_path)
59
+ return all_nearest_points_01
60
+
61
+ def calculate_bps_points_batch(self, seg_batch, thr=0.5, vis=False, out_path=None):
62
+ # seg_batch: input segmentation image of shape (bs, 256, 256) with values between 0 and 1
63
+ bs = seg_batch.shape[0]
64
+ all_nearest_points_01_batch = np.zeros((bs, self.query_points.shape[0], 2))
65
+ for ind in range(0, bs): # 0.25
66
+ seg = seg_batch[ind, :, :]
67
+ all_nearest_points_01 = self.calculate_bps_points(seg, thr=thr, vis=vis, out_path=out_path)
68
+ all_nearest_points_01_batch[ind, :, :] = all_nearest_points_01
69
+ return all_nearest_points_01_batch
70
+
71
+ def visualize_result(self, seg, all_nearest_points, out_path=None):
72
+ import matplotlib as mpl
73
+ mpl.use('Agg')
74
+ import matplotlib.pyplot as plt
75
+ # img: (256, 256, 3)
76
+ img = (np.stack((seg, seg, seg), axis=2) * 155).astype(np.int)
77
+ if out_path is None:
78
+ ind_img = 0
79
+ out_path = '../test_img' + str(ind_img) + '.png'
80
+ fig, ax = plt.subplots()
81
+ plt.imshow(img)
82
+ plt.gca().set_axis_off()
83
+ plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
84
+ plt.margins(0,0)
85
+ ratio_in_out = 1 # 255
86
+ for idx, (y, x) in enumerate(self.query_points):
87
+ x = int(x*ratio_in_out)
88
+ y = int(y*ratio_in_out)
89
+ plt.scatter([x], [y], marker="x", s=50)
90
+ x2 = int(all_nearest_points[idx, 1])
91
+ y2 = int(all_nearest_points[idx, 0])
92
+ plt.scatter([x2], [y2], marker="o", s=50)
93
+ plt.plot([x, x2], [y, y2])
94
+ plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
95
+ plt.close()
96
+ return
97
+
98
+
99
+
100
+
101
+
102
+ if __name__ == "__main__":
103
+ ind_img = 2 # 4
104
+ path_seg_top = '...../pytorch-stacked-hourglass/results/dogs_hg8_ks_24_v1/test/'
105
+ path_seg = os.path.join(path_seg_top, 'seg_big_' + str(ind_img) + '.png')
106
+ img = np.asarray(Image.open(path_seg))
107
+ # min is 0.004, max is 0.9
108
+ # low values are background, high values are foreground
109
+ seg = img[:, :, 1] / 255.
110
+ # calculate points
111
+ bps = SegBPS()
112
+ bps.calculate_bps_points(seg, thr=0.5, vis=False, out_path=None)
113
+
114
+
src/combined_model/__init__.py ADDED
File without changes
src/combined_model/helper.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.backends.cudnn
5
+ import torch.nn.parallel
6
+ from tqdm import tqdm
7
+ import os
8
+ import pathlib
9
+ from matplotlib import pyplot as plt
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ import trimesh
14
+
15
+ import sys
16
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
17
+ from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
18
+ from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
19
+ from metrics.metrics import Metrics
20
+ from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS
21
+
22
+
23
+ # GOAL: have all the functions from the validation and visual epoch together
24
+
25
+
26
+ '''
27
+ save_imgs_path = ...
28
+ prefix = ''
29
+ input # this is the image
30
+ data_info
31
+ target_dict
32
+ render_all
33
+ model
34
+
35
+
36
+ vertices_smal = output_reproj['vertices_smal']
37
+ flength = output_unnorm['flength']
38
+ hg_keyp_norm = output['keypoints_norm']
39
+ hg_keyp_scores = output['keypoints_scores']
40
+ betas = output_reproj['betas']
41
+ betas_limbs = output_reproj['betas_limbs']
42
+ zz = output_reproj['z']
43
+ pose_rotmat = output_unnorm['pose_rotmat']
44
+ trans = output_unnorm['trans']
45
+ pred_keyp = output_reproj['keyp_2d']
46
+ pred_silh = output_reproj['silh']
47
+ '''
48
+
49
+ #################################################
50
+
51
+ def eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=False):
52
+ device = input.device
53
+ curr_batch_size = input.shape[0]
54
+ # render predicted 3d models
55
+ visualizations = model.render_vis_nograd(vertices=vertices_smal,
56
+ focal_lengths=flength,
57
+ color=0) # color=2)
58
+ for ind_img in range(len(target_dict['index'])):
59
+ try:
60
+ # import pdb; pdb.set_trace()
61
+ if test_name_list is not None:
62
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
63
+ img_name = img_name.split('.')[0]
64
+ else:
65
+ img_name = str(index) + '_' + str(ind_img)
66
+ # save image with predicted keypoints
67
+ out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
68
+ pred_unp = (hg_keyp_norm[ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
69
+ pred_unp_maxval = hg_keyp_scores[ind_img, :, :]
70
+ pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
71
+ inp_img = input[ind_img, :, :, :].detach().clone()
72
+ save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
73
+ # save predicted 3d model (front view)
74
+ pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
75
+ pred_tex_max = np.max(pred_tex, axis=2)
76
+ out_path = save_imgs_path + '/' + prefix + 'tex_pred_' + img_name + '.png'
77
+ plt.imsave(out_path, pred_tex)
78
+ input_image = input[ind_img, :, :, :].detach().clone()
79
+ for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
80
+ input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
81
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
82
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
83
+ out_path = save_imgs_path + '/' + prefix + 'comp_pred_' + img_name + '.png'
84
+ plt.imsave(out_path, im_masked)
85
+ # save predicted 3d model (side view)
86
+ vertices_cent = vertices_smal - vertices_smal.mean(dim=1)[:, None, :]
87
+ roll = np.pi / 2 * torch.ones(1).float().to(device)
88
+ pitch = np.pi / 2 * torch.ones(1).float().to(device)
89
+ tensor_0 = torch.zeros(1).float().to(device)
90
+ tensor_1 = torch.ones(1).float().to(device)
91
+ RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
92
+ RY = torch.stack([
93
+ torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
94
+ torch.stack([tensor_0, tensor_1, tensor_0]),
95
+ torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
96
+ vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3))
97
+ vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
98
+
99
+ visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
100
+ focal_lengths=flength,
101
+ color=0) # 2)
102
+ pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
103
+ pred_tex_max = np.max(pred_tex, axis=2)
104
+ out_path = save_imgs_path + '/' + prefix + 'rot_tex_pred_' + img_name + '.png'
105
+ plt.imsave(out_path, pred_tex)
106
+ if render_all:
107
+ # save input image
108
+ inp_img = input[ind_img, :, :, :].detach().clone()
109
+ out_path = save_imgs_path + '/image_' + img_name + '.png'
110
+ save_input_image(inp_img, out_path)
111
+ # save mesh
112
+ V_posed = vertices_smal[ind_img, :, :].detach().cpu().numpy()
113
+ Faces = model.smal.f
114
+ mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
115
+ mesh_posed.export(save_imgs_path + '/' + prefix + 'mesh_posed_' + img_name + '.obj')
116
+ except:
117
+ print('dont save an image')
118
+
119
+ ############
120
+
121
+ def eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh, progress=None, skip_pck_and_iou=False):
122
+ preds = {}
123
+ preds['betas'] = betas.cpu().detach().numpy()
124
+ preds['betas_limbs'] = betas_limbs.cpu().detach().numpy()
125
+ preds['z'] = zz.cpu().detach().numpy()
126
+ preds['pose_rotmat'] = pose_rotmat.cpu().detach().numpy()
127
+ preds['flength'] = flength.cpu().detach().numpy()
128
+ preds['trans'] = trans.cpu().detach().numpy()
129
+ preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1))
130
+ img_names = []
131
+ for ind_img2 in range(0, betas.shape[0]):
132
+ if test_name_list is not None:
133
+ img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_')
134
+ img_name2 = img_name2.split('.')[0]
135
+ else:
136
+ img_name2 = str(index) + '_' + str(ind_img2)
137
+ img_names.append(img_name2)
138
+ preds['image_names'] = img_names
139
+ if not skip_pck_and_iou:
140
+ # prepare keypoints for PCK calculation - predicted as well as ground truth
141
+ # pred_keyp = output_reproj['keyp_2d'] # 256
142
+ gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1)
143
+ # gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1
144
+ gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm
145
+ # prepare silhouette for IoU calculation - predicted as well as ground truth
146
+ has_seg = target_dict['has_seg']
147
+ img_border_mask = target_dict['img_border_mask'][:, 0, :, :]
148
+ gtseg = target_dict['silh']
149
+ synth_silhouettes = pred_silh[:, 0, :, :] # output_reproj['silh']
150
+ synth_silhouettes[synth_silhouettes>0.5] = 1
151
+ synth_silhouettes[synth_silhouettes<0.5] = 0
152
+ # calculate PCK as well as IoU (similar to WLDO)
153
+ preds['acc_PCK'] = Metrics.PCK(
154
+ pred_keyp, gt_keypoints,
155
+ gtseg, has_seg, idxs=EVAL_KEYPOINTS,
156
+ thresh_range=[pck_thresh], # [0.15],
157
+ )
158
+ preds['acc_IOU'] = Metrics.IOU(
159
+ synth_silhouettes, gtseg,
160
+ img_border_mask, mask=has_seg
161
+ )
162
+ for group, group_kps in KEYPOINT_GROUPS.items():
163
+ preds[f'{group}_PCK'] = Metrics.PCK(
164
+ pred_keyp, gt_keypoints, gtseg, has_seg,
165
+ thresh_range=[pck_thresh], # [0.15],
166
+ idxs=group_kps
167
+ )
168
+ return preds
169
+
170
+
171
+ # preds['acc_PCK'] = Metrics.PCK(pred_keyp, gt_keypoints, gtseg, has_seg, idxs=EVAL_KEYPOINTS, thresh_range=[pck_thresh])
172
+ # preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, gtseg, img_border_mask, mask=has_seg)
173
+ #############################
174
+
175
+ def eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size, skip_pck_and_iou=False):
176
+ if not skip_pck_and_iou:
177
+ if not (preds['acc_PCK'].data.cpu().numpy().shape == (summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size]).shape):
178
+ import pdb; pdb.set_trace()
179
+ summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
180
+ summary['acc_sil_2d'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
181
+ for part in summary['pck_by_part']:
182
+ summary['pck_by_part'][part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
183
+ summary['betas'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas']
184
+ summary['betas_limbs'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs']
185
+ summary['z'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z']
186
+ summary['pose_rotmat'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat']
187
+ summary['flength'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength']
188
+ summary['trans'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans']
189
+ summary['breed_indices'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index']
190
+ summary['image_names'].extend(preds['image_names'])
191
+ return
192
+
193
+
194
+ def get_triangle_faces_from_pyvista_poly(poly):
195
+ """Fetch all triangle faces."""
196
+ stream = poly.faces
197
+ tris = []
198
+ i = 0
199
+ while i < len(stream):
200
+ n = stream[i]
201
+ if n != 3:
202
+ i += n + 1
203
+ continue
204
+ stop = i + n + 1
205
+ tris.append(stream[i+1:stop])
206
+ i = stop
207
+ return np.array(tris)
src/combined_model/helper3.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+
4
+ def get_triangle_faces_from_pyvista_poly(poly):
5
+ """Fetch all triangle faces."""
6
+ stream = poly.faces
7
+ tris = []
8
+ i = 0
9
+ while i < len(stream):
10
+ n = stream[i]
11
+ if n != 3:
12
+ i += n + 1
13
+ continue
14
+ stop = i + n + 1
15
+ tris.append(stream[i+1:stop])
16
+ i = stop
17
+ return np.array(tris)
src/combined_model/loss_image_to_3d_refinement.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import numpy as np
5
+ import pickle as pkl
6
+
7
+ import os
8
+ import sys
9
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
10
+ # from priors.pose_prior_35 import Prior
11
+ # from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior
12
+ from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior
13
+ from priors.shape_prior import ShapePrior
14
+ from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa, geodesic_loss_R
15
+ from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error
16
+ from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch
17
+
18
+ from priors.shape_prior import ShapePrior
19
+ from configs.SMAL_configs import SMAL_MODEL_CONFIG
20
+
21
+ from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss
22
+
23
+
24
+ class LossRef(torch.nn.Module):
25
+ def __init__(self, smal_model_type, data_info, nf_version=None):
26
+ super(LossRef, self).__init__()
27
+ self.criterion_regr = torch.nn.MSELoss() # takes the mean
28
+ self.criterion_class = torch.nn.CrossEntropyLoss()
29
+
30
+ class_weights_isflat = torch.tensor([12, 2])
31
+ self.criterion_class_isflat = torch.nn.CrossEntropyLoss(weight=class_weights_isflat)
32
+ self.criterion_l1 = torch.nn.L1Loss()
33
+ self.geodesic_loss = geodesic_loss_R(reduction='mean')
34
+ self.gc_loss_on_mesh = LossGConMesh()
35
+ self.data_info = data_info
36
+ self.smal_model_type = smal_model_type
37
+ self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :])
38
+ # if nf_version is not None:
39
+ # self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version)
40
+
41
+ self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path']
42
+ self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov
43
+
44
+ remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
45
+ with open(remeshing_path, 'rb') as fp:
46
+ self.remeshing_dict = pkl.load(fp)
47
+ self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long)
48
+ self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32)
49
+
50
+
51
+
52
+ # load 3d data for the unity dogs (an optional shape prior for 11 breeds)
53
+ self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs']
54
+ if self.unity_smal_shape_prior_dogs is not None:
55
+ self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type)
56
+ else:
57
+ self.dog_betas_unity = None
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+ def forward(self, output_ref, output_ref_comp, target_dict, weight_dict_ref):
66
+ # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image']
67
+ # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight']
68
+ batch_size = output_ref['keyp_2d'].shape[0]
69
+ loss_dict_temp = {}
70
+
71
+ # loss on reprojected keypoints
72
+ output_kp_resh = (output_ref['keyp_2d']).reshape((-1, 2))
73
+ target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2))
74
+ weights_resh = target_dict['tpts'][:, :, 2].reshape((-1))
75
+ keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1))
76
+ loss_dict_temp['keyp_ref'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
77
+ max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
78
+
79
+ # loss on reprojected silhouette
80
+ assert output_ref['silh'].shape == (target_dict['silh'][:, None, :, :]).shape
81
+ silh_loss_type = 'default'
82
+ if silh_loss_type == 'default':
83
+ with torch.no_grad():
84
+ thr_silh = 20
85
+ diff = torch.norm(output_kp_resh - target_kp_resh, dim=1)
86
+ diff_x = diff.reshape((batch_size, -1))
87
+ weights_resh_x = weights_resh.reshape((batch_size, -1))
88
+ unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6)
89
+ loss_silh_bs = ((output_ref['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_ref['silh'].shape[2]*output_ref['silh'].shape[3])
90
+ loss_dict_temp['silh_ref'] = loss_silh_bs[unweighted_kp_mean_dist<thr_silh].sum() / batch_size
91
+ else:
92
+ print('silh_loss_type: ' + silh_loss_type)
93
+ raise ValueError
94
+
95
+ # regularization: losses on difference between previous prediction and refinement
96
+ loss_dict_temp['reg_trans'] = self.criterion_l1(output_ref_comp['ref_trans_notnorm'], output_ref_comp['old_trans_notnorm'].detach()) * 3
97
+ loss_dict_temp['reg_flength'] = self.criterion_l1(output_ref_comp['ref_flength_notnorm'], output_ref_comp['old_flength_notnorm'].detach()) * 1
98
+ loss_dict_temp['reg_pose'] = self.geodesic_loss(output_ref_comp['ref_pose_rotmat'], output_ref_comp['old_pose_rotmat'].detach()) * 35 * 6
99
+
100
+ # pose priors on refined pose
101
+ loss_dict_temp['pose_legs_side'] = leg_sideway_error(output_ref['pose_rotmat'])
102
+ loss_dict_temp['pose_legs_tors'] = leg_torsion_error(output_ref['pose_rotmat'])
103
+ loss_dict_temp['pose_tail_side'] = tail_sideway_error(output_ref['pose_rotmat'])
104
+ loss_dict_temp['pose_tail_tors'] = tail_torsion_error(output_ref['pose_rotmat'])
105
+ loss_dict_temp['pose_spine_side'] = spine_sideway_error(output_ref['pose_rotmat'])
106
+ loss_dict_temp['pose_spine_tors'] = spine_torsion_error(output_ref['pose_rotmat'])
107
+
108
+ # loss to predict ground contact per vertex
109
+ # import pdb; pdb.set_trace()
110
+ if 'gc_vertexwise' in weight_dict_ref.keys():
111
+ # import pdb; pdb.set_trace()
112
+ device = output_ref['vertexwise_ground_contact'].device
113
+ pred_gc = output_ref['vertexwise_ground_contact']
114
+ loss_dict_temp['gc_vertexwise'] = self.gc_loss_on_mesh(pred_gc, target_dict['gc'].to(device=device, dtype=torch.long), target_dict['has_gc'], loss_type_gcmesh='ce')
115
+
116
+ keep_smal_mesh = False
117
+ if 'gc_plane' in weight_dict_ref.keys():
118
+ if weight_dict_ref['gc_plane'] > 0:
119
+ if keep_smal_mesh:
120
+ target_gc_class = target_dict['gc'][:, :, 0]
121
+ gc_errors_plane = calculate_plane_errors_batch(output_ref['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching'])
122
+ loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane)
123
+ else: # use a uniformly sampled mesh
124
+ target_gc_class = target_dict['gc'][:, :, 0]
125
+ device = output_ref['vertices_smal'].device
126
+ remeshing_relevant_faces = self.remeshing_relevant_faces.to(device)
127
+ remeshing_relevant_barys = self.remeshing_relevant_barys.to(device)
128
+
129
+ bs = output_ref['vertices_smal'].shape[0]
130
+ # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_ref['vertices_smal'][:, self.remeshing_relevant_faces])
131
+ # sel_verts_comparison = output_ref['vertices_smal'][:, self.remeshing_relevant_faces]
132
+ # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison)
133
+ sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
134
+ verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
135
+ target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32))
136
+ target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
137
+ gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
138
+ loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane)
139
+ loss_dict_temp['gc_blowplane'] = torch.mean(gc_errors_under_plane)
140
+
141
+ # error on classification if the ground plane is flat
142
+ if 'gc_isflat' in weight_dict_ref.keys():
143
+ # import pdb; pdb.set_trace()
144
+ self.criterion_class_isflat.to(device)
145
+ loss_dict_temp['gc_isflat'] = self.criterion_class(output_ref['isflat'], target_dict['isflat'].to(device))
146
+
147
+ # if we refine the shape WITHIN the refinement newtork (shaperef_type is not inexistent)
148
+ # shape regularization
149
+ # 'smal': loss on betas (pca coefficients), betas should be close to 0
150
+ # 'limbs...' loss on selected betas_limbs
151
+ device = output_ref_comp['ref_trans_notnorm'].device
152
+ loss_shape_weighted_list = [torch.zeros((1), device=device).mean()]
153
+ if 'shape_options' in weight_dict_ref.keys():
154
+ for ind_sp, sp in enumerate(weight_dict_ref['shape_options']):
155
+ weight_sp = weight_dict_ref['shape'][ind_sp]
156
+ # self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
157
+ if sp == 'smal':
158
+ loss_shape_tmp = self.shape_prior(output_ref['betas'])
159
+ elif sp == 'limbs':
160
+ loss_shape_tmp = torch.mean((output_ref['betas_limbs'])**2)
161
+ elif sp == 'limbs7':
162
+ limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2]
163
+ limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device)
164
+ loss_shape_tmp = torch.mean((output_ref['betas_limbs'] * limb_coeffs[None, :])**2)
165
+ else:
166
+ raise NotImplementedError
167
+ loss_shape_weighted_list.append(weight_sp * loss_shape_tmp)
168
+ loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum()
169
+
170
+
171
+
172
+
173
+
174
+ # 3D loss for dogs for which we have a unity model or toy figure
175
+ loss_dict_temp['models3d'] = torch.zeros((1), device=device).mean().to(output_ref['betas'].device)
176
+ if 'models3d' in weight_dict_ref.keys():
177
+ if weight_dict_ref['models3d'] > 0:
178
+ assert (self.dog_betas_unity is not None)
179
+ if weight_dict_ref['models3d'] > 0:
180
+ for ind_dog in range(target_dict['breed_index'].shape[0]):
181
+ breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy())
182
+ if breed_index in self.dog_betas_unity.keys():
183
+ betas_target = self.dog_betas_unity[breed_index][:output_ref['betas'].shape[1]].to(output_ref['betas'].device)
184
+ betas_output = output_ref['betas'][ind_dog, :]
185
+ betas_limbs_output = output_ref['betas_limbs'][ind_dog, :]
186
+ loss_dict_temp['models3d'] += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_ref['betas'].shape[1] + output_ref['betas_limbs'].shape[1])
187
+ else:
188
+ weight_dict_ref['models3d'] = 0.0
189
+ else:
190
+ weight_dict_ref['models3d'] = 0.0
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+ # weight the losses
203
+ loss = torch.zeros((1)).mean().to(device=output_ref['keyp_2d'].device, dtype=output_ref['keyp_2d'].dtype)
204
+ loss_dict = {}
205
+ for loss_name in weight_dict_ref.keys():
206
+ if not loss_name in ['shape', 'shape_options']:
207
+ if weight_dict_ref[loss_name] > 0:
208
+ loss_weighted = loss_dict_temp[loss_name] * weight_dict_ref[loss_name]
209
+ loss_dict[loss_name] = loss_weighted.item()
210
+ loss += loss_weighted
211
+ loss += loss_shape_weighted
212
+ loss_dict['loss'] = loss.item()
213
+
214
+ return loss, loss_dict
215
+
216
+
src/combined_model/loss_image_to_3d_withbreedrel.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import torch
4
+ import numpy as np
5
+ import pickle as pkl
6
+
7
+ import os
8
+ import sys
9
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
10
+ # from priors.pose_prior_35 import Prior
11
+ # from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior
12
+ from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior
13
+ from priors.shape_prior import ShapePrior
14
+ from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa
15
+ # from configs.SMAL_configs import SMAL_MODEL_DATA_PATH, UNITY_SMAL_SHAPE_PRIOR_DOGS, SMAL_MODEL_TYPE
16
+ from configs.SMAL_configs import SMAL_MODEL_CONFIG
17
+
18
+ from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss
19
+ from combined_model.loss_utils.loss_utils_gc import calculate_plane_errors_batch
20
+
21
+
22
+
23
+ class Loss(torch.nn.Module):
24
+ def __init__(self, smal_model_type, data_info, nf_version=None):
25
+ super(Loss, self).__init__()
26
+ self.criterion_regr = torch.nn.MSELoss() # takes the mean
27
+ self.criterion_class = torch.nn.CrossEntropyLoss()
28
+ self.data_info = data_info
29
+ self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :])
30
+ self.l_anchor = None
31
+ self.l_pos = None
32
+ self.l_neg = None
33
+ self.smal_model_type = smal_model_type
34
+ self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path']
35
+ self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs']
36
+
37
+ if nf_version is not None:
38
+ self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version)
39
+ self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov
40
+ self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1)
41
+
42
+ # load 3d data for the unity dogs (an optional shape prior for 11 breeds)
43
+ if self.unity_smal_shape_prior_dogs is not None:
44
+ self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type)
45
+ else:
46
+ self.dog_betas_unity = None
47
+
48
+ remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
49
+ with open(remeshing_path, 'rb') as fp:
50
+ self.remeshing_dict = pkl.load(fp)
51
+ self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long)
52
+ self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32)
53
+
54
+
55
+ def prepare_anchor_pos_neg(self, batch_size, device):
56
+ l0 = np.arange(0, batch_size, 2)
57
+ l_anchor = []
58
+ l_pos = []
59
+ l_neg = []
60
+ for ind in l0:
61
+ xx = set(np.arange(0, batch_size))
62
+ xx.discard(ind)
63
+ xx.discard(ind+1)
64
+ for ind2 in xx:
65
+ if ind2 % 2 == 0:
66
+ l_anchor.append(ind)
67
+ l_pos.append(ind + 1)
68
+ else:
69
+ l_anchor.append(ind + 1)
70
+ l_pos.append(ind)
71
+ l_neg.append(ind2)
72
+ self.l_anchor = torch.Tensor(l_anchor).to(torch.int64).to(device)
73
+ self.l_pos = torch.Tensor(l_pos).to(torch.int64).to(device)
74
+ self.l_neg = torch.Tensor(l_neg).to(torch.int64).to(device)
75
+ return
76
+
77
+
78
+ def forward(self, output_reproj, target_dict, weight_dict=None):
79
+
80
+ # output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image']
81
+ # target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight']
82
+ batch_size = output_reproj['keyp_2d'].shape[0]
83
+ device = output_reproj['keyp_2d'].device
84
+
85
+ # loss on reprojected keypoints
86
+ output_kp_resh = (output_reproj['keyp_2d']).reshape((-1, 2))
87
+ target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2))
88
+ weights_resh = target_dict['tpts'][:, :, 2].reshape((-1))
89
+ keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1))
90
+ loss_keyp = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
91
+ max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
92
+
93
+ # loss on reprojected silhouette
94
+ assert output_reproj['silh'].shape == (target_dict['silh'][:, None, :, :]).shape
95
+ silh_loss_type = 'default'
96
+ if silh_loss_type == 'default':
97
+ with torch.no_grad():
98
+ thr_silh = 20
99
+ diff = torch.norm(output_kp_resh - target_kp_resh, dim=1)
100
+ diff_x = diff.reshape((batch_size, -1))
101
+ weights_resh_x = weights_resh.reshape((batch_size, -1))
102
+ unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6)
103
+ loss_silh_bs = ((output_reproj['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_reproj['silh'].shape[2]*output_reproj['silh'].shape[3])
104
+ loss_silh = loss_silh_bs[unweighted_kp_mean_dist<thr_silh].sum() / batch_size
105
+ else:
106
+ print('silh_loss_type: ' + silh_loss_type)
107
+ raise ValueError
108
+
109
+ # shape regularization
110
+ # 'smal': loss on betas (pca coefficients), betas should be close to 0
111
+ # 'limbs...' loss on selected betas_limbs
112
+ loss_shape_weighted_list = [torch.zeros((1), device=device).mean().to(output_reproj['keyp_2d'].device)]
113
+ for ind_sp, sp in enumerate(weight_dict['shape_options']):
114
+ weight_sp = weight_dict['shape'][ind_sp]
115
+ # self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
116
+ if sp == 'smal':
117
+ loss_shape_tmp = self.shape_prior(output_reproj['betas'])
118
+ elif sp == 'limbs':
119
+ loss_shape_tmp = torch.mean((output_reproj['betas_limbs'])**2)
120
+ elif sp == 'limbs7':
121
+ limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2]
122
+ limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device)
123
+ loss_shape_tmp = torch.mean((output_reproj['betas_limbs'] * limb_coeffs[None, :])**2)
124
+ else:
125
+ raise NotImplementedError
126
+ loss_shape_weighted_list.append(weight_sp * loss_shape_tmp)
127
+ loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum()
128
+
129
+ # 3D loss for dogs for which we have a unity model or toy figure
130
+ loss_models3d = torch.zeros((1), device=device).mean().to(output_reproj['betas'].device)
131
+ if 'models3d' in weight_dict.keys():
132
+ if weight_dict['models3d'] > 0:
133
+ assert (self.dog_betas_unity is not None)
134
+ if weight_dict['models3d'] > 0:
135
+ for ind_dog in range(target_dict['breed_index'].shape[0]):
136
+ breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy())
137
+ if breed_index in self.dog_betas_unity.keys():
138
+ betas_target = self.dog_betas_unity[breed_index][:output_reproj['betas'].shape[1]].to(output_reproj['betas'].device)
139
+ betas_output = output_reproj['betas'][ind_dog, :]
140
+ betas_limbs_output = output_reproj['betas_limbs'][ind_dog, :]
141
+ loss_models3d += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_reproj['betas'].shape[1] + output_reproj['betas_limbs'].shape[1])
142
+ else:
143
+ weight_dict['models3d'] = 0.0
144
+ else:
145
+ weight_dict['models3d'] = 0.0
146
+
147
+ # shape resularization loss on shapedirs
148
+ # -> in the current version shapedirs are kept fixed, so we don't need those losses
149
+ if weight_dict['shapedirs'] > 0:
150
+ raise NotImplementedError
151
+ else:
152
+ loss_shapedirs = torch.zeros((1), device=device).mean().to(output_reproj['betas'].device)
153
+
154
+ # prior on back joints (not used in cvpr 2022 paper)
155
+ # -> elementwise MSE loss on all 6 coefficients of 6d rotation representation
156
+ if 'pose_0' in weight_dict.keys():
157
+ if weight_dict['pose_0'] > 0:
158
+ pred_pose_rot6d = output_reproj['pose_rot6d']
159
+ w_rj_np = np.zeros((pred_pose_rot6d.shape[1]))
160
+ w_rj_np[[2, 3, 4, 5]] = 1.0 # back
161
+ w_rj = torch.tensor(w_rj_np).to(torch.float32).to(pred_pose_rot6d.device)
162
+ zero_rot = torch.tensor([1, 0, 0, 1, 0, 0]).to(pred_pose_rot6d.device).to(torch.float32)[None, None, :].repeat((batch_size, pred_pose_rot6d.shape[1], 1))
163
+ loss_pose = self.criterion_regr(pred_pose_rot6d*w_rj[None, :, None], zero_rot*w_rj[None, :, None])
164
+ else:
165
+ loss_pose = torch.zeros((1), device=device).mean()
166
+
167
+ # pose prior
168
+ # -> we did experiment with different pose priors, for example:
169
+ # * similart to SMALify (https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py,
170
+ # https://github.com/benjiebob/SMALify/blob/master/smal_fitter/priors/pose_prior_35.py)
171
+ # * vae
172
+ # * normalizing flow pose prior
173
+ # -> our cvpr 2022 paper uses the normalizing flow pose prior as implemented below
174
+ if 'poseprior' in weight_dict.keys():
175
+ if weight_dict['poseprior'] > 0:
176
+ pred_pose_rot6d = output_reproj['pose_rot6d']
177
+ pred_pose = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3))
178
+ if 'normalizing_flow_tiger' in weight_dict['poseprior_options']:
179
+ if output_reproj['normflow_z'] is not None:
180
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='square')
181
+ else:
182
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='square')
183
+ elif 'normalizing_flow_tiger_logprob' in weight_dict['poseprior_options']:
184
+ if output_reproj['normflow_z'] is not None:
185
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='neg_log_prob')
186
+ else:
187
+ loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='neg_log_prob')
188
+ else:
189
+ raise NotImplementedError
190
+ else:
191
+ loss_poseprior = torch.zeros((1), device=device).mean()
192
+ else:
193
+ weight_dict['poseprior'] = 0
194
+ loss_poseprior = torch.zeros((1), device=device).mean()
195
+
196
+ # add a prior which penalizes side-movement angles for legs
197
+ if 'poselegssidemovement' in weight_dict.keys():
198
+ if weight_dict['poselegssidemovement'] > 0:
199
+ use_pose_legs_side_loss = True
200
+ else:
201
+ use_pose_legs_side_loss = False
202
+ else:
203
+ use_pose_legs_side_loss = False
204
+ if use_pose_legs_side_loss:
205
+ leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
206
+ leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
207
+ vec = torch.zeros((3, 1)).to(device=pred_pose.device, dtype=pred_pose.dtype)
208
+ vec[2] = -1
209
+ x0_rotmat = pred_pose
210
+ x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
211
+ x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
212
+ x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
213
+ x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
214
+ eps=0 # 1e-7
215
+ # use the component of the vector which points to the side
216
+ loss_poselegssidemovement = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean()
217
+ else:
218
+ loss_poselegssidemovement = torch.zeros((1), device=device).mean()
219
+ weight_dict['poselegssidemovement'] = 0
220
+
221
+ # dog breed classification loss
222
+ dog_breed_gt = target_dict['breed_index']
223
+ dog_breed_pred = output_reproj['dog_breed']
224
+ loss_class = self.criterion_class(dog_breed_pred, dog_breed_gt)
225
+
226
+ # dog breed relationship loss
227
+ # -> we did experiment with many other options, but none was significantly better
228
+ if '4' in weight_dict['breed_options']: # we have pairs of dogs of the same breed
229
+ if weight_dict['breed'] > 0:
230
+ assert output_reproj['dog_breed'].shape[0] == 12
231
+ # assert weight_dict['breed'] > 0
232
+ z = output_reproj['z']
233
+ # go through all pairs and compare them to each other sample
234
+ if self.l_anchor is None:
235
+ self.prepare_anchor_pos_neg(batch_size, z.device)
236
+ anchor = torch.index_select(z, 0, self.l_anchor)
237
+ positive = torch.index_select(z, 0, self.l_pos)
238
+ negative = torch.index_select(z, 0, self.l_neg)
239
+ loss_breed = self.criterion_triplet(anchor, positive, negative)
240
+ else:
241
+ loss_breed = torch.zeros((1), device=device).mean()
242
+ else:
243
+ loss_breed = torch.zeros((1), device=device).mean()
244
+
245
+ # regularizarion for focal length
246
+ loss_flength_near_mean = torch.mean(output_reproj['flength']**2)
247
+ loss_flength = loss_flength_near_mean
248
+
249
+ # bodypart segmentation loss
250
+ if 'partseg' in weight_dict.keys():
251
+ if weight_dict['partseg'] > 0:
252
+ raise NotImplementedError
253
+ else:
254
+ loss_partseg = torch.zeros((1), device=device).mean()
255
+ else:
256
+ weight_dict['partseg'] = 0
257
+ loss_partseg = torch.zeros((1), device=device).mean()
258
+
259
+
260
+ # NEW: ground contact loss for main network
261
+ keep_smal_mesh = False
262
+ if 'gc_plane' in weight_dict.keys():
263
+ if weight_dict['gc_plane'] > 0:
264
+ if keep_smal_mesh:
265
+ target_gc_class = target_dict['gc'][:, :, 0]
266
+ gc_errors_plane = calculate_plane_errors_batch(output_reproj['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching'])
267
+ loss_gc_plane = torch.mean(gc_errors_plane)
268
+ else: # use a uniformly sampled mesh
269
+ target_gc_class = target_dict['gc'][:, :, 0]
270
+ device = output_reproj['vertices_smal'].device
271
+ remeshing_relevant_faces = self.remeshing_relevant_faces.to(device)
272
+ remeshing_relevant_barys = self.remeshing_relevant_barys.to(device)
273
+
274
+ bs = output_reproj['vertices_smal'].shape[0]
275
+ # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_reproj['vertices_smal'][:, self.remeshing_relevant_faces])
276
+ # sel_verts_comparison = output_reproj['vertices_smal'][:, self.remeshing_relevant_faces]
277
+ # verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison)
278
+ sel_verts = torch.index_select(output_reproj['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
279
+ verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
280
+ target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32))
281
+ target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
282
+ gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
283
+ loss_gc_plane = torch.mean(gc_errors_plane)
284
+ loss_gc_belowplane = torch.mean(gc_errors_under_plane)
285
+ # loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane)
286
+ else:
287
+ loss_gc_plane = torch.zeros((1), device=device).mean()
288
+ loss_gc_belowplane = torch.zeros((1), device=device).mean()
289
+ else:
290
+ loss_gc_plane = torch.zeros((1), device=device).mean()
291
+ loss_gc_belowplane = torch.zeros((1), device=device).mean()
292
+ weight_dict['gc_plane'] = 0
293
+ weight_dict['gc_belowplane'] = 0
294
+
295
+
296
+
297
+ # weight and combine losses
298
+ loss_keyp_weighted = loss_keyp * weight_dict['keyp']
299
+ loss_silh_weighted = loss_silh * weight_dict['silh']
300
+ loss_shapedirs_weighted = loss_shapedirs * weight_dict['shapedirs']
301
+ loss_pose_weighted = loss_pose * weight_dict['pose_0']
302
+ loss_class_weighted = loss_class * weight_dict['class']
303
+ loss_breed_weighted = loss_breed * weight_dict['breed']
304
+ loss_flength_weighted = loss_flength * weight_dict['flength']
305
+ loss_poseprior_weighted = loss_poseprior * weight_dict['poseprior']
306
+ loss_partseg_weighted = loss_partseg * weight_dict['partseg']
307
+ loss_models3d_weighted = loss_models3d * weight_dict['models3d']
308
+ loss_poselegssidemovement_weighted = loss_poselegssidemovement * weight_dict['poselegssidemovement']
309
+
310
+ loss_gc_plane_weighted = loss_gc_plane * weight_dict['gc_plane']
311
+ loss_gc_belowplane_weighted = loss_gc_belowplane * weight_dict['gc_belowplane']
312
+
313
+
314
+ ####################################################################################################
315
+ loss = loss_keyp_weighted + loss_silh_weighted + loss_shape_weighted + loss_pose_weighted + loss_class_weighted + \
316
+ loss_shapedirs_weighted + loss_breed_weighted + loss_flength_weighted + loss_poseprior_weighted + \
317
+ loss_partseg_weighted + loss_models3d_weighted + loss_poselegssidemovement_weighted + \
318
+ loss_gc_plane_weighted + loss_gc_belowplane_weighted
319
+ ####################################################################################################
320
+
321
+ loss_dict = {'loss': loss.item(),
322
+ 'loss_keyp_weighted': loss_keyp_weighted.item(), \
323
+ 'loss_silh_weighted': loss_silh_weighted.item(), \
324
+ 'loss_shape_weighted': loss_shape_weighted.item(), \
325
+ 'loss_shapedirs_weighted': loss_shapedirs_weighted.item(), \
326
+ 'loss_pose0_weighted': loss_pose_weighted.item(), \
327
+ 'loss_class_weighted': loss_class_weighted.item(), \
328
+ 'loss_breed_weighted': loss_breed_weighted.item(), \
329
+ 'loss_flength_weighted': loss_flength_weighted.item(), \
330
+ 'loss_poseprior_weighted': loss_poseprior_weighted.item(), \
331
+ 'loss_partseg_weighted': loss_partseg_weighted.item(), \
332
+ 'loss_models3d_weighted': loss_models3d_weighted.item(), \
333
+ 'loss_poselegssidemovement_weighted': loss_poselegssidemovement_weighted.item(), \
334
+ 'loss_gc_plane_weighted': loss_gc_plane_weighted.item(), \
335
+ 'loss_gc_belowplane_weighted': loss_gc_belowplane_weighted.item()
336
+ }
337
+
338
+ return loss, loss_dict
339
+
340
+
341
+
342
+
src/combined_model/loss_utils/loss_arap.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # code from https://raw.githubusercontent.com/yufu-wang/aves/main/optimization/loss_arap.py
4
+
5
+
6
+ class Arap_Loss():
7
+ '''
8
+ Pytorch implementaion: As-rigid-as-possible loss class
9
+
10
+ '''
11
+
12
+ def __init__(self, meshes, device='cpu', vertex_w=None):
13
+
14
+ with torch.no_grad(): # new nadine
15
+
16
+ self.device = device
17
+ self.bn = len(meshes)
18
+
19
+ # get lapacian cotangent matrix
20
+ L = self.get_laplacian_cot(meshes)
21
+ self.wij = L.values().clone()
22
+ self.wij[self.wij<0] = 0.
23
+
24
+ # get ajacency matrix
25
+ V = meshes.num_verts_per_mesh().sum()
26
+ edges_packed = meshes.edges_packed()
27
+ e0, e1 = edges_packed.unbind(1)
28
+ idx01 = torch.stack([e0, e1], dim=1)
29
+ idx10 = torch.stack([e1, e0], dim=1)
30
+ idx = torch.cat([idx01, idx10], dim=0).t()
31
+
32
+ ones = torch.ones(idx.shape[1], dtype=torch.float32).to(device)
33
+ A = torch.sparse.FloatTensor(idx, ones, (V, V))
34
+ self.deg = torch.sparse.sum(A, dim=1).to_dense().long()
35
+ self.idx = self.sort_idx(idx)
36
+
37
+ # get edges of default mesh
38
+ self.eij = self.get_edges(meshes)
39
+
40
+ # get per vertex regularization strength
41
+ self.vertex_w = vertex_w
42
+
43
+
44
+ def __call__(self, new_meshes):
45
+ new_meshes._compute_packed()
46
+
47
+ optimal_R = self.step_1(new_meshes)
48
+ arap_loss = self.step_2(optimal_R, new_meshes)
49
+ return arap_loss
50
+
51
+
52
+ def step_1(self, new_meshes):
53
+ bn = self.bn
54
+ eij = self.eij.view(bn, -1, 3).cpu()
55
+
56
+ with torch.no_grad():
57
+ eij_ = self.get_edges(new_meshes)
58
+
59
+ eij_ = eij_.view(bn, -1, 3).cpu()
60
+ wij = self.wij.view(bn, -1).cpu()
61
+
62
+ deg_1 = self.deg.view(bn, -1)[0].cpu() # assuming same topology
63
+ S = torch.zeros([bn, len(deg_1), 3, 3])
64
+ for i in range(len(deg_1)):
65
+ start, end = deg_1[:i].sum(), deg_1[:i+1].sum()
66
+ P = eij[:, start : end]
67
+ P_ = eij_[:, start : end]
68
+ D = wij[:, start : end]
69
+ D = torch.diag_embed(D)
70
+ S[:, i] = P.transpose(-2,-1) @ D @ P_
71
+
72
+ S = S.view(-1, 3, 3)
73
+
74
+ u, _, v = torch.svd(S)
75
+ R = v @ u.transpose(-2, -1)
76
+ det = torch.det(R)
77
+
78
+ u[det<0, :, -1] *= -1
79
+ R = v @ u.transpose(-2, -1)
80
+ R = R.to(self.device)
81
+
82
+ return R
83
+
84
+
85
+ def step_2(self, R, new_meshes):
86
+ R = torch.repeat_interleave(R, self.deg, dim=0)
87
+ Reij = R @ self.eij.unsqueeze(2)
88
+ Reij = Reij.squeeze()
89
+
90
+ eij_ = self.get_edges(new_meshes)
91
+ arap_loss = self.wij * (eij_ - Reij).norm(dim=1)
92
+
93
+ if self.vertex_w is not None:
94
+ vertex_w = torch.repeat_interleave(self.vertex_w, self.deg, dim=0)
95
+ arap_loss = arap_loss * vertex_w
96
+
97
+ arap_loss = arap_loss.sum() / self.bn
98
+
99
+ return arap_loss
100
+
101
+
102
+ def get_edges(self, meshes):
103
+ verts_packed = meshes.verts_packed()
104
+ vi = torch.repeat_interleave(verts_packed, self.deg, dim=0)
105
+ vj = verts_packed[self.idx[1]]
106
+ eij = vi - vj
107
+ return eij
108
+
109
+
110
+ def sort_idx(self, idx):
111
+ _, order = (idx[0] + idx[1]*1e-6).sort()
112
+
113
+ return idx[:, order]
114
+
115
+
116
+ def get_laplacian_cot(self, meshes):
117
+ '''
118
+ Routine modified from :
119
+ pytorch3d/loss/mesh_laplacian_smoothing.py
120
+ '''
121
+ verts_packed = meshes.verts_packed()
122
+ faces_packed = meshes.faces_packed()
123
+ V, F = verts_packed.shape[0], faces_packed.shape[0]
124
+
125
+ face_verts = verts_packed[faces_packed]
126
+ v0, v1, v2 = face_verts[:,0], face_verts[:,1], face_verts[:,2]
127
+
128
+ A = (v1-v2).norm(dim=1)
129
+ B = (v0-v2).norm(dim=1)
130
+ C = (v0-v1).norm(dim=1)
131
+
132
+ s = 0.5 * (A+B+C)
133
+ area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
134
+
135
+ A2, B2, C2 = A * A, B * B, C * C
136
+ cota = (B2 + C2 - A2) / area
137
+ cotb = (A2 + C2 - B2) / area
138
+ cotc = (A2 + B2 - C2) / area
139
+ cot = torch.stack([cota, cotb, cotc], dim=1)
140
+ cot /= 4.0
141
+
142
+ ii = faces_packed[:, [1,2,0]]
143
+ jj = faces_packed[:, [2,0,1]]
144
+ idx = torch.stack([ii, jj], dim=0).view(2, F*3)
145
+ L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
146
+ L += L.t()
147
+ L = L.coalesce()
148
+ L /= 2.0 # normalized according to arap paper
149
+
150
+ return L
151
+
152
+
153
+
src/combined_model/loss_utils/loss_laplacian_mesh_comparison.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_utils.py
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+
7
+ # Laplacian loss, calculate the Laplacian coordiante of both coarse and refined vertices and then compare the difference
8
+ class LaplacianCTF(torch.nn.Module):
9
+ def __init__(self, adjmat, device):
10
+ '''
11
+ Args:
12
+ adjmat: adjacency matrix of the input graph data
13
+ device: specify device for training
14
+ '''
15
+ super(LaplacianCTF, self).__init__()
16
+ adjmat.data = np.ones_like(adjmat.data)
17
+ adjmat = torch.from_numpy(adjmat.todense()).float()
18
+ dg = torch.sum(adjmat, dim=-1)
19
+ dg_m = torch.diag(dg)
20
+ ls = dg_m - adjmat
21
+ self.ls = ls.unsqueeze(0).to(device) # Should be normalized by the diagonal elements according to
22
+ # the origial definition, this one also works fine.
23
+
24
+ def forward(self, verts_pred, verts_gt, smooth=False):
25
+ verts_pred = torch.matmul(self.ls, verts_pred)
26
+ verts_gt = torch.matmul(self.ls, verts_gt)
27
+ loss = torch.norm(verts_pred - verts_gt, dim=-1).mean()
28
+ if smooth:
29
+ loss_smooth = torch.norm(torch.matmul(self.ls, verts_pred), dim=-1).mean()
30
+ return loss, loss_smooth
31
+ return loss, None
32
+
33
+
34
+
35
+
36
+ #
37
+ # read the adjacency matrix, which will used in the Laplacian regularizer
38
+ # data = np.load('./data/mesh_down_sampling_4.npz', encoding='latin1', allow_pickle=True)
39
+ # adjmat = data['A'][0]
40
+ # laplacianloss = Laplacian(adjmat, device)
41
+ #
42
+ # verts_clone = verts.detach().clone()
43
+ # loss_arap, loss_smooth = laplacianloss(verts_refine, verts_clone)
44
+ # loss_arap = args.w_arap * loss_arap
45
+ #
src/combined_model/loss_utils/loss_sdf.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_sdf.py
3
+
4
+ import torch
5
+ import numpy as np
6
+ from scipy.ndimage import distance_transform_edt as distance
7
+ from skimage import segmentation as skimage_seg
8
+ import matplotlib.pyplot as plt
9
+
10
+
11
+ def dice_loss(score, target):
12
+ # implemented from paper https://arxiv.org/pdf/1606.04797.pdf
13
+ target = target.float()
14
+ smooth = 1e-5
15
+ intersect = torch.sum(score * target)
16
+ y_sum = torch.sum(target * target)
17
+ z_sum = torch.sum(score * score)
18
+ loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
19
+ loss = 1 - loss
20
+ return loss
21
+
22
+
23
+ class tversky_loss(torch.nn.Module):
24
+ # implemented from https://arxiv.org/pdf/1706.05721.pdf
25
+ def __init__(self, alpha, beta):
26
+ '''
27
+ Args:
28
+ alpha: coefficient for false positive prediction
29
+ beta: coefficient for false negtive prediction
30
+ '''
31
+ super(tversky_loss, self).__init__()
32
+ self.alpha = alpha
33
+ self.beta = beta
34
+
35
+ def __call__(self, score, target):
36
+ target = target.float()
37
+ smooth = 1e-5
38
+ tp = torch.sum(score * target)
39
+ fn = torch.sum(target * (1 - score))
40
+ fp = torch.sum((1-target) * score)
41
+ loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth)
42
+ loss = 1 - loss
43
+ return loss
44
+
45
+
46
+ def compute_sdf1_1(img_gt, out_shape):
47
+ """
48
+ compute the normalized signed distance map of binary mask
49
+ input: segmentation, shape = (batch_size, x, y, z)
50
+ output: the Signed Distance Map (SDM)
51
+ sdf(x) = 0; x in segmentation boundary
52
+ -inf|x-y|; x in segmentation
53
+ +inf|x-y|; x out of segmentation
54
+ normalize sdf to [-1, 1]
55
+ """
56
+
57
+ img_gt = img_gt.astype(np.uint8)
58
+
59
+ normalized_sdf = np.zeros(out_shape)
60
+
61
+ for b in range(out_shape[0]): # batch size
62
+ # ignore background
63
+ for c in range(1, out_shape[1]):
64
+ posmask = img_gt[b]
65
+ negmask = 1-posmask
66
+ posdis = distance(posmask)
67
+ negdis = distance(negmask)
68
+ boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
69
+ sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
70
+ sdf[boundary==1] = 0
71
+ normalized_sdf[b][c] = sdf
72
+ assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
73
+ assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
74
+
75
+ return normalized_sdf
76
+
77
+
78
+ def compute_sdf(img_gt, out_shape):
79
+ """
80
+ compute the signed distance map of binary mask
81
+ input: segmentation, shape = (batch_size, x, y, z)
82
+ output: the Signed Distance Map (SDM)
83
+ sdf(x) = 0; x in segmentation boundary
84
+ -inf|x-y|; x in segmentation
85
+ +inf|x-y|; x out of segmentation
86
+ """
87
+
88
+ img_gt = img_gt.astype(np.uint8)
89
+
90
+ gt_sdf = np.zeros(out_shape)
91
+ debug = False
92
+ for b in range(out_shape[0]): # batch size
93
+ for c in range(0, out_shape[1]):
94
+ posmask = img_gt[b]
95
+ negmask = 1-posmask
96
+ posdis = distance(posmask)
97
+ negdis = distance(negmask)
98
+ boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
99
+ sdf = negdis - posdis
100
+ sdf[boundary==1] = 0
101
+ gt_sdf[b][c] = sdf
102
+ if debug:
103
+ plt.figure()
104
+ plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar()
105
+ plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar()
106
+ plt.show()
107
+
108
+ return gt_sdf
109
+
110
+
111
+ def boundary_loss(output, gt):
112
+ """
113
+ compute boundary loss for binary segmentation
114
+ input: outputs_soft: softmax results, shape=(b,2,x,y,z)
115
+ gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z)
116
+ output: boundary_loss; sclar
117
+ adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf
118
+ """
119
+ multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt)
120
+ bd_loss = multipled.mean()
121
+
122
+ return bd_loss
src/combined_model/loss_utils/loss_utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import numpy as np
4
+
5
+
6
+ '''
7
+ def keyp_rep_error_l1(smpl_keyp_2d, keyp_hourglass, keyp_hourglass_scores, thr_kp=0.3):
8
+ # step 1: make sure that the hg prediction and barc are close
9
+ with torch.no_grad():
10
+ kp_weights = keyp_hourglass_scores
11
+ kp_weights[keyp_hourglass_scores<thr_kp] = 0
12
+ loss_keyp_rep = torch.mean((torch.abs((smpl_keyp_2d - keyp_hourglass)/512)).sum(dim=2)*kp_weights[:, :, 0])
13
+ return loss_keyp_rep
14
+
15
+ def keyp_rep_error(smpl_keyp_2d, keyp_hourglass, keyp_hourglass_scores, thr_kp=0.3):
16
+ # step 1: make sure that the hg prediction and barc are close
17
+ with torch.no_grad():
18
+ kp_weights = keyp_hourglass_scores
19
+ kp_weights[keyp_hourglass_scores<thr_kp] = 0
20
+ # losses['kp_reproj']['value'] = torch.mean((((smpl_keyp_2d - keyp_reproj_init)/512)**2).sum(dim=2)*kp_weights[:, :, 0])
21
+ loss_keyp_rep = torch.mean((((smpl_keyp_2d - keyp_hourglass)/512)**2).sum(dim=2)*kp_weights[:, :, 0])
22
+ return loss_keyp_rep
23
+ '''
24
+
25
+ def leg_sideway_error(optimed_pose_with_glob):
26
+ assert optimed_pose_with_glob.shape[1] == 35
27
+ leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
28
+ leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
29
+ # leg_indices_right = np.asarray([8, 9, 10, 18, 19, 20]) # front, back
30
+ # leg_indices_left = np.asarray([12, 13, 14, 22, 23, 24]) # front, back
31
+ x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
32
+ x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
33
+ x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
34
+ vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
35
+ vec[2] = -1
36
+ x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
37
+ x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
38
+ loss_pose_legs_side = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean()
39
+ return loss_pose_legs_side
40
+
41
+
42
+ def leg_torsion_error(optimed_pose_with_glob):
43
+ leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
44
+ leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
45
+ x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
46
+ x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
47
+ x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
48
+ vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
49
+ vec_x[0] = 1 # in x direction
50
+ x_x_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec_x
51
+ x_x_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec_x
52
+ loss_pose_legs_torsion = (x_x_legs_left[:, 1]**2).mean() + (x_x_legs_right[:, 1]**2).mean()
53
+ return loss_pose_legs_torsion
54
+
55
+
56
+ def frontleg_walkingdir_error(optimed_pose_with_glob):
57
+ # this prior should only be used for standing poses!
58
+ leg_indices_right = np.asarray([7, 8, 9, 10]) # front, back
59
+ leg_indices_left = np.asarray([11, 12, 13, 14]) # front, back
60
+ relevant_back_indices = np.asarray([1, 2, 3, 4, 5, 6]) # np.asarray([6]) # back joint in the front
61
+ x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
62
+ x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
63
+ x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
64
+ x0_rotmat_back = x0_rotmat[:, relevant_back_indices, :, :]
65
+ vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
66
+ vec[2] = -1 # vector down
67
+ x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
68
+ x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
69
+ x0_back = x0_rotmat_back.reshape((-1, 3, 3))@vec
70
+ loss_pose_legs_side = (x0_legs_left[:, 0]**2).mean() + (x0_legs_right[:, 0]**2).mean() + (x0_back[:, 0]**2).mean() # penalize movement to front
71
+ return loss_pose_legs_side
72
+
73
+
74
+ def tail_sideway_error(optimed_pose_with_glob):
75
+ tail_indices = np.asarray([25, 26, 27, 28, 29, 30, 31])
76
+ x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
77
+ x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
78
+ vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
79
+ '''vec[2] = -1
80
+ x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
81
+ loss_pose_tail_side = (x0_tail[:, 1]**2).mean()'''
82
+ vec[0] = -1
83
+ x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
84
+ loss_pose_tail_side = (x0_tail[:, 1]**2).mean()
85
+ return loss_pose_tail_side
86
+
87
+
88
+ def tail_torsion_error(optimed_pose_with_glob):
89
+ tail_indices = np.asarray([25, 26, 27, 28, 29, 30, 31])
90
+ x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
91
+ x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
92
+ vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
93
+ '''vec_x[0] = 1 # in x direction
94
+ x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
95
+ loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean()'''
96
+ vec_x[2] = 1 # in y direction
97
+ x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
98
+ loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean()
99
+ return loss_pose_tail_torsion
100
+
101
+
102
+ def spine_sideway_error(optimed_pose_with_glob):
103
+ tail_indices = np.asarray([1, 2, 3, 4, 5, 6]) # was wrong
104
+ x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
105
+ x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
106
+ vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
107
+ vec[0] = -1
108
+ x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
109
+ loss_pose_tail_side = (x0_tail[:, 1]**2).mean()
110
+ return loss_pose_tail_side
111
+
112
+
113
+ def spine_torsion_error(optimed_pose_with_glob):
114
+ tail_indices = np.asarray([1, 2, 3, 4, 5, 6])
115
+ x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
116
+ x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
117
+ vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
118
+ vec_x[2] = 1 # vec_x[0] = 1 # in z direction
119
+ x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
120
+ loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean() # (x_x_tail[:, 1]**2).mean()
121
+ return loss_pose_tail_torsion
122
+
123
+
124
+ def fit_plane(points_npx3):
125
+ # remarks:
126
+ # visualization of the plane: debug_code/curve_fitting_v2.py
127
+ # theory: https://www.ltu.se/cms_fs/1.51590!/svd-fitting.pdf
128
+ # remark: torch.svd is depreciated
129
+ # new plane equation:
130
+ # a(x−x0)+b(y−y0)+c(z−z0)=0
131
+ # ax+by+cz=d with d=ax0+by0+cz0
132
+ # z = (d-ax-by)/c
133
+ # here:
134
+ # a, b, c describe the plane normal
135
+ # d can be calculated (from a, b, c, x0, y0, z0)
136
+ # (x0, y0, z0) are the coordinates of a point on the
137
+ # plane, for example points_centroid
138
+ # (x, y, z) are the coordinates of a query point on the plane
139
+ #
140
+ # points_npx3: (n_points, 3)
141
+ # REMARK: this loss is not yet for batches!
142
+ # import pdb; pdb.set_trace()
143
+ # print('this loss is not yet for batches!')
144
+ assert (points_npx3.ndim == 2)
145
+ assert (points_npx3.shape[1] == 3)
146
+ points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
147
+ points_centroid = torch.mean(points, dim=1)
148
+ input_svd = points - points_centroid[:, None]
149
+ U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
150
+ plane_normal = U_svd[:, 2]
151
+ plane_squaredsumofdists = sigma_svd[2]
152
+ error = plane_squaredsumofdists
153
+ return points_centroid, plane_normal, error
154
+
155
+
156
+ def paws_to_groundplane_error(vertices, return_details=False):
157
+ # list of feet vertices (some of them)
158
+ # remark: we did annotate left indices and find the right insices using sym_ids_dict
159
+ # REMARK: this loss is not yet for batches!
160
+ # import pdb; pdb.set_trace()
161
+ # print('this loss is not yet for batches!')
162
+ list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
163
+ list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
164
+ list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
165
+ list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
166
+ assert vertices.shape[0] == 3889
167
+ assert vertices.shape[1] == 3
168
+ all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
169
+ verts_paws = vertices[all_paw_vert_idxs, :]
170
+ plane_centroid, plane_normal, error = fit_plane(verts_paws)
171
+ if return_details:
172
+ return plane_centroid, plane_normal, error
173
+ else:
174
+ return error
175
+
176
+ def groundcontact_error(vertices, gclabels, return_details=False):
177
+ # import pdb; pdb.set_trace()
178
+ # REMARK: this loss is not yet for batches!
179
+ import pdb; pdb.set_trace()
180
+ print('this loss is not yet for batches!')
181
+ assert vertices.shape[0] == 3889
182
+ assert vertices.shape[1] == 3
183
+ verts_gc = vertices[gclabels, :]
184
+ plane_centroid, plane_normal, error = fit_plane(verts_gc)
185
+ if return_details:
186
+ return plane_centroid, plane_normal, error
187
+ else:
188
+ return error
189
+
190
+
191
+
src/combined_model/loss_utils/loss_utils_gc.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+
5
+
6
+
7
+
8
+ class LossGConMesh(torch.nn.Module):
9
+ def __init__(self , n_verts=3889):
10
+ super(LossGConMesh, self).__init__()
11
+ self.n_verts = n_verts
12
+ self.criterion_class = torch.nn.CrossEntropyLoss(reduction='mean')
13
+
14
+ def forward(self, pred_gc, target_gc, has_gc, loss_type_gcmesh='ce'):
15
+ # pred_gc has shape (bs, n_verts, 2)
16
+ # target_gc has shape (bs, n_verts, 3)
17
+ # with [first: no-contact=0 contact=1
18
+ # second: index of closest vertex with opposite label
19
+ # third: dist to that closest vertex]
20
+ target_gc_class = target_gc[:, :, 0]
21
+ target_gc_nearoppvert_ind = target_gc[:, :, 1]
22
+ target_gc_nearoppvert_dist = target_gc[:, :, 2]
23
+ # bs = pred_gc.shape[0]
24
+ bs = has_gc.sum()
25
+
26
+ if loss_type_gcmesh == 'ce': # cross entropy
27
+ # import pdb; pdb.set_trace()
28
+
29
+ # classification_loss = self.criterion_class(pred_gc.reshape((bs*self.n_verts, 2)), target_gc_class.reshape((bs*self.n_verts)))
30
+ classification_loss = self.criterion_class(pred_gc[has_gc==True, ...].reshape((bs*self.n_verts, 2)), target_gc_class[has_gc==True, ...].reshape((bs*self.n_verts)))
31
+ loss = classification_loss
32
+ else:
33
+ raise ValueError
34
+
35
+ return loss
36
+
37
+
38
+
39
+
40
+
41
+
42
+ def calculate_plane_errors_batch(vertices, target_gc_class, target_has_gc, has_gc_is_touching, return_error_under_plane=True):
43
+ # remarks:
44
+ # visualization of the plane: debug_code/curve_fitting_v2.py
45
+ # theory: https://www.ltu.se/cms_fs/1.51590!/svd-fitting.pdf
46
+ # remark: torch.svd is depreciated
47
+ # new plane equation:
48
+ # a(x−x0)+b(y−y0)+c(z−z0)=0
49
+ # ax+by+cz=d with d=ax0+by0+cz0
50
+ # z = (d-ax-by)/c
51
+ # here:
52
+ # a, b, c describe the plane normal
53
+ # d can be calculated (from a, b, c, x0, y0, z0)
54
+ # (x0, y0, z0) are the coordinates of a point on the
55
+ # plane, for example points_centroid
56
+ # (x, y, z) are the coordinates of a query point on the plane
57
+ #
58
+ # input:
59
+ # vertices: (bs, 3889, 3)
60
+ # target_gc_class: (bs, 3889)
61
+ #
62
+ bs = vertices.shape[0]
63
+ error_list = []
64
+ error_under_plane_list = []
65
+
66
+ for ind_b in range(bs):
67
+ if target_has_gc[ind_b] == 1 and has_gc_is_touching[ind_b] == 1:
68
+ try:
69
+ points_npx3 = vertices[ind_b, target_gc_class[ind_b, :]==1, :]
70
+ points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
71
+ points_centroid = torch.mean(points, dim=1)
72
+ input_svd = points - points_centroid[:, None]
73
+ # U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
74
+ # plane_normal = U_svd[:, 2]
75
+ # _, sigma_svd, _ = torch.svd(input_svd, compute_uv=False)
76
+ # _, sigma_svd, _ = torch.svd(input_svd, compute_uv=True)
77
+ U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
78
+ plane_squaredsumofdists = sigma_svd[2]
79
+ error_list.append(plane_squaredsumofdists)
80
+
81
+ if return_error_under_plane:
82
+ # plane information
83
+ # plane_centroid = points_centroid
84
+ plane_normal = U_svd[:, 2]
85
+
86
+ # non-plane points
87
+ nonplane_points_npx3 = vertices[ind_b, target_gc_class[ind_b, :]==0, :] # (n_points_3)
88
+ nonplane_points = torch.transpose(nonplane_points_npx3, 0, 1) # (3, n_points)
89
+ nonplane_points_centered = nonplane_points - points_centroid[:, None]
90
+
91
+ nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered) # plane normal already has length 1
92
+
93
+ if nonplane_points_projected.sum() > 0:
94
+ # bug corrected 07.11.22
95
+ # error_under_plane = nonplane_points_projected[nonplane_points_projected<0].sum() / 100
96
+ error_under_plane = - nonplane_points_projected[nonplane_points_projected<0].sum() / 100
97
+ else:
98
+ error_under_plane = nonplane_points_projected[nonplane_points_projected>0].sum() / 100
99
+ error_under_plane_list.append(error_under_plane)
100
+ except:
101
+ print('was not able to calculate plane error for this image')
102
+ error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
103
+ error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
104
+ else:
105
+ error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
106
+ error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
107
+ errors = torch.stack(error_list, dim=0)
108
+ errors_under_plane = torch.stack(error_under_plane_list, dim=0)
109
+
110
+ if return_error_under_plane:
111
+ return errors, errors_under_plane
112
+ else:
113
+ return errors
114
+
115
+
116
+
117
+ # def calculate_vertex_wise_labeling_error():
118
+ # vertexwise_ground_contact
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+ '''
133
+
134
+ def paws_to_groundplane_error_batch(vertices, return_details=False):
135
+ # list of feet vertices (some of them)
136
+ # remark: we did annotate left indices and find the right insices using sym_ids_dict
137
+ # REMARK: this loss is not yet for batches!
138
+ import pdb; pdb.set_trace()
139
+ print('this loss is not yet for batches!')
140
+ list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
141
+ list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
142
+ list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
143
+ list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
144
+ assert vertices.shape[0] == 3889
145
+ assert vertices.shape[1] == 3
146
+ all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
147
+ verts_paws = vertices[all_paw_vert_idxs, :]
148
+ plane_centroid, plane_normal, error = fit_plane_batch(verts_paws)
149
+ if return_details:
150
+ return plane_centroid, plane_normal, error
151
+ else:
152
+ return error
153
+
154
+ def paws_to_groundplane_error_batch_new(vertices, return_details=False):
155
+ # list of feet vertices (some of them)
156
+ # remark: we did annotate left indices and find the right insices using sym_ids_dict
157
+ # REMARK: this loss is not yet for batches!
158
+ import pdb; pdb.set_trace()
159
+ print('this loss is not yet for batches!')
160
+ list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
161
+ list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
162
+ list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
163
+ list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
164
+ assert vertices.shape[0] == 3889
165
+ assert vertices.shape[1] == 3
166
+ all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
167
+ verts_paws = vertices[all_paw_vert_idxs, :]
168
+ plane_centroid, plane_normal, error = fit_plane_batch(verts_paws)
169
+ print('this loss is not yet for batches!')
170
+ points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
171
+ points_centroid = torch.mean(points, dim=1)
172
+ input_svd = points - points_centroid[:, None]
173
+ U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
174
+ plane_normal = U_svd[:, 2]
175
+ plane_squaredsumofdists = sigma_svd[2]
176
+ error = plane_squaredsumofdists
177
+ print('error: ' + str(error.item()))
178
+ return error
179
+ '''
src/combined_model/model_shape_v7_withref_withgraphcnn.py ADDED
@@ -0,0 +1,927 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pickle as pkl
3
+ import numpy as np
4
+ import torchvision.models as models
5
+ from torchvision import transforms
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn.parameter import Parameter
9
+ from kornia.geometry.subpix import dsnt # kornia 0.4.0
10
+
11
+ import os
12
+ import sys
13
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
14
+ from stacked_hourglass.utils.evaluation import get_preds_soft
15
+ from stacked_hourglass import hg1, hg2, hg8
16
+ from lifting_to_3d.linear_model import LinearModelComplete, LinearModel
17
+ from lifting_to_3d.inn_model_for_shape import INNForShape
18
+ from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d
19
+ from smal_pytorch.smal_model.smal_torch_new import SMAL
20
+ from smal_pytorch.renderer.differentiable_renderer import SilhRenderer
21
+ from bps_2d.bps_for_segmentation import SegBPS
22
+ # from configs.SMAL_configs import SMAL_MODEL_DATA_PATH as SHAPE_PRIOR
23
+ from configs.SMAL_configs import SMAL_MODEL_CONFIG
24
+ from configs.SMAL_configs import MEAN_DOG_BONE_LENGTHS_NO_RED, VERTEX_IDS_TAIL
25
+
26
+ # NEW: for graph cnn part
27
+ from smal_pytorch.smal_model.smal_torch_new import SMAL
28
+ from configs.SMAL_configs import SMAL_MODEL_CONFIG
29
+ from graph_networks.graphcmr.utils_mesh import Mesh
30
+ from graph_networks.graphcmr.graph_cnn_groundcontact_multistage import GraphCNNMS
31
+
32
+
33
+
34
+
35
+ class SmallLinear(nn.Module):
36
+ def __init__(self, input_size=64, output_size=30, linear_size=128):
37
+ super(SmallLinear, self).__init__()
38
+ self.relu = nn.ReLU(inplace=True)
39
+ self.w1 = nn.Linear(input_size, linear_size)
40
+ self.w2 = nn.Linear(linear_size, linear_size)
41
+ self.w3 = nn.Linear(linear_size, output_size)
42
+ def forward(self, x):
43
+ # pre-processing
44
+ y = self.w1(x)
45
+ y = self.relu(y)
46
+ y = self.w2(y)
47
+ y = self.relu(y)
48
+ y = self.w3(y)
49
+ return y
50
+
51
+
52
+ class MyConv1d(nn.Module):
53
+ def __init__(self, input_size=37, output_size=30, start=True):
54
+ super(MyConv1d, self).__init__()
55
+ self.input_size = input_size
56
+ self.output_size = output_size
57
+ self.start = start
58
+ self.weight = Parameter(torch.ones((self.output_size)))
59
+ self.bias = Parameter(torch.zeros((self.output_size)))
60
+ def forward(self, x):
61
+ # pre-processing
62
+ if self.start:
63
+ y = x[:, :self.output_size]
64
+ else:
65
+ y = x[:, -self.output_size:]
66
+ y = y * self.weight[None, :] + self.bias[None, :]
67
+ return y
68
+
69
+
70
+ class ModelShapeAndBreed(nn.Module):
71
+ def __init__(self, smal_model_type, n_betas=10, n_betas_limbs=13, n_breeds=121, n_z=512, structure_z_to_betas='default'):
72
+ super(ModelShapeAndBreed, self).__init__()
73
+ self.n_betas = n_betas
74
+ self.n_betas_limbs = n_betas_limbs # n_betas_logscale
75
+ self.n_breeds = n_breeds
76
+ self.structure_z_to_betas = structure_z_to_betas
77
+ if self.structure_z_to_betas == '1dconv':
78
+ if not (n_z == self.n_betas+self.n_betas_limbs):
79
+ raise ValueError
80
+ self.smal_model_type = smal_model_type
81
+ # shape branch
82
+ self.resnet = models.resnet34(pretrained=False)
83
+ # replace the first layer
84
+ n_in = 3 + 1
85
+ self.resnet.conv1 = nn.Conv2d(n_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
86
+ # replace the last layer
87
+ self.resnet.fc = nn.Linear(512, n_z)
88
+ # softmax
89
+ self.soft_max = torch.nn.Softmax(dim=1)
90
+ # fc network (and other versions) to connect z with betas
91
+ p_dropout = 0.2
92
+ if self.structure_z_to_betas == 'default':
93
+ self.linear_betas = LinearModel(linear_size=1024,
94
+ num_stage=1,
95
+ p_dropout=p_dropout,
96
+ input_size=n_z,
97
+ output_size=self.n_betas)
98
+ self.linear_betas_limbs = LinearModel(linear_size=1024,
99
+ num_stage=1,
100
+ p_dropout=p_dropout,
101
+ input_size=n_z,
102
+ output_size=self.n_betas_limbs)
103
+ elif self.structure_z_to_betas == 'lin':
104
+ self.linear_betas = nn.Linear(n_z, self.n_betas)
105
+ self.linear_betas_limbs = nn.Linear(n_z, self.n_betas_limbs)
106
+ elif self.structure_z_to_betas == 'fc_0':
107
+ self.linear_betas = SmallLinear(linear_size=128, # 1024,
108
+ input_size=n_z,
109
+ output_size=self.n_betas)
110
+ self.linear_betas_limbs = SmallLinear(linear_size=128, # 1024,
111
+ input_size=n_z,
112
+ output_size=self.n_betas_limbs)
113
+ elif structure_z_to_betas == 'fc_1':
114
+ self.linear_betas = LinearModel(linear_size=64, # 1024,
115
+ num_stage=1,
116
+ p_dropout=0,
117
+ input_size=n_z,
118
+ output_size=self.n_betas)
119
+ self.linear_betas_limbs = LinearModel(linear_size=64, # 1024,
120
+ num_stage=1,
121
+ p_dropout=0,
122
+ input_size=n_z,
123
+ output_size=self.n_betas_limbs)
124
+ elif self.structure_z_to_betas == '1dconv':
125
+ self.linear_betas = MyConv1d(n_z, self.n_betas, start=True)
126
+ self.linear_betas_limbs = MyConv1d(n_z, self.n_betas_limbs, start=False)
127
+ elif self.structure_z_to_betas == 'inn':
128
+ self.linear_betas_and_betas_limbs = INNForShape(self.n_betas, self.n_betas_limbs, betas_scale=1.0, betas_limbs_scale=1.0)
129
+ else:
130
+ raise ValueError
131
+ # network to connect latent shape vector z with dog breed classification
132
+ self.linear_breeds = LinearModel(linear_size=1024, # 1024,
133
+ num_stage=1,
134
+ p_dropout=p_dropout,
135
+ input_size=n_z,
136
+ output_size=self.n_breeds)
137
+ # shape multiplicator
138
+ self.shape_multiplicator_np = np.ones(self.n_betas)
139
+ with open(SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path'], 'rb') as file:
140
+ u = pkl._Unpickler(file)
141
+ u.encoding = 'latin1'
142
+ res = u.load()
143
+ # shape predictions are centered around the mean dog of our dog model
144
+ if 'dog_cluster_mean' in res.keys():
145
+ self.betas_mean_np = res['dog_cluster_mean']
146
+ else:
147
+ assert res['cluster_means'].shape[0]==1
148
+ self.betas_mean_np = res['cluster_means'][0, :]
149
+
150
+
151
+ def forward(self, img, seg_raw=None, seg_prep=None):
152
+ # img is the network input image
153
+ # seg_raw is before softmax and subtracting 0.5
154
+ # seg_prep would be the prepared_segmentation
155
+ if seg_prep is None:
156
+ seg_prep = self.soft_max(seg_raw)[:, 1:2, :, :] - 0.5
157
+ input_img_and_seg = torch.cat((img, seg_prep), axis=1)
158
+ res_output = self.resnet(input_img_and_seg)
159
+ dog_breed_output = self.linear_breeds(res_output)
160
+ if self.structure_z_to_betas == 'inn':
161
+ shape_output_orig, shape_limbs_output_orig = self.linear_betas_and_betas_limbs(res_output)
162
+ else:
163
+ shape_output_orig = self.linear_betas(res_output) * 0.1
164
+ betas_mean = torch.tensor(self.betas_mean_np).float().to(img.device)
165
+ shape_output = shape_output_orig + betas_mean[None, 0:self.n_betas]
166
+ shape_limbs_output_orig = self.linear_betas_limbs(res_output)
167
+ shape_limbs_output = shape_limbs_output_orig * 0.1
168
+ output_dict = {'z': res_output,
169
+ 'breeds': dog_breed_output,
170
+ 'betas': shape_output_orig,
171
+ 'betas_limbs': shape_limbs_output_orig}
172
+ return output_dict
173
+
174
+
175
+
176
+ class LearnableShapedirs(nn.Module):
177
+ def __init__(self, sym_ids_dict, shapedirs_init, n_betas, n_betas_fixed=10):
178
+ super(LearnableShapedirs, self).__init__()
179
+ # shapedirs_init = self.smal.shapedirs.detach()
180
+ self.n_betas = n_betas
181
+ self.n_betas_fixed = n_betas_fixed
182
+ self.sym_ids_dict = sym_ids_dict
183
+ sym_left_ids = self.sym_ids_dict['left']
184
+ sym_right_ids = self.sym_ids_dict['right']
185
+ sym_center_ids = self.sym_ids_dict['center']
186
+ self.n_center = sym_center_ids.shape[0]
187
+ self.n_left = sym_left_ids.shape[0]
188
+ self.n_sd = self.n_betas - self.n_betas_fixed # number of learnable shapedirs
189
+ # get indices to go from half_shapedirs to shapedirs
190
+ inds_back = np.zeros((3889))
191
+ for ind in range(0, sym_center_ids.shape[0]):
192
+ ind_in_forward = sym_center_ids[ind]
193
+ inds_back[ind_in_forward] = ind
194
+ for ind in range(0, sym_left_ids.shape[0]):
195
+ ind_in_forward = sym_left_ids[ind]
196
+ inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind
197
+ for ind in range(0, sym_right_ids.shape[0]):
198
+ ind_in_forward = sym_right_ids[ind]
199
+ inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind
200
+ self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long())
201
+ # self.smal.shapedirs: (51, 11667)
202
+ # shapedirs: (3889, 3, n_sd)
203
+ # shapedirs_half: (2012, 3, n_sd)
204
+ sd = shapedirs_init[:self.n_betas, :].permute((1, 0)).reshape((-1, 3, self.n_betas))
205
+ self.register_buffer('sd', sd)
206
+ sd_center = sd[sym_center_ids, :, self.n_betas_fixed:]
207
+ sd_left = sd[sym_left_ids, :, self.n_betas_fixed:]
208
+ self.register_parameter('learnable_half_shapedirs_c0', torch.nn.Parameter(sd_center[:, 0, :].detach()))
209
+ self.register_parameter('learnable_half_shapedirs_c2', torch.nn.Parameter(sd_center[:, 2, :].detach()))
210
+ self.register_parameter('learnable_half_shapedirs_l0', torch.nn.Parameter(sd_left[:, 0, :].detach()))
211
+ self.register_parameter('learnable_half_shapedirs_l1', torch.nn.Parameter(sd_left[:, 1, :].detach()))
212
+ self.register_parameter('learnable_half_shapedirs_l2', torch.nn.Parameter(sd_left[:, 2, :].detach()))
213
+ def forward(self):
214
+ device = self.learnable_half_shapedirs_c0.device
215
+ half_shapedirs_center = torch.stack((self.learnable_half_shapedirs_c0, \
216
+ torch.zeros((self.n_center, self.n_sd)).to(device), \
217
+ self.learnable_half_shapedirs_c2), axis=1)
218
+ half_shapedirs_left = torch.stack((self.learnable_half_shapedirs_l0, \
219
+ self.learnable_half_shapedirs_l1, \
220
+ self.learnable_half_shapedirs_l2), axis=1)
221
+ half_shapedirs_right = torch.stack((self.learnable_half_shapedirs_l0, \
222
+ - self.learnable_half_shapedirs_l1, \
223
+ self.learnable_half_shapedirs_l2), axis=1)
224
+ half_shapedirs_tot = torch.cat((half_shapedirs_center, half_shapedirs_left, half_shapedirs_right))
225
+ shapedirs = torch.index_select(half_shapedirs_tot, dim=0, index=self.inds_back_torch)
226
+ shapedirs_complete = torch.cat((self.sd[:, :, :self.n_betas_fixed], shapedirs), axis=2) # (3889, 3, n_sd)
227
+ shapedirs_complete_prepared = torch.cat((self.sd[:, :, :10], shapedirs), axis=2).reshape((-1, 30)).permute((1, 0)) # (n_sd, 11667)
228
+ return shapedirs_complete, shapedirs_complete_prepared
229
+
230
+
231
+ class ModelRefinement(nn.Module):
232
+ def __init__(self, n_betas=10, n_betas_limbs=7, n_breeds=121, n_keyp=20, n_joints=35, ref_net_type='add', graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'):
233
+ super(ModelRefinement, self).__init__()
234
+ self.n_betas = n_betas
235
+ self.n_betas_limbs = n_betas_limbs
236
+ self.n_breeds = n_breeds
237
+ self.n_keyp = n_keyp
238
+ self.n_joints = n_joints
239
+ self.n_out_seg = 256
240
+ self.n_out_keyp = 256
241
+ self.n_out_enc = 256
242
+ self.linear_size = 1024
243
+ self.linear_size_small = 128
244
+ self.ref_net_type = ref_net_type
245
+ self.graphcnn_type = graphcnn_type
246
+ self.isflat_type = isflat_type
247
+ self.shaperef_type = shaperef_type
248
+ p_dropout = 0.2
249
+ # --- segmentation encoder
250
+ if self.ref_net_type in ['multrot_res34', 'multrot01all_res34']:
251
+ self.ref_res = models.resnet34(pretrained=False)
252
+ else:
253
+ self.ref_res = models.resnet18(pretrained=False)
254
+ # replace the first layer
255
+ self.ref_res.conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
256
+ # replace the last layer
257
+ self.ref_res.fc = nn.Linear(512, self.n_out_seg)
258
+ # softmax
259
+ self.soft_max = torch.nn.Softmax(dim=1)
260
+ # --- keypoint encoder
261
+ self.linear_keyp = LinearModel(linear_size=self.linear_size,
262
+ num_stage=1,
263
+ p_dropout=p_dropout,
264
+ input_size=n_keyp*2*2,
265
+ output_size=self.n_out_keyp)
266
+ # --- decoder
267
+ self.linear_combined = LinearModel(linear_size=self.linear_size,
268
+ num_stage=1,
269
+ p_dropout=p_dropout,
270
+ input_size=self.n_out_seg+self.n_out_keyp,
271
+ output_size=self.n_out_enc)
272
+ # output info
273
+ pose = {'name': 'pose', 'n': self.n_joints*6, 'out_shape':[self.n_joints, 6]}
274
+ trans = {'name': 'trans_notnorm', 'n': 3}
275
+ cam = {'name': 'flength_notnorm', 'n': 1}
276
+ betas = {'name': 'betas', 'n': self.n_betas}
277
+ betas_limbs = {'name': 'betas_limbs', 'n': self.n_betas_limbs}
278
+ if self.shaperef_type=='inexistent':
279
+ self.output_info = [pose, trans, cam] # , betas]
280
+ else:
281
+ self.output_info = [pose, trans, cam, betas, betas_limbs]
282
+ # output branches
283
+ self.output_info_linear_models = []
284
+ for ind_el, element in enumerate(self.output_info):
285
+ n_in = self.n_out_enc + element['n']
286
+ self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size,
287
+ num_stage=1,
288
+ p_dropout=p_dropout,
289
+ input_size=n_in,
290
+ output_size=element['n']))
291
+ element['linear_model_index'] = ind_el
292
+ self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models)
293
+ # new: predict if the ground is flat
294
+ if not self.isflat_type=='inexistent':
295
+ self.linear_isflat = LinearModel(linear_size=self.linear_size_small,
296
+ num_stage=1,
297
+ p_dropout=p_dropout,
298
+ input_size=self.n_out_enc,
299
+ output_size=2) # answer is just yes or no
300
+
301
+
302
+ # new for ground contact prediction: graph cnn
303
+ if not self.graphcnn_type=='inexistent':
304
+ num_downsampling = 1
305
+ smal_model_type = '39dogs_norm'
306
+ smal = SMAL(smal_model_type=smal_model_type, template_name='neutral')
307
+ ROOT_smal_downsampling = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/'
308
+ smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path']).replace('.pkl', '_template.npz')
309
+ smal_downsampling_npz_path = ROOT_smal_downsampling + smal_downsampling_npz_name # 'data/mesh_downsampling.npz'
310
+ self.my_custom_smal_dog_mesh = Mesh(filename=smal_downsampling_npz_path, num_downsampling=num_downsampling, nsize=1, body_model=smal) # , device=device)
311
+ # create GraphCNN
312
+ num_layers = 2 # <= len(my_custom_mesh._A)-1
313
+ n_resnet_out = self.n_out_enc # 256
314
+ num_channels = 256 # 512
315
+ self.graph_cnn = GraphCNNMS(mesh=self.my_custom_smal_dog_mesh,
316
+ num_downsample = num_downsampling,
317
+ num_layers = num_layers,
318
+ n_resnet_out = n_resnet_out,
319
+ num_channels = num_channels) # .to(device)
320
+
321
+
322
+
323
+ def forward(self, keyp_sh, keyp_pred, in_pose_3x3, in_trans_notnorm, in_cam_notnorm, in_betas, in_betas_limbs, seg_pred_prep=None, seg_sh_raw=None, seg_sh_prep=None):
324
+ # img is the network input image
325
+ # seg_raw is before softmax and subtracting 0.5
326
+ # seg_prep would be the prepared_segmentation
327
+ batch_size = in_pose_3x3.shape[0]
328
+ device = in_pose_3x3.device
329
+ dtype = in_pose_3x3.dtype
330
+ # --- segmentation encoder
331
+ if seg_sh_prep is None:
332
+ seg_sh_prep = self.soft_max(seg_sh_raw)[:, 1:2, :, :] - 0.5 # class 1 is the dog
333
+ input_seg_conc = torch.cat((seg_sh_prep, seg_pred_prep), axis=1)
334
+ network_output_seg = self.ref_res(input_seg_conc)
335
+ # --- keypoint encoder
336
+ keyp_conc = torch.cat((keyp_sh.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2])), keyp_pred.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2]))), axis=1)
337
+ network_output_keyp = self.linear_keyp(keyp_conc)
338
+ # --- decoder
339
+ x = torch.cat((network_output_seg, network_output_keyp), axis=1)
340
+ y_comb = self.linear_combined(x)
341
+ in_pose_6d = rotmat_to_rot6d(in_pose_3x3.reshape((-1, 3, 3))).reshape((in_pose_3x3.shape[0], -1, 6))
342
+ in_dict = {'pose': in_pose_6d,
343
+ 'trans_notnorm': in_trans_notnorm,
344
+ 'flength_notnorm': in_cam_notnorm,
345
+ 'betas': in_betas,
346
+ 'betas_limbs': in_betas_limbs}
347
+ results = {}
348
+ for element in self.output_info:
349
+ # import pdb; pdb.set_trace()
350
+
351
+ linear_model = self.output_info_linear_models[element['linear_model_index']]
352
+ y = torch.cat((y_comb, in_dict[element['name']].reshape((-1, element['n']))), axis=1)
353
+ if 'out_shape' in element.keys():
354
+ if element['name'] == 'pose':
355
+ if self.ref_net_type in ['multrot', 'multrot01', 'multrot01all', 'multrotxx', 'multrot_res34', 'multrot01all_res34']: # if self.ref_net_type == 'multrot' or self.ref_net_type == 'multrot_res34':
356
+ # multiply the rotations with each other -> just predict a correction
357
+ # the correction should be initialized as identity
358
+ # res_pose_out = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']]
359
+ identity_rot6d = torch.tensor(([1., 0., 0., 1., 0., 0.])).repeat((in_pose_3x3.shape[0]*in_pose_3x3.shape[1], 1)).to(device=device, dtype=dtype)
360
+ if self.ref_net_type in ['multrot01', 'multrot01all', 'multrot01all_res34']:
361
+ res_pose_out = identity_rot6d + 0.1*(linear_model(y)).reshape((-1, element['out_shape'][1]))
362
+ elif self.ref_net_type == 'multrotxx':
363
+ res_pose_out = identity_rot6d + 0.0*(linear_model(y)).reshape((-1, element['out_shape'][1]))
364
+ else:
365
+ res_pose_out = identity_rot6d + (linear_model(y)).reshape((-1, element['out_shape'][1]))
366
+ res_pose_rotmat = rot6d_to_rotmat(res_pose_out.reshape((-1, 6))) # (bs*35, 3, 3) .reshape((batch_size, -1, 3, 3))
367
+ res_tot_rotmat = torch.bmm(res_pose_rotmat.reshape((-1, 3, 3)), in_pose_3x3.reshape((-1, 3, 3))).reshape((batch_size, -1, 3, 3)) # (bs, 5, 3, 3)
368
+ results['pose_rotmat'] = res_tot_rotmat
369
+ elif self.ref_net_type == 'add':
370
+ res_6d = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict['pose']
371
+ results['pose_rotmat'] = rot6d_to_rotmat(res_6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3))
372
+ else:
373
+ raise ValueError
374
+ else:
375
+ if self.ref_net_type in ['multrot01all', 'multrot01all_res34']:
376
+ results[element['name']] = (0.1*linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']]
377
+ else:
378
+ results[element['name']] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']]
379
+ else:
380
+ if self.ref_net_type in ['multrot01all', 'multrot01all_res34']:
381
+ results[element['name']] = 0.1*linear_model(y) + in_dict[element['name']]
382
+ else:
383
+ results[element['name']] = linear_model(y) + in_dict[element['name']]
384
+
385
+ # add prediction if ground is flat
386
+ if not self.isflat_type=='inexistent':
387
+ isflat = self.linear_isflat(y_comb)
388
+ results['isflat'] = isflat
389
+
390
+ # add graph cnn
391
+ if not self.graphcnn_type=='inexistent':
392
+ ground_contact_downsampled, ground_cntact_all_stages_output = self.graph_cnn(y_comb)
393
+ ground_contact = self.my_custom_smal_dog_mesh.upsample(ground_contact_downsampled.transpose(1,2))
394
+ results['vertexwise_ground_contact'] = ground_contact
395
+
396
+ return results
397
+
398
+
399
+
400
+
401
+ class ModelImageToBreed(nn.Module):
402
+ def __init__(self, smal_model_type, arch='hg8', n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=7, n_breeds=121, image_size=256, n_z=512, thr_keyp_sc=None, add_partseg=True):
403
+ super(ModelImageToBreed, self).__init__()
404
+ self.n_classes = n_classes
405
+ self.n_partseg = n_partseg
406
+ self.n_betas = n_betas
407
+ self.n_betas_limbs = n_betas_limbs
408
+ self.n_keyp = n_keyp
409
+ self.n_bones = n_bones
410
+ self.n_breeds = n_breeds
411
+ self.image_size = image_size
412
+ self.upsample_seg = True
413
+ self.threshold_scores = thr_keyp_sc
414
+ self.n_z = n_z
415
+ self.add_partseg = add_partseg
416
+ self.smal_model_type = smal_model_type
417
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
418
+ if arch == 'hg8':
419
+ self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg)
420
+ else:
421
+ raise Exception('unrecognised model architecture: ' + arch)
422
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
423
+ self.breed_model = ModelShapeAndBreed(smal_model_type=self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z)
424
+ def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
425
+ batch_size = input_img.shape[0]
426
+ device = input_img.device
427
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
428
+ hourglass_out_dict = self.stacked_hourglass(input_img)
429
+ last_seg = hourglass_out_dict['seg_final']
430
+ last_heatmap = hourglass_out_dict['out_list_kp'][-1]
431
+ # - prepare keypoints (from heatmap)
432
+ # normalize predictions -> from logits to probability distribution
433
+ # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1))
434
+ # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
435
+ # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2)
436
+ keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True)
437
+ if self.threshold_scores is not None:
438
+ scores[scores>self.threshold_scores] = 1.0
439
+ scores[scores<=self.threshold_scores] = 0.0
440
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
441
+ # breed_model takes as input the image as well as the predicted segmentation map
442
+ # -> we need to split up ModelImageTo3d, such that we can use the silhouette
443
+ resnet_output = self.breed_model(img=input_img, seg_raw=last_seg)
444
+ pred_breed = resnet_output['breeds'] # (bs, n_breeds)
445
+ pred_betas = resnet_output['betas']
446
+ pred_betas_limbs = resnet_output['betas_limbs']
447
+ small_output = {'keypoints_norm': keypoints_norm,
448
+ 'keypoints_scores': scores}
449
+ small_output_reproj = {'betas': pred_betas,
450
+ 'betas_limbs': pred_betas_limbs,
451
+ 'dog_breed': pred_breed}
452
+ return small_output, None, small_output_reproj
453
+
454
+ class ModelImageTo3d_withshape_withproj(nn.Module):
455
+ def __init__(self, smal_model_type, smal_keyp_conf=None, arch='hg8', num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=6, n_breeds=121, image_size=256, n_z=512, n_segbps=64*2, thr_keyp_sc=None, add_z_to_3d_input=True, add_segbps_to_3d_input=False, add_partseg=True, silh_no_tail=True, fix_flength=False, render_partseg=False, structure_z_to_betas='default', structure_pose_net='default', nf_version=None, ref_net_type='add', ref_detach_shape=True, graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'):
456
+ super(ModelImageTo3d_withshape_withproj, self).__init__()
457
+ self.n_classes = n_classes
458
+ self.n_partseg = n_partseg
459
+ self.n_betas = n_betas
460
+ self.n_betas_limbs = n_betas_limbs
461
+ self.n_keyp = n_keyp
462
+ self.n_joints = n_joints
463
+ self.n_bones = n_bones
464
+ self.n_breeds = n_breeds
465
+ self.image_size = image_size
466
+ self.threshold_scores = thr_keyp_sc
467
+ self.upsample_seg = True
468
+ self.silh_no_tail = silh_no_tail
469
+ self.add_z_to_3d_input = add_z_to_3d_input
470
+ self.add_segbps_to_3d_input = add_segbps_to_3d_input
471
+ self.add_partseg = add_partseg
472
+ self.ref_net_type = ref_net_type
473
+ self.ref_detach_shape = ref_detach_shape
474
+ self.graphcnn_type = graphcnn_type
475
+ self.isflat_type = isflat_type
476
+ self.shaperef_type = shaperef_type
477
+ assert (not self.add_segbps_to_3d_input) or (not self.add_z_to_3d_input)
478
+ self.n_z = n_z
479
+ if add_segbps_to_3d_input:
480
+ self.n_segbps = n_segbps # 64
481
+ self.segbps_model = SegBPS()
482
+ else:
483
+ self.n_segbps = 0
484
+ self.fix_flength = fix_flength
485
+ self.render_partseg = render_partseg
486
+ self.structure_z_to_betas = structure_z_to_betas
487
+ self.structure_pose_net = structure_pose_net
488
+ assert self.structure_pose_net in ['default', 'vae', 'normflow']
489
+ self.nf_version = nf_version
490
+ self.smal_model_type = smal_model_type
491
+ assert (smal_keyp_conf is not None)
492
+ self.smal_keyp_conf = smal_keyp_conf
493
+ self.register_buffer('betas_zeros', torch.zeros((1, self.n_betas)))
494
+ self.register_buffer('mean_dog_bone_lengths', torch.tensor(MEAN_DOG_BONE_LENGTHS_NO_RED, dtype=torch.float32))
495
+ p_dropout = 0.2 # 0.5
496
+ # ------------------------------ SMAL MODEL ------------------------------
497
+ self.smal = SMAL(smal_model_type=self.smal_model_type, template_name='neutral')
498
+ print('SMAL model type: ' + self.smal.smal_model_type)
499
+ # New for rendering without tail
500
+ f_np = self.smal.faces.detach().cpu().numpy()
501
+ self.f_no_tail_np = f_np[np.isin(f_np[:,:], VERTEX_IDS_TAIL).sum(axis=1)==0, :]
502
+ # in theory we could optimize for improved shapedirs, but we do not do that
503
+ # -> would need to implement regularizations
504
+ # -> there are better ways than changing the shapedirs
505
+ self.model_learnable_shapedirs = LearnableShapedirs(self.smal.sym_ids_dict, self.smal.shapedirs.detach(), self.n_betas, 10)
506
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
507
+ if arch == 'hg8':
508
+ self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg)
509
+ else:
510
+ raise Exception('unrecognised model architecture: ' + arch)
511
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
512
+ self.breed_model = ModelShapeAndBreed(self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z, structure_z_to_betas=self.structure_z_to_betas)
513
+ # ------------------------------ LINEAR 3D MODEL ------------------------------
514
+ # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength}
515
+ self.soft_max = torch.nn.Softmax(dim=1)
516
+ input_size = self.n_keyp*3 + self.n_bones
517
+ self.model_3d = LinearModelComplete(linear_size=1024,
518
+ num_stage_comb=num_stage_comb,
519
+ num_stage_heads=num_stage_heads,
520
+ num_stage_heads_pose=num_stage_heads_pose,
521
+ trans_sep=trans_sep,
522
+ p_dropout=p_dropout, # 0.5,
523
+ input_size=input_size,
524
+ intermediate_size=1024,
525
+ output_info=None,
526
+ n_joints=self.n_joints,
527
+ n_z=self.n_z,
528
+ add_z_to_3d_input=self.add_z_to_3d_input,
529
+ n_segbps=self.n_segbps,
530
+ add_segbps_to_3d_input=self.add_segbps_to_3d_input,
531
+ structure_pose_net=self.structure_pose_net,
532
+ nf_version = self.nf_version)
533
+ # ------------------------------ RENDERING ------------------------------
534
+ self.silh_renderer = SilhRenderer(image_size)
535
+ # ------------------------------ REFINEMENT -----------------------------
536
+ self.refinement_model = ModelRefinement(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_keyp=self.n_keyp, n_joints=self.n_joints, ref_net_type=self.ref_net_type, graphcnn_type=self.graphcnn_type, isflat_type=self.isflat_type, shaperef_type=self.shaperef_type)
537
+
538
+
539
+ def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
540
+ batch_size = input_img.shape[0]
541
+ device = input_img.device
542
+ # ------------------------------ STACKED HOUR GLASS ------------------------------
543
+ hourglass_out_dict = self.stacked_hourglass(input_img)
544
+ last_seg = hourglass_out_dict['seg_final']
545
+ last_heatmap = hourglass_out_dict['out_list_kp'][-1]
546
+ # - prepare keypoints (from heatmap)
547
+ # normalize predictions -> from logits to probability distribution
548
+ # last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1))
549
+ # keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
550
+ # keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2)
551
+ keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True)
552
+ if self.threshold_scores is not None:
553
+ scores[scores>self.threshold_scores] = 1.0
554
+ scores[scores<=self.threshold_scores] = 0.0
555
+ # ------------------------------ LEARNABLE SHAPE MODEL ------------------------------
556
+ # in our cvpr 2022 paper we do not change the shapedirs
557
+ # learnable_sd_complete has shape (3889, 3, n_sd)
558
+ # learnable_sd_complete_prepared has shape (n_sd, 11667)
559
+ learnable_sd_complete, learnable_sd_complete_prepared = self.model_learnable_shapedirs()
560
+ shapedirs_sel = learnable_sd_complete_prepared # None
561
+ # ------------------------------ SHAPE AND BREED MODEL ------------------------------
562
+ # breed_model takes as input the image as well as the predicted segmentation map
563
+ # -> we need to split up ModelImageTo3d, such that we can use the silhouette
564
+ resnet_output = self.breed_model(img=input_img, seg_raw=last_seg)
565
+ pred_breed = resnet_output['breeds'] # (bs, n_breeds)
566
+ pred_z = resnet_output['z']
567
+ # - prepare shape
568
+ pred_betas = resnet_output['betas']
569
+ pred_betas_limbs = resnet_output['betas_limbs']
570
+ # - calculate bone lengths
571
+ with torch.no_grad():
572
+ use_mean_bone_lengths = False
573
+ if use_mean_bone_lengths:
574
+ bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))])
575
+ else:
576
+ assert (bone_lengths_prepared is None)
577
+ bone_lengths_prepared = self.smal.caclulate_bone_lengths(pred_betas, pred_betas_limbs, shapedirs_sel=shapedirs_sel, short=True)
578
+ # ------------------------------ LINEAR 3D MODEL ------------------------------
579
+ # 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength}
580
+ # prepare input for 2d-to-3d network
581
+ keypoints_prepared = torch.cat((keypoints_norm, scores), axis=2)
582
+ if bone_lengths_prepared is None:
583
+ bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))])
584
+ # should we add silhouette to 3d input? should we add z?
585
+ if self.add_segbps_to_3d_input:
586
+ seg_raw = last_seg
587
+ seg_prep_bps = self.soft_max(seg_raw)[:, 1, :, :] # class 1 is the dog
588
+ with torch.no_grad():
589
+ seg_prep_np = seg_prep_bps.detach().cpu().numpy()
590
+ bps_output_np = self.segbps_model.calculate_bps_points_batch(seg_prep_np) # (bs, 64, 2)
591
+ bps_output = torch.tensor(bps_output_np, dtype=torch.float32).to(device).reshape((batch_size, -1))
592
+ bps_output_prep = bps_output * 2. - 1
593
+ input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
594
+ input_vec = torch.cat((input_vec_keyp_bones, bps_output_prep), dim=1)
595
+ elif self.add_z_to_3d_input:
596
+ # we do not use this in our cvpr 2022 version
597
+ input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
598
+ input_vec_additional = pred_z
599
+ input_vec = torch.cat((input_vec_keyp_bones, input_vec_additional), dim=1)
600
+ else:
601
+ input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
602
+ # predict 3d parameters (those are normalized, we need to correct mean and std in a next step)
603
+ output = self.model_3d(input_vec)
604
+ # add predicted keypoints to the output dict
605
+ output['keypoints_norm'] = keypoints_norm
606
+ output['keypoints_scores'] = scores
607
+ # add predicted segmentation to output dictc
608
+ output['seg_hg'] = hourglass_out_dict['seg_final']
609
+ # - denormalize 3d parameters -> so far predictions were normalized, now we denormalize them again
610
+ pred_trans = output['trans'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3)
611
+ if self.structure_pose_net == 'default':
612
+ pred_pose_rot6d = output['pose'] + norm_dict['pose_rot6d_mean'][None, :]
613
+ elif self.structure_pose_net == 'normflow':
614
+ pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :])
615
+ pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :]
616
+ pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros
617
+ else:
618
+ pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :])
619
+ pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :]
620
+ pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros
621
+ pred_pose_reshx33 = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6)))
622
+ pred_pose = pred_pose_reshx33.reshape((batch_size, -1, 3, 3))
623
+ pred_pose_rot6d = rotmat_to_rot6d(pred_pose_reshx33).reshape((batch_size, -1, 6))
624
+
625
+ if self.fix_flength:
626
+ output['flength'] = torch.zeros_like(output['flength'])
627
+ pred_flength = torch.ones_like(output['flength'])*2100 # norm_dict['flength_mean'][None, :]
628
+ else:
629
+ pred_flength_orig = output['flength'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1)
630
+ pred_flength = pred_flength_orig.clone() # torch.abs(pred_flength_orig)
631
+ pred_flength[pred_flength_orig<=0] = norm_dict['flength_mean'][None, :]
632
+
633
+ # ------------------------------ RENDERING ------------------------------
634
+ # get 3d model (SMAL)
635
+ V, keyp_green_3d, _ = self.smal(beta=pred_betas, betas_limbs=pred_betas_limbs, pose=pred_pose, trans=pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=shapedirs_sel)
636
+ keyp_3d = keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3)
637
+ # render silhouette
638
+ faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
639
+ if not self.silh_no_tail:
640
+ pred_silh_images, pred_keyp = self.silh_renderer(vertices=V,
641
+ points=keyp_3d, faces=faces_prep, focal_lengths=pred_flength)
642
+ else:
643
+ faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1))
644
+ pred_silh_images, pred_keyp = self.silh_renderer(vertices=V,
645
+ points=keyp_3d, faces=faces_no_tail_prep, focal_lengths=pred_flength)
646
+ # get torch 'Meshes'
647
+ torch_meshes = self.silh_renderer.get_torch_meshes(vertices=V, faces=faces_prep)
648
+
649
+ # render body parts (not part of cvpr 2022 version)
650
+ if self.render_partseg:
651
+ raise NotImplementedError
652
+ else:
653
+ partseg_images = None
654
+ partseg_images_hg = None
655
+
656
+
657
+ # ------------------------------ REFINEMENT MODEL ------------------------------
658
+
659
+ # refinement model
660
+ pred_keyp_norm = (pred_keyp.detach() / (self.image_size - 1) - 0.5)*2
661
+ '''output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \
662
+ seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \
663
+ in_pose=output['pose'].detach(), in_trans=output['trans'].detach(), in_cam=output['flength'].detach(), in_betas=pred_betas.detach())'''
664
+ output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \
665
+ seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \
666
+ in_pose_3x3=pred_pose.detach(), in_trans_notnorm=output['trans'].detach(), in_cam_notnorm=output['flength'].detach(), in_betas=pred_betas.detach(), in_betas_limbs=pred_betas_limbs.detach())
667
+ # a better alternative would be to submit pred_pose_reshx33
668
+
669
+
670
+
671
+ # nothing changes for betas or shapedirs or z ##################### should probably not be detached in the end
672
+ if self.shaperef_type == 'inexistent':
673
+ if self.ref_detach_shape:
674
+ output_ref['betas'] = pred_betas.detach()
675
+ output_ref['betas_limbs'] = pred_betas_limbs.detach()
676
+ output_ref['z'] = pred_z.detach()
677
+ output_ref['shapedirs'] = shapedirs_sel.detach()
678
+ else:
679
+ output_ref['betas'] = pred_betas
680
+ output_ref['betas_limbs'] = pred_betas_limbs
681
+ output_ref['z'] = pred_z
682
+ output_ref['shapedirs'] = shapedirs_sel
683
+ else:
684
+ assert ('betas' in output_ref.keys())
685
+ assert ('betas_limbs' in output_ref.keys())
686
+ output_ref['shapedirs'] = shapedirs_sel
687
+
688
+
689
+ # we denormalize flength and trans, but pose is handled differently
690
+ if self.fix_flength:
691
+ output_ref['flength_notnorm'] = torch.zeros_like(output['flength'])
692
+ ref_pred_flength = torch.ones_like(output['flength_notnorm'])*2100 # norm_dict['flength_mean'][None, :]
693
+ raise ValueError # not sure if we want to have a fixed flength in refinement
694
+ else:
695
+ ref_pred_flength_orig = output_ref['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1)
696
+ ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig)
697
+ ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :]
698
+ ref_pred_trans = output_ref['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3)
699
+
700
+
701
+
702
+
703
+ # ref_pred_pose_rot6d = output_ref['pose']
704
+ # ref_pred_pose_reshx33 = rot6d_to_rotmat(output_ref['pose'].reshape((-1, 6))).reshape((batch_size, -1, 3, 3))
705
+ ref_pred_pose_reshx33 = output_ref['pose_rotmat'].reshape((batch_size, -1, 3, 3))
706
+ ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6))
707
+
708
+ ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref['betas'], betas_limbs=output_ref['betas_limbs'],
709
+ pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf,
710
+ shapedirs_sel=output_ref['shapedirs'])
711
+ ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3)
712
+
713
+ if not self.silh_no_tail:
714
+ faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
715
+ ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
716
+ points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength)
717
+ else:
718
+ faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1))
719
+ ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
720
+ points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength)
721
+
722
+ output_ref_unnorm = {'vertices_smal': ref_V,
723
+ 'keyp_3d': ref_keyp_3d,
724
+ 'keyp_2d': ref_pred_keyp,
725
+ 'silh': ref_pred_silh_images,
726
+ 'trans': ref_pred_trans,
727
+ 'flength': ref_pred_flength,
728
+ 'betas': output_ref['betas'],
729
+ 'betas_limbs': output_ref['betas_limbs'],
730
+ # 'z': output_ref['z'],
731
+ 'pose_rot6d': ref_pred_pose_rot6d,
732
+ 'pose_rotmat': ref_pred_pose_reshx33}
733
+ # 'shapedirs': shapedirs_sel}
734
+
735
+ if not self.graphcnn_type == 'inexistent':
736
+ output_ref_unnorm['vertexwise_ground_contact'] = output_ref['vertexwise_ground_contact']
737
+ if not self.isflat_type=='inexistent':
738
+ output_ref_unnorm['isflat'] = output_ref['isflat']
739
+ if self.shaperef_type == 'inexistent':
740
+ output_ref_unnorm['z'] = output_ref['z']
741
+
742
+ # REMARK: we will want to have the predicted differences, for pose this would
743
+ # be a rotation matrix, ...
744
+ # -> TODO: adjust output_orig_ref_comparison
745
+ output_orig_ref_comparison = {#'pose': output['pose'].detach(),
746
+ #'trans': output['trans'].detach(),
747
+ #'flength': output['flength'].detach(),
748
+ # 'pose': output['pose'],
749
+ 'old_pose_rotmat': pred_pose_reshx33,
750
+ 'old_trans_notnorm': output['trans'],
751
+ 'old_flength_notnorm': output['flength'],
752
+ # 'ref_pose': output_ref['pose'],
753
+ 'ref_pose_rotmat': ref_pred_pose_reshx33,
754
+ 'ref_trans_notnorm': output_ref['trans_notnorm'],
755
+ 'ref_flength_notnorm': output_ref['flength_notnorm']}
756
+
757
+
758
+
759
+ # ------------------------------ PREPARE OUTPUT ------------------------------
760
+ # create output dictionarys
761
+ # output: contains all output from model_image_to_3d
762
+ # output_unnorm: same as output, but normalizations are undone
763
+ # output_reproj: smal output and reprojected keypoints as well as silhouette
764
+ keypoints_heatmap_256 = (output['keypoints_norm'] / 2. + 0.5) * (self.image_size - 1)
765
+ output_unnorm = {'pose_rotmat': pred_pose,
766
+ 'flength': pred_flength,
767
+ 'trans': pred_trans,
768
+ 'keypoints':keypoints_heatmap_256}
769
+ output_reproj = {'vertices_smal': V,
770
+ 'torch_meshes': torch_meshes,
771
+ 'keyp_3d': keyp_3d,
772
+ 'keyp_2d': pred_keyp,
773
+ 'silh': pred_silh_images,
774
+ 'betas': pred_betas,
775
+ 'betas_limbs': pred_betas_limbs,
776
+ 'pose_rot6d': pred_pose_rot6d, # used for pose prior...
777
+ 'dog_breed': pred_breed,
778
+ 'shapedirs': shapedirs_sel,
779
+ 'z': pred_z,
780
+ 'flength_unnorm': pred_flength,
781
+ 'flength': output['flength'],
782
+ 'partseg_images_rend': partseg_images,
783
+ 'partseg_images_hg_nograd': partseg_images_hg,
784
+ 'normflow_z': output['normflow_z']}
785
+
786
+ return output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison
787
+
788
+
789
+ def forward_with_multiple_refinements(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
790
+
791
+ # import pdb; pdb.set_trace()
792
+
793
+ # run normal network part
794
+ output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison = self.forward(input_img, norm_dict=norm_dict, bone_lengths_prepared=bone_lengths_prepared, betas=betas)
795
+
796
+ # prepare input for second refinement stage
797
+ batch_size = output['keypoints_norm'].shape[0]
798
+ keypoints_norm = output['keypoints_norm']
799
+ pred_keyp_norm = (output_ref_unnorm['keyp_2d'].detach() / (self.image_size - 1) - 0.5)*2
800
+
801
+ last_seg = output['seg_hg']
802
+ pred_silh_images = output_ref_unnorm['silh'].detach()
803
+
804
+ trans_notnorm = output_orig_ref_comparison['ref_trans_notnorm']
805
+ flength_notnorm = output_orig_ref_comparison['ref_flength_notnorm']
806
+ # trans_notnorm = output_orig_ref_comparison['ref_pose_rotmat']
807
+ pred_pose = output_ref_unnorm['pose_rotmat'].reshape((batch_size, -1, 3, 3))
808
+
809
+ # run second refinement step
810
+ output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \
811
+ seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \
812
+ in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), \
813
+ in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach())
814
+ # output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach())
815
+
816
+
817
+ # new shape
818
+ if self.shaperef_type == 'inexistent':
819
+ if self.ref_detach_shape:
820
+ output_ref_new['betas'] = output_ref_unnorm['betas'].detach()
821
+ output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs'].detach()
822
+ output_ref_new['z'] = output_ref_unnorm['z'].detach()
823
+ output_ref_new['shapedirs'] = output_reproj['shapedirs'].detach()
824
+ else:
825
+ output_ref_new['betas'] = output_ref_unnorm['betas']
826
+ output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs']
827
+ output_ref_new['z'] = output_ref_unnorm['z']
828
+ output_ref_new['shapedirs'] = output_reproj['shapedirs']
829
+ else:
830
+ assert ('betas' in output_ref_new.keys())
831
+ assert ('betas_limbs' in output_ref_new.keys())
832
+ output_ref_new['shapedirs'] = output_reproj['shapedirs']
833
+
834
+ # we denormalize flength and trans, but pose is handled differently
835
+ if self.fix_flength:
836
+ raise ValueError # not sure if we want to have a fixed flength in refinement
837
+ else:
838
+ ref_pred_flength_orig = output_ref_new['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1)
839
+ ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig)
840
+ ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :]
841
+ ref_pred_trans = output_ref_new['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3)
842
+
843
+
844
+ ref_pred_pose_reshx33 = output_ref_new['pose_rotmat'].reshape((batch_size, -1, 3, 3))
845
+ ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6))
846
+
847
+ ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'],
848
+ pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf,
849
+ shapedirs_sel=output_ref_new['shapedirs'])
850
+
851
+ # ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'], pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=output_ref_new['shapedirs'])
852
+ ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3)
853
+
854
+ if not self.silh_no_tail:
855
+ faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
856
+ ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
857
+ points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength)
858
+ else:
859
+ faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1))
860
+ ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
861
+ points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength)
862
+
863
+ output_ref_unnorm_new = {'vertices_smal': ref_V,
864
+ 'keyp_3d': ref_keyp_3d,
865
+ 'keyp_2d': ref_pred_keyp,
866
+ 'silh': ref_pred_silh_images,
867
+ 'trans': ref_pred_trans,
868
+ 'flength': ref_pred_flength,
869
+ 'betas': output_ref_new['betas'],
870
+ 'betas_limbs': output_ref_new['betas_limbs'],
871
+ 'pose_rot6d': ref_pred_pose_rot6d,
872
+ 'pose_rotmat': ref_pred_pose_reshx33}
873
+
874
+ if not self.graphcnn_type == 'inexistent':
875
+ output_ref_unnorm_new['vertexwise_ground_contact'] = output_ref_new['vertexwise_ground_contact']
876
+ if not self.isflat_type=='inexistent':
877
+ output_ref_unnorm_new['isflat'] = output_ref_new['isflat']
878
+ if self.shaperef_type == 'inexistent':
879
+ output_ref_unnorm_new['z'] = output_ref_new['z']
880
+
881
+ output_orig_ref_comparison_new = {'ref_pose_rotmat': ref_pred_pose_reshx33,
882
+ 'ref_trans_notnorm': output_ref_new['trans_notnorm'],
883
+ 'ref_flength_notnorm': output_ref_new['flength_notnorm']}
884
+
885
+ results = {
886
+ 'output': output,
887
+ 'output_unnorm': output_unnorm,
888
+ 'output_reproj':output_reproj,
889
+ 'output_ref_unnorm': output_ref_unnorm,
890
+ 'output_orig_ref_comparison':output_orig_ref_comparison,
891
+ 'output_ref_unnorm_new': output_ref_unnorm_new,
892
+ 'output_orig_ref_comparison_new': output_orig_ref_comparison_new}
893
+ return results
894
+
895
+
896
+
897
+
898
+
899
+
900
+
901
+
902
+
903
+
904
+
905
+
906
+
907
+
908
+
909
+
910
+
911
+
912
+
913
+
914
+
915
+ def render_vis_nograd(self, vertices, focal_lengths, color=0):
916
+ # this function is for visualization only
917
+ # vertices: (bs, n_verts, 3)
918
+ # focal_lengths: (bs, 1)
919
+ # color: integer, either 0 or 1
920
+ # returns a torch tensor of shape (bs, image_size, image_size, 3)
921
+ with torch.no_grad():
922
+ batch_size = vertices.shape[0]
923
+ faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
924
+ visualizations = self.silh_renderer.get_visualization_nograd(vertices,
925
+ faces_prep, focal_lengths, color=color)
926
+ return visualizations
927
+
src/combined_model/train_main_image_to_3d_wbr_withref.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.backends.cudnn
5
+ import torch.nn.parallel
6
+ from tqdm import tqdm
7
+ import os
8
+ import pathlib
9
+ from matplotlib import pyplot as plt
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ import trimesh
14
+ import pickle as pkl
15
+ import csv
16
+ from scipy.spatial.transform import Rotation as R_sc
17
+
18
+
19
+ import sys
20
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
21
+ from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
22
+ from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
23
+ from metrics.metrics import Metrics
24
+ from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS, SMAL_KEYPOINT_NAMES_FOR_3D_EVAL, SMAL_KEYPOINT_INDICES_FOR_3D_EVAL, SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL
25
+ from combined_model.helper import eval_save_visualizations_and_meshes, eval_prepare_pck_and_iou, eval_add_preds_to_summary
26
+
27
+ from smal_pytorch.smal_model.smal_torch_new import SMAL # for gc visualization
28
+ from src.combined_model.loss_utils.loss_utils import fit_plane
29
+ # from src.evaluation.sketchfab_evaluation.alignment_utils.calculate_v2v_error_release import compute_similarity_transform
30
+ # from src.evaluation.sketchfab_evaluation.alignment_utils.calculate_alignment_error import calculate_alignemnt_errors
31
+
32
+ # ---------------------------------------------------------------------------------------------------------------------------
33
+ def do_training_epoch(train_loader, model, loss_module, loss_module_ref, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None, weight_dict_ref=None):
34
+ losses = AverageMeter()
35
+ losses_keyp = AverageMeter()
36
+ losses_silh = AverageMeter()
37
+ losses_shape = AverageMeter()
38
+ losses_pose = AverageMeter()
39
+ losses_class = AverageMeter()
40
+ losses_breed = AverageMeter()
41
+ losses_partseg = AverageMeter()
42
+ losses_ref_keyp = AverageMeter()
43
+ losses_ref_silh = AverageMeter()
44
+ losses_ref_pose = AverageMeter()
45
+ losses_ref_reg = AverageMeter()
46
+ accuracies = AverageMeter()
47
+ # Put the model in training mode.
48
+ model.train()
49
+ # prepare progress bar
50
+ iterable = enumerate(train_loader)
51
+ progress = None
52
+ if not quiet:
53
+ progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False)
54
+ iterable = progress
55
+ # information for normalization
56
+ norm_dict = {
57
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
58
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
59
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
60
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
61
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
62
+ # prepare variables, put them on the right device
63
+ for i, (input, target_dict) in iterable:
64
+ batch_size = input.shape[0]
65
+ for key in target_dict.keys():
66
+ if key == 'breed_index':
67
+ target_dict[key] = target_dict[key].long().to(device)
68
+ elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
69
+ target_dict[key] = target_dict[key].float().to(device)
70
+ elif key in ['has_seg', 'gc']:
71
+ target_dict[key] = target_dict[key].to(device)
72
+ else:
73
+ pass
74
+ input = input.float().to(device)
75
+
76
+ # ----------------------- do training step -----------------------
77
+ assert model.training, 'model must be in training mode.'
78
+ with torch.enable_grad():
79
+ # ----- forward pass -----
80
+ output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict)
81
+ # ----- loss -----
82
+ # --- from main network
83
+ loss, loss_dict = loss_module(output_reproj=output_reproj,
84
+ target_dict=target_dict,
85
+ weight_dict=weight_dict)
86
+ # ---from refinement network
87
+ loss_ref, loss_dict_ref = loss_module_ref(output_ref=output_ref,
88
+ output_ref_comp=output_ref_comp,
89
+ target_dict=target_dict,
90
+ weight_dict_ref=weight_dict_ref)
91
+ loss_total = loss + loss_ref
92
+ # ----- backward pass and parameter update -----
93
+ optimiser.zero_grad()
94
+ loss_total.backward()
95
+ optimiser.step()
96
+ # ----------------------------------------------------------------
97
+
98
+ # prepare losses for progress bar
99
+ bs_fake = 1 # batch_size
100
+ losses.update(loss_dict['loss'] + loss_dict_ref['loss'], bs_fake)
101
+ losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
102
+ losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
103
+ losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
104
+ losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
105
+ losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
106
+ losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
107
+ losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
108
+ losses_ref_keyp.update(loss_dict_ref['keyp_ref'], bs_fake)
109
+ losses_ref_silh.update(loss_dict_ref['silh_ref'], bs_fake)
110
+ loss_ref_pose = 0
111
+ for l_name in ['pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_side', 'pose_spine_tors']:
112
+ if l_name in loss_dict_ref.keys():
113
+ loss_ref_pose += loss_dict_ref[l_name]
114
+ losses_ref_pose.update(loss_ref_pose, bs_fake)
115
+ loss_ref_reg = 0
116
+ for l_name in ['reg_trans', 'reg_flength', 'reg_pose']:
117
+ if l_name in loss_dict_ref.keys():
118
+ loss_ref_reg += loss_dict_ref[l_name]
119
+ losses_ref_reg.update(loss_ref_reg, bs_fake)
120
+ acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
121
+ accuracies.update(acc, bs_fake)
122
+ # Show losses as part of the progress bar.
123
+ if progress is not None:
124
+ my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
125
+ loss=losses.avg,
126
+ loss_keyp=losses_keyp.avg,
127
+ loss_silh=losses_silh.avg,
128
+ loss_shape=losses_shape.avg,
129
+ loss_pose=losses_pose.avg,
130
+ loss_class=losses_class.avg,
131
+ loss_breed=losses_breed.avg,
132
+ loss_partseg=losses_partseg.avg,
133
+ loss_ref_keyp=losses_ref_keyp.avg,
134
+ loss_ref_silh=losses_ref_silh.avg,
135
+ loss_ref_pose=losses_ref_pose.avg,
136
+ loss_ref_reg=losses_ref_reg.avg)
137
+ my_string_short = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
138
+ loss=losses.avg,
139
+ loss_keyp=losses_keyp.avg,
140
+ loss_silh=losses_silh.avg,
141
+ loss_ref_keyp=losses_ref_keyp.avg,
142
+ loss_ref_silh=losses_ref_silh.avg,
143
+ loss_ref_pose=losses_ref_pose.avg,
144
+ loss_ref_reg=losses_ref_reg.avg)
145
+ progress.set_postfix_str(my_string_short)
146
+
147
+ return my_string, accuracies.avg
148
+
149
+
150
+ # ---------------------------------------------------------------------------------------------------------------------------
151
+ def do_validation_epoch(val_loader, model, loss_module, loss_module_ref, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, weight_dict_ref=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None):
152
+ losses = AverageMeter()
153
+ losses_keyp = AverageMeter()
154
+ losses_silh = AverageMeter()
155
+ losses_shape = AverageMeter()
156
+ losses_pose = AverageMeter()
157
+ losses_class = AverageMeter()
158
+ losses_breed = AverageMeter()
159
+ losses_partseg = AverageMeter()
160
+ losses_ref_keyp = AverageMeter()
161
+ losses_ref_silh = AverageMeter()
162
+ losses_ref_pose = AverageMeter()
163
+ losses_ref_reg = AverageMeter()
164
+ accuracies = AverageMeter()
165
+ if save_imgs_path is not None:
166
+ pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
167
+ # Put the model in evaluation mode.
168
+ model.eval()
169
+ # prepare progress bar
170
+ iterable = enumerate(val_loader)
171
+ progress = None
172
+ if not quiet:
173
+ progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False)
174
+ iterable = progress
175
+ # summarize information for normalization
176
+ norm_dict = {
177
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
178
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
179
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
180
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
181
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
182
+ batch_size = val_loader.batch_size
183
+
184
+ return_mesh_with_gt_groundplane = True
185
+ if return_mesh_with_gt_groundplane:
186
+ remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
187
+ with open(remeshing_path, 'rb') as fp:
188
+ remeshing_dict = pkl.load(fp)
189
+ remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
190
+ remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
191
+
192
+
193
+ # from smal_pytorch.smal_model.smal_torch_new import SMAL
194
+ print('start: load smal default model (barc), but only for vertices')
195
+ smal = SMAL()
196
+ print('end: load smal default model (barc), but only for vertices')
197
+ smal_template_verts = smal.v_template.detach().cpu().numpy()
198
+ smal_faces = smal.faces.detach().cpu().numpy()
199
+
200
+
201
+ my_step = 0
202
+ for index, (input, target_dict) in iterable:
203
+
204
+ # prepare variables, put them on the right device
205
+ curr_batch_size = input.shape[0]
206
+ for key in target_dict.keys():
207
+ if key == 'breed_index':
208
+ target_dict[key] = target_dict[key].long().to(device)
209
+ elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
210
+ target_dict[key] = target_dict[key].float().to(device)
211
+ elif key in ['has_seg', 'gc']:
212
+ target_dict[key] = target_dict[key].to(device)
213
+ else:
214
+ pass
215
+ input = input.float().to(device)
216
+
217
+ # ----------------------- do validation step -----------------------
218
+ with torch.no_grad():
219
+ # ----- forward pass -----
220
+ # output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores'])
221
+ # output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints'])
222
+ # output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength'])
223
+ # target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh'])
224
+ output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict)
225
+ # ----- loss -----
226
+ if metrics == 'no_loss':
227
+ # --- from main network
228
+ loss, loss_dict = loss_module(output_reproj=output_reproj,
229
+ target_dict=target_dict,
230
+ weight_dict=weight_dict)
231
+ # ---from refinement network
232
+ loss_ref, loss_dict_ref = loss_module_ref(output_ref=output_ref,
233
+ output_ref_comp=output_ref_comp,
234
+ target_dict=target_dict,
235
+ weight_dict_ref=weight_dict_ref)
236
+ loss_total = loss + loss_ref
237
+
238
+ # ----------------------------------------------------------------
239
+
240
+
241
+ for result_network in ['normal', 'ref']:
242
+ # variabled that are not refined
243
+ hg_keyp_norm = output['keypoints_norm']
244
+ hg_keyp_scores = output['keypoints_scores']
245
+ betas = output_reproj['betas']
246
+ betas_limbs = output_reproj['betas_limbs']
247
+ zz = output_reproj['z']
248
+ if result_network == 'normal':
249
+ # STEP 1: normal network
250
+ vertices_smal = output_reproj['vertices_smal']
251
+ flength = output_unnorm['flength']
252
+ pose_rotmat = output_unnorm['pose_rotmat']
253
+ trans = output_unnorm['trans']
254
+ pred_keyp = output_reproj['keyp_2d']
255
+ pred_silh = output_reproj['silh']
256
+ prefix = 'normal_'
257
+ else:
258
+ # STEP 1: refinement network
259
+ vertices_smal = output_ref['vertices_smal']
260
+ flength = output_ref['flength']
261
+ pose_rotmat = output_ref['pose_rotmat']
262
+ trans = output_ref['trans']
263
+ pred_keyp = output_ref['keyp_2d']
264
+ pred_silh = output_ref['silh']
265
+ prefix = 'ref_'
266
+ if return_mesh_with_gt_groundplane and 'gc' in target_dict.keys():
267
+ bs = vertices_smal.shape[0]
268
+ target_gc_class = target_dict['gc'][:, :, 0]
269
+ sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
270
+ verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
271
+ target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
272
+ target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
273
+
274
+
275
+
276
+
277
+
278
+ # import pdb; pdb.set_trace()
279
+
280
+ # new for vertex wise ground contact
281
+ if (not model.graphcnn_type == 'inexistent') and (save_imgs_path is not None):
282
+ # import pdb; pdb.set_trace()
283
+
284
+ sm = torch.nn.Softmax(dim=2)
285
+ ground_contact_probs = sm(output_ref['vertexwise_ground_contact'])
286
+
287
+ for ind_img in range(ground_contact_probs.shape[0]):
288
+ # ind_img = 0
289
+ if test_name_list is not None:
290
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
291
+ img_name = img_name.split('.')[0]
292
+ else:
293
+ img_name = str(index) + '_' + str(ind_img)
294
+ out_path_gcmesh = save_imgs_path + '/' + prefix + 'gcmesh_' + img_name + '.obj'
295
+
296
+ gc_prob = ground_contact_probs[ind_img, :, 1] # contact probability
297
+ vert_colors = np.repeat(255*gc_prob.detach().cpu().numpy()[:, None], 3, 1)
298
+ my_mesh = trimesh.Trimesh(vertices=smal_template_verts, faces=smal_faces, process=False, maintain_order=True)
299
+ my_mesh.visual.vertex_colors = vert_colors
300
+ save_gc_mesh = True # False
301
+ if save_gc_mesh:
302
+ my_mesh.export(out_path_gcmesh)
303
+
304
+ '''
305
+ input_image = input[ind_img, :, :, :].detach().clone()
306
+ for t, m, s in zip(input_image, data_info.rgb_mean,data_info.rgb_stddev): t.add_(m)
307
+ input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
308
+ out_path = save_debug_path + 'b' + str(ind_img) +'_input.png'
309
+ plt.imsave(out_path, input_image_np)
310
+ '''
311
+
312
+ # -------------------------------------
313
+
314
+ # import pdb; pdb.set_trace()
315
+
316
+
317
+ '''
318
+ target_gc_class = target_dict['gc'][ind_img, :, 0]
319
+
320
+ current_vertices_smal = vertices_smal[ind_img, :, :]
321
+
322
+ points_centroid, plane_normal, error = fit_plane(current_vertices_smal[target_gc_class==1, :])
323
+ '''
324
+
325
+ # calculate ground plane
326
+ # (see /is/cluster/work/nrueegg/icon_pifu_related/ICON/debug_code/curve_fitting_v2.py)
327
+ if return_mesh_with_gt_groundplane and 'gc' in target_dict.keys():
328
+
329
+ current_verts_remeshed = verts_remeshed[ind_img, :, :]
330
+ current_target_gc_class_remeshed_prep = target_gc_class_remeshed_prep[ind_img, ...]
331
+
332
+ if current_target_gc_class_remeshed_prep.sum() > 3:
333
+ points_on_plane = current_verts_remeshed[current_target_gc_class_remeshed_prep==1, :]
334
+ data_centroid, plane_normal, error = fit_plane(points_on_plane)
335
+ nonplane_points_centered = current_verts_remeshed[current_target_gc_class_remeshed_prep==0, :] - data_centroid[None, :]
336
+ nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered.transpose(0,1))
337
+
338
+ if nonplane_points_projected.sum() > 0: # plane normal points towards the animal
339
+ plane_normal = plane_normal.detach().cpu().numpy()
340
+ else:
341
+ plane_normal = - plane_normal.detach().cpu().numpy()
342
+ data_centroid = data_centroid.detach().cpu().numpy()
343
+
344
+
345
+
346
+ # import pdb; pdb.set_trace()
347
+
348
+
349
+ desired_plane_normal_vector = np.asarray([[0, -1, 0]])
350
+ # new approach: use cross product
351
+ rotation_axis = np.cross(plane_normal, desired_plane_normal_vector) # np.cross(plane_normal, desired_plane_normal_vector)
352
+ lengt_rotation_axis = np.linalg.norm(rotation_axis) # = sin(alpha) (because vectors have unit length)
353
+ angle = np.sin(lengt_rotation_axis)
354
+ rot = R_sc.from_rotvec(angle * rotation_axis * 1/lengt_rotation_axis)
355
+ rot_mat = rot[0].as_matrix()
356
+ rot_upsidedown = R_sc.from_rotvec(np.pi * np.asarray([[1, 0, 0]]))
357
+ # rot_upsidedown[0].apply(rot[0].apply(plane_normal))
358
+ current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy()
359
+ new_smal_vertices = rot_upsidedown[0].apply(rot[0].apply(current_vertices_smal - data_centroid[None, :]))
360
+ my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True)
361
+ vert_colors[:, 2] = 255
362
+ my_mesh.visual.vertex_colors = vert_colors
363
+ out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_new.obj'
364
+ my_mesh.export(out_path_gc_rotated)
365
+
366
+
367
+
368
+
369
+
370
+
371
+ '''# rot = R_sc.align_vectors(plane_normal.reshape((1, -1)), desired_plane_normal_vector)
372
+ desired_plane_normal_vector = np.asarray([[0, 1, 0]])
373
+
374
+ rot = R_sc.align_vectors(desired_plane_normal_vector, plane_normal.reshape((1, -1))) # inv
375
+ rot_mat = rot[0].as_matrix()
376
+
377
+
378
+ current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy()
379
+ new_smal_vertices = rot[0].apply((current_vertices_smal - data_centroid[None, :]))
380
+
381
+ my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True)
382
+ my_mesh.visual.vertex_colors = vert_colors
383
+ out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_y.obj'
384
+ my_mesh.export(out_path_gc_rotated)
385
+ '''
386
+
387
+
388
+
389
+
390
+
391
+
392
+
393
+
394
+
395
+ # ----
396
+
397
+
398
+ # -------------------------------------
399
+
400
+
401
+
402
+
403
+ if index == 0:
404
+ if len_dataset is None:
405
+ len_data = val_loader.batch_size * len(val_loader) # 1703
406
+ else:
407
+ len_data = len_dataset
408
+ if metrics == 'all' or metrics == 'no_loss':
409
+ if result_network == 'normal':
410
+ summaries = {'normal': dict(), 'ref': dict()}
411
+ summary = summaries['normal']
412
+ else:
413
+ summary = summaries['ref']
414
+ summary['pck'] = np.zeros((len_data))
415
+ summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS}
416
+ summary['acc_sil_2d'] = np.zeros(len_data)
417
+ summary['betas'] = np.zeros((len_data,betas.shape[1]))
418
+ summary['betas_limbs'] = np.zeros((len_data, betas_limbs.shape[1]))
419
+ summary['z'] = np.zeros((len_data, zz.shape[1]))
420
+ summary['pose_rotmat'] = np.zeros((len_data, pose_rotmat.shape[1], 3, 3))
421
+ summary['flength'] = np.zeros((len_data, flength.shape[1]))
422
+ summary['trans'] = np.zeros((len_data, trans.shape[1]))
423
+ summary['breed_indices'] = np.zeros((len_data))
424
+ summary['image_names'] = [] # len_data * [None]
425
+ else:
426
+ if result_network == 'normal':
427
+ summary = summaries['normal']
428
+ else:
429
+ summary = summaries['ref']
430
+
431
+ if save_imgs_path is not None:
432
+ eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=render_all)
433
+
434
+ if metrics == 'all' or metrics == 'no_loss':
435
+ preds = eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh, progress=progress)
436
+ # add results for all images in this batch to lists
437
+ curr_batch_size = pred_keyp.shape[0]
438
+ eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size)
439
+ else:
440
+ # measure accuracy and record loss
441
+ bs_fake = 1 # batch_size
442
+ # import pdb; pdb.set_trace()
443
+
444
+
445
+ # save_imgs_path + '/' + prefix + 'rot_tex_pred_' + img_name + '.png'
446
+ # import pdb; pdb.set_trace()
447
+ '''
448
+ for ind_img in range(len(target_dict['index'])):
449
+ try:
450
+ if test_name_list is not None:
451
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
452
+ img_name = img_name.split('.')[0]
453
+ else:
454
+ img_name = str(index) + '_' + str(ind_img)
455
+ all_image_names = ['keypoints_pred_' + img_name + '.png', 'normal_comp_pred_' + img_name + '.png', 'normal_rot_tex_pred_' + img_name + '.png', 'ref_comp_pred_' + img_name + '.png', 'ref_rot_tex_pred_' + img_name + '.png']
456
+ all_saved_images = []
457
+ for sub_img_name in all_image_names:
458
+ saved_img = cv2.imread(save_imgs_path + '/' + sub_img_name)
459
+ if not (saved_img.shape[0] == 256 and saved_img.shape[1] == 256):
460
+ saved_img = cv2.resize(saved_img, (256, 256))
461
+ all_saved_images.append(saved_img)
462
+ final_image = np.concatenate(all_saved_images, axis=1)
463
+ save_imgs_path_sum = save_imgs_path.replace('test_', 'summary_test_')
464
+ if not os.path.exists(save_imgs_path_sum): os.makedirs(save_imgs_path_sum)
465
+ final_image_path = save_imgs_path_sum + '/summary_' + img_name + '.png'
466
+ cv2.imwrite(final_image_path, final_image)
467
+ except:
468
+ print('dont save a summary image')
469
+ '''
470
+
471
+
472
+ bs_fake = 1
473
+ if metrics == 'all' or metrics == 'no_loss':
474
+ # update progress bar
475
+ if progress is not None:
476
+ '''my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format(
477
+ pck[:(my_step * batch_size + curr_batch_size)].mean(),
478
+ acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean())'''
479
+ my_string = "normal_PCK: {0:.2f}, normal_IOU: {1:.2f}, ref_PCK: {2:.2f}, ref_IOU: {3:.2f}".format(
480
+ summaries['normal']['pck'][:(my_step * batch_size + curr_batch_size)].mean(),
481
+ summaries['normal']['acc_sil_2d'][:(my_step * batch_size + curr_batch_size)].mean(),
482
+ summaries['ref']['pck'][:(my_step * batch_size + curr_batch_size)].mean(),
483
+ summaries['ref']['acc_sil_2d'][:(my_step * batch_size + curr_batch_size)].mean())
484
+ progress.set_postfix_str(my_string)
485
+ else:
486
+ losses.update(loss_dict['loss'] + loss_dict_ref['loss'], bs_fake)
487
+ losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
488
+ losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
489
+ losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
490
+ losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
491
+ losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
492
+ losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
493
+ losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
494
+ losses_ref_keyp.update(loss_dict_ref['keyp_ref'], bs_fake)
495
+ losses_ref_silh.update(loss_dict_ref['silh_ref'], bs_fake)
496
+ loss_ref_pose = 0
497
+ for l_name in ['pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_side', 'pose_spine_tors']:
498
+ loss_ref_pose += loss_dict_ref[l_name]
499
+ losses_ref_pose.update(loss_ref_pose, bs_fake)
500
+ loss_ref_reg = 0
501
+ for l_name in ['reg_trans', 'reg_flength', 'reg_pose']:
502
+ loss_ref_reg += loss_dict_ref[l_name]
503
+ losses_ref_reg.update(loss_ref_reg, bs_fake)
504
+ acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
505
+ accuracies.update(acc, bs_fake)
506
+ # Show losses as part of the progress bar.
507
+ if progress is not None:
508
+ my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
509
+ loss=losses.avg,
510
+ loss_keyp=losses_keyp.avg,
511
+ loss_silh=losses_silh.avg,
512
+ loss_shape=losses_shape.avg,
513
+ loss_pose=losses_pose.avg,
514
+ loss_class=losses_class.avg,
515
+ loss_breed=losses_breed.avg,
516
+ loss_partseg=losses_partseg.avg,
517
+ loss_ref_keyp=losses_ref_keyp.avg,
518
+ loss_ref_silh=losses_ref_silh.avg,
519
+ loss_ref_pose=losses_ref_pose.avg,
520
+ loss_ref_reg=losses_ref_reg.avg)
521
+ my_string_short = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
522
+ loss=losses.avg,
523
+ loss_keyp=losses_keyp.avg,
524
+ loss_silh=losses_silh.avg,
525
+ loss_ref_keyp=losses_ref_keyp.avg,
526
+ loss_ref_silh=losses_ref_silh.avg,
527
+ loss_ref_pose=losses_ref_pose.avg,
528
+ loss_ref_reg=losses_ref_reg.avg)
529
+ progress.set_postfix_str(my_string_short)
530
+ my_step += 1
531
+ if metrics == 'all':
532
+ return my_string, summaries # summary
533
+ elif metrics == 'no_loss':
534
+ return my_string, np.average(np.asarray(summaries['ref']['acc_sil_2d'])) # np.average(np.asarray(summary['acc_sil_2d']))
535
+ else:
536
+ return my_string, accuracies.avg
537
+
538
+
539
+ # ---------------------------------------------------------------------------------------------------------------------------
540
+ def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, weight_dict_ref=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False, len_dataset=None):
541
+ if save_imgs_path is not None:
542
+ pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
543
+ all_results = []
544
+
545
+ # Put the model in evaluation mode.
546
+ model.eval()
547
+
548
+ iterable = enumerate(val_loader)
549
+
550
+ # information for normalization
551
+ norm_dict = {
552
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
553
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
554
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
555
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
556
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
557
+
558
+
559
+ return_mesh_with_gt_groundplane = True
560
+ if return_mesh_with_gt_groundplane:
561
+ remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
562
+ with open(remeshing_path, 'rb') as fp:
563
+ remeshing_dict = pkl.load(fp)
564
+ remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
565
+ remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
566
+
567
+ # from smal_pytorch.smal_model.smal_torch_new import SMAL
568
+ print('start: load smal default model (barc), but only for vertices')
569
+ smal = SMAL()
570
+ print('end: load smal default model (barc), but only for vertices')
571
+ smal_template_verts = smal.v_template.detach().cpu().numpy()
572
+ smal_faces = smal.faces.detach().cpu().numpy()
573
+
574
+ file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.txt', 'a') # append mode
575
+ file_alignment_errors.write(" ----------- start evaluation ------------- \n ")
576
+
577
+ csv_file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.csv', 'w') # write mode
578
+ fieldnames = ['name', 'error']
579
+ writer = csv.DictWriter(csv_file_alignment_errors, fieldnames=fieldnames)
580
+ writer.writeheader()
581
+
582
+ my_step = 0
583
+ for index, (input, target_dict) in iterable:
584
+ batch_size = input.shape[0]
585
+ input = input.float().to(device)
586
+ partial_results = {}
587
+
588
+ # ----------------------- do visualization step -----------------------
589
+ with torch.no_grad():
590
+ output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict)
591
+
592
+
593
+ # import pdb; pdb.set_trace()
594
+
595
+
596
+ sm = torch.nn.Softmax(dim=2)
597
+ ground_contact_probs = sm(output_ref['vertexwise_ground_contact'])
598
+
599
+ for result_network in ['normal', 'ref']:
600
+ # variabled that are not refined
601
+ hg_keyp_norm = output['keypoints_norm']
602
+ hg_keyp_scores = output['keypoints_scores']
603
+ betas = output_reproj['betas']
604
+ betas_limbs = output_reproj['betas_limbs']
605
+ zz = output_reproj['z']
606
+ if result_network == 'normal':
607
+ # STEP 1: normal network
608
+ vertices_smal = output_reproj['vertices_smal']
609
+ flength = output_unnorm['flength']
610
+ pose_rotmat = output_unnorm['pose_rotmat']
611
+ trans = output_unnorm['trans']
612
+ pred_keyp = output_reproj['keyp_2d']
613
+ pred_silh = output_reproj['silh']
614
+ prefix = 'normal_'
615
+ else:
616
+ # STEP 1: refinement network
617
+ vertices_smal = output_ref['vertices_smal']
618
+ flength = output_ref['flength']
619
+ pose_rotmat = output_ref['pose_rotmat']
620
+ trans = output_ref['trans']
621
+ pred_keyp = output_ref['keyp_2d']
622
+ pred_silh = output_ref['silh']
623
+ prefix = 'ref_'
624
+
625
+ bs = vertices_smal.shape[0]
626
+ # target_gc_class = target_dict['gc'][:, :, 0]
627
+ target_gc_class = torch.round(ground_contact_probs).long()[:, :, 1]
628
+ sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
629
+ verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
630
+ target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
631
+ target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
632
+
633
+
634
+
635
+
636
+ # index = i
637
+ # ind_img = 0
638
+ for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size):
639
+
640
+ # ind_img = 0
641
+ if test_name_list is not None:
642
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
643
+ img_name = img_name.split('.')[0]
644
+ else:
645
+ img_name = str(index) + '_' + str(ind_img)
646
+ out_path_gcmesh = save_imgs_path + '/' + prefix + 'gcmesh_' + img_name + '.obj'
647
+
648
+ gc_prob = ground_contact_probs[ind_img, :, 1] # contact probability
649
+ vert_colors = np.repeat(255*gc_prob.detach().cpu().numpy()[:, None], 3, 1)
650
+ my_mesh = trimesh.Trimesh(vertices=smal_template_verts, faces=smal_faces, process=False, maintain_order=True)
651
+ my_mesh.visual.vertex_colors = vert_colors
652
+ save_gc_mesh = False
653
+ if save_gc_mesh:
654
+ my_mesh.export(out_path_gcmesh)
655
+
656
+ current_verts_remeshed = verts_remeshed[ind_img, :, :]
657
+ current_target_gc_class_remeshed_prep = target_gc_class_remeshed_prep[ind_img, ...]
658
+
659
+ if current_target_gc_class_remeshed_prep.sum() > 3:
660
+ points_on_plane = current_verts_remeshed[current_target_gc_class_remeshed_prep==1, :]
661
+ data_centroid, plane_normal, error = fit_plane(points_on_plane)
662
+ nonplane_points_centered = current_verts_remeshed[current_target_gc_class_remeshed_prep==0, :] - data_centroid[None, :]
663
+ nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered.transpose(0,1))
664
+
665
+ if nonplane_points_projected.sum() > 0: # plane normal points towards the animal
666
+ plane_normal = plane_normal.detach().cpu().numpy()
667
+ else:
668
+ plane_normal = - plane_normal.detach().cpu().numpy()
669
+ data_centroid = data_centroid.detach().cpu().numpy()
670
+
671
+
672
+
673
+ # import pdb; pdb.set_trace()
674
+
675
+
676
+ desired_plane_normal_vector = np.asarray([[0, -1, 0]])
677
+ # new approach: use cross product
678
+ rotation_axis = np.cross(plane_normal, desired_plane_normal_vector) # np.cross(plane_normal, desired_plane_normal_vector)
679
+ lengt_rotation_axis = np.linalg.norm(rotation_axis) # = sin(alpha) (because vectors have unit length)
680
+ angle = np.sin(lengt_rotation_axis)
681
+ rot = R_sc.from_rotvec(angle * rotation_axis * 1/lengt_rotation_axis)
682
+ rot_mat = rot[0].as_matrix()
683
+ rot_upsidedown = R_sc.from_rotvec(np.pi * np.asarray([[1, 0, 0]]))
684
+ # rot_upsidedown[0].apply(rot[0].apply(plane_normal))
685
+ current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy()
686
+ new_smal_vertices = rot_upsidedown[0].apply(rot[0].apply(current_vertices_smal - data_centroid[None, :]))
687
+ my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True)
688
+ vert_colors[:, 2] = 255
689
+ my_mesh.visual.vertex_colors = vert_colors
690
+ out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_new.obj'
691
+ my_mesh.export(out_path_gc_rotated)
692
+
693
+
694
+
695
+ '''
696
+ import pdb; pdb.set_trace()
697
+
698
+ from src.evaluation.registration import preprocess_point_cloud, o3d_ransac, draw_registration_result
699
+ import open3d as o3d
700
+ import copy
701
+
702
+
703
+ mesh_gt_path = target_dict['mesh_path'][ind_img]
704
+ mesh_gt = o3d.io.read_triangle_mesh(mesh_gt_path)
705
+
706
+ mesh_gt_verts = np.asarray(mesh_gt.vertices)
707
+ mesh_gt_faces = np.asarray(mesh_gt.triangles)
708
+ diag_gt = np.sqrt(sum((mesh_gt_verts.max(axis=0) - mesh_gt_verts.min(axis=0))**2))
709
+
710
+ mesh_pred_verts = np.asarray(new_smal_vertices)
711
+ mesh_pred_faces = np.asarray(smal_faces)
712
+ diag_pred = np.sqrt(sum((mesh_pred_verts.max(axis=0) - mesh_pred_verts.min(axis=0))**2))
713
+ mesh_pred = o3d.geometry.TriangleMesh()
714
+ mesh_pred.vertices = o3d.utility.Vector3dVector(mesh_pred_verts)
715
+ mesh_pred.triangles = o3d.utility.Vector3iVector(mesh_pred_faces)
716
+
717
+ # center the predicted mesh around 0
718
+ trans = - mesh_pred_verts.mean(axis=0)
719
+ mesh_pred_verts_new = mesh_pred_verts + trans
720
+ # change the size of the predicted mesh
721
+ mesh_pred_verts_new = mesh_pred_verts_new * diag_gt / diag_pred
722
+
723
+ # transform the predicted mesh (rough alignment)
724
+ mesh_pred_new = copy.deepcopy(mesh_pred)
725
+ mesh_pred_new.vertices = o3d.utility.Vector3dVector(np.asarray(mesh_pred_verts_new)) # normals should not have changed
726
+ voxel_size = 0.01 # 0.5
727
+ distance_threshold = 0.015 # 0.005 # 0.02 # 1.0
728
+ result, src_down, src_fpfh, dst_down, dst_fpfh = o3d_ransac(mesh_pred_new, mesh_gt, voxel_size=voxel_size, distance_threshold=distance_threshold, return_all=True)
729
+ transform = result.transformation
730
+ mesh_pred_transf = copy.deepcopy(mesh_pred_new).transform(transform)
731
+
732
+ out_path_pred_transf = save_imgs_path + '/' + prefix + 'alignment_initial_' + img_name + '.obj'
733
+ o3d.io.write_triangle_mesh(out_path_pred_transf, mesh_pred_transf)
734
+
735
+ # img_name_part = img_name.split(img_name.split('_')[-1] + '_')[0]
736
+ # out_path_gt = save_imgs_path + '/' + prefix + 'ground_truth_' + img_name_part + '.obj'
737
+ # o3d.io.write_triangle_mesh(out_path_gt, mesh_gt)
738
+
739
+
740
+ trans_init = transform
741
+ threshold = 0.02 # 0.1 # 0.02
742
+
743
+ n_points = 10000
744
+ src = mesh_pred_new.sample_points_uniformly(number_of_points=n_points)
745
+ dst = mesh_gt.sample_points_uniformly(number_of_points=n_points)
746
+
747
+ # reg_p2p = o3d.pipelines.registration.registration_icp(src_down, dst_down, threshold, trans_init, o3d.pipelines.registration.TransformationEstimationPointToPoint(), o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
748
+ reg_p2p = o3d.pipelines.registration.registration_icp(src, dst, threshold, trans_init, o3d.pipelines.registration.TransformationEstimationPointToPoint(), o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
749
+
750
+ # mesh_pred_transf_refined = copy.deepcopy(mesh_pred_new).transform(reg_p2p.transformation)
751
+ # out_path_pred_transf_refined = save_imgs_path + '/' + prefix + 'alignment_final_' + img_name + '.obj'
752
+ # o3d.io.write_triangle_mesh(out_path_pred_transf_refined, mesh_pred_transf_refined)
753
+
754
+
755
+ aligned_mesh_final = trimesh.Trimesh(mesh_pred_new.vertices, mesh_pred_new.triangles, vertex_colors=[0, 255, 0])
756
+ gt_mesh = trimesh.Trimesh(mesh_gt.vertices, mesh_gt.triangles, vertex_colors=[255, 0, 0])
757
+ scene = trimesh.Scene([aligned_mesh_final, gt_mesh])
758
+ out_path_alignment_with_gt = save_imgs_path + '/' + prefix + 'alignment_with_gt_' + img_name + '.obj'
759
+
760
+ scene.export(out_path_alignment_with_gt)
761
+ '''
762
+
763
+ # import pdb; pdb.set_trace()
764
+
765
+
766
+ # SMAL_KEYPOINT_NAMES_FOR_3D_EVAL # 17 keypoints
767
+ # prepare target
768
+ target_keyp_isvalid = target_dict['keypoints_3d'][ind_img, :, 3].detach().cpu().numpy()
769
+ keyp_to_use = (np.asarray(SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL)==1)*(target_keyp_isvalid==1)
770
+ target_keyp_raw = target_dict['keypoints_3d'][ind_img, :, :3].detach().cpu().numpy()
771
+ target_keypoints = target_keyp_raw[keyp_to_use, :]
772
+ target_pointcloud = target_dict['pointcloud_points'][ind_img, :, :].detach().cpu().numpy()
773
+ # prepare prediction
774
+ pred_keypoints_raw = output_ref['vertices_smal'][ind_img, SMAL_KEYPOINT_INDICES_FOR_3D_EVAL, :].detach().cpu().numpy()
775
+ pred_keypoints = pred_keypoints_raw[keyp_to_use, :]
776
+ pred_pointcloud = verts_remeshed[ind_img, :, :].detach().cpu().numpy()
777
+
778
+
779
+
780
+
781
+ '''
782
+ pred_keypoints_transf, pred_pointcloud_transf, procrustes_params = compute_similarity_transform(pred_keypoints, target_keypoints, num_joints=None, verts=pred_pointcloud)
783
+ pa_error = np.sqrt(np.sum((target_keypoints - pred_keypoints_transf) ** 2, axis=1))
784
+ error_procrustes = np.mean(pa_error)
785
+
786
+
787
+ col_target = np.zeros((target_pointcloud.shape[0], 3), dtype=np.uint8)
788
+ col_target[:, 0] = 255
789
+ col_pred = np.zeros((pred_pointcloud_transf.shape[0], 3), dtype=np.uint8)
790
+ col_pred[:, 1] = 255
791
+ pc = trimesh.points.PointCloud(np.concatenate((target_pointcloud, pred_pointcloud_transf)), colors=np.concatenate((col_target, col_pred)))
792
+ out_path_pc = save_imgs_path + '/' + prefix + 'pointclouds_aligned_' + img_name + '.obj'
793
+ pc.export(out_path_pc)
794
+
795
+ print(target_dict['mesh_path'][ind_img])
796
+ print(error_procrustes)
797
+ file_alignment_errors.write(target_dict['mesh_path'][ind_img] + '\n')
798
+ file_alignment_errors.write('error: ' + str(error_procrustes) + ' \n')
799
+
800
+ writer.writerow({'name': (target_dict['mesh_path'][ind_img]).split('/')[-1], 'error': str(error_procrustes)})
801
+
802
+ # import pdb; pdb.set_trace()
803
+ # alignment_dict = calculate_alignemnt_errors(output_ref['vertices_smal'][ind_img, :, :], target_dict['keypoints_3d'][ind_img, :, :], target_dict['pointcloud_points'][ind_img, :, :])
804
+ # file_alignment_errors.write('error: ' + str(alignment_dict['error_procrustes']) + ' \n')
805
+ '''
806
+
807
+
808
+
809
+
810
+
811
+
812
+ if index == 0:
813
+ if len_dataset is None:
814
+ len_data = val_loader.batch_size * len(val_loader) # 1703
815
+ else:
816
+ len_data = len_dataset
817
+ if result_network == 'normal':
818
+ summaries = {'normal': dict(), 'ref': dict()}
819
+ summary = summaries['normal']
820
+ else:
821
+ summary = summaries['ref']
822
+ summary['pck'] = np.zeros((len_data))
823
+ summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS}
824
+ summary['acc_sil_2d'] = np.zeros(len_data)
825
+ summary['betas'] = np.zeros((len_data,betas.shape[1]))
826
+ summary['betas_limbs'] = np.zeros((len_data, betas_limbs.shape[1]))
827
+ summary['z'] = np.zeros((len_data, zz.shape[1]))
828
+ summary['pose_rotmat'] = np.zeros((len_data, pose_rotmat.shape[1], 3, 3))
829
+ summary['flength'] = np.zeros((len_data, flength.shape[1]))
830
+ summary['trans'] = np.zeros((len_data, trans.shape[1]))
831
+ summary['breed_indices'] = np.zeros((len_data))
832
+ summary['image_names'] = [] # len_data * [None]
833
+ # ['vertices_smal'] = np.zeros((len_data, vertices_smal.shape[1], 3))
834
+ else:
835
+ if result_network == 'normal':
836
+ summary = summaries['normal']
837
+ else:
838
+ summary = summaries['ref']
839
+
840
+
841
+ # import pdb; pdb.set_trace()
842
+
843
+
844
+ eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=render_all)
845
+
846
+
847
+ preds = eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh=None, skip_pck_and_iou=True)
848
+ # add results for all images in this batch to lists
849
+ curr_batch_size = pred_keyp.shape[0]
850
+ eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size, skip_pck_and_iou=True)
851
+
852
+ # summary['vertices_smal'][my_step * batch_size:my_step * batch_size + curr_batch_size] = vertices_smal.detach().cpu().numpy()
853
+
854
+
855
+
856
+
857
+
858
+
859
+
860
+
861
+
862
+
863
+
864
+
865
+
866
+
867
+
868
+ '''
869
+ try:
870
+ if test_name_list is not None:
871
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
872
+ img_name = img_name.split('.')[0]
873
+ else:
874
+ img_name = str(index) + '_' + str(ind_img)
875
+ partial_results['img_name'] = img_name
876
+ visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
877
+ focal_lengths=output_unnorm['flength'],
878
+ color=0) # 2)
879
+ # save image with predicted keypoints
880
+ pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
881
+ pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
882
+ pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
883
+ inp_img = input[ind_img, :, :, :].detach().clone()
884
+ if save_imgs_path is not None:
885
+ out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
886
+ save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
887
+ # save predicted 3d model
888
+ # (1) front view
889
+ pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
890
+ pred_tex_max = np.max(pred_tex, axis=2)
891
+ partial_results['tex_pred'] = pred_tex
892
+ if save_imgs_path is not None:
893
+ out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
894
+ plt.imsave(out_path, pred_tex)
895
+ input_image = input[ind_img, :, :, :].detach().clone()
896
+ for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
897
+ input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
898
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
899
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
900
+ partial_results['comp_pred'] = im_masked
901
+ if save_imgs_path is not None:
902
+ out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
903
+ plt.imsave(out_path, im_masked)
904
+ # (2) side view
905
+ vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
906
+ roll = np.pi / 2 * torch.ones(1).float().to(device)
907
+ pitch = np.pi / 2 * torch.ones(1).float().to(device)
908
+ tensor_0 = torch.zeros(1).float().to(device)
909
+ tensor_1 = torch.ones(1).float().to(device)
910
+ RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
911
+ RY = torch.stack([
912
+ torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
913
+ torch.stack([tensor_0, tensor_1, tensor_0]),
914
+ torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
915
+ vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3))
916
+ vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
917
+ visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
918
+ focal_lengths=output_unnorm['flength'],
919
+ color=0) # 2)
920
+ pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
921
+ pred_tex_max = np.max(pred_tex, axis=2)
922
+ partial_results['rot_tex_pred'] = pred_tex
923
+ if save_imgs_path is not None:
924
+ out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
925
+ plt.imsave(out_path, pred_tex)
926
+ render_all = True
927
+ if render_all:
928
+ # save input image
929
+ inp_img = input[ind_img, :, :, :].detach().clone()
930
+ if save_imgs_path is not None:
931
+ out_path = save_imgs_path + '/image_' + img_name + '.png'
932
+ save_input_image(inp_img, out_path)
933
+ # save posed mesh
934
+ V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
935
+ Faces = model.smal.f
936
+ mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
937
+ partial_results['mesh_posed'] = mesh_posed
938
+ if save_imgs_path is not None:
939
+ mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
940
+ except:
941
+ print('pass...')
942
+ all_results.append(partial_results)
943
+ '''
944
+
945
+ my_step += 1
946
+
947
+
948
+ file_alignment_errors.close()
949
+ csv_file_alignment_errors.close()
950
+
951
+
952
+ if return_results:
953
+ return all_results
954
+ else:
955
+ return summaries
src/combined_model/train_main_image_to_3d_withbreedrel.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.backends.cudnn
5
+ import torch.nn.parallel
6
+ from tqdm import tqdm
7
+ import os
8
+ import pathlib
9
+ from matplotlib import pyplot as plt
10
+ import cv2
11
+ import numpy as np
12
+ import torch
13
+ import trimesh
14
+
15
+ import sys
16
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
17
+ from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
18
+ from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
19
+ from metrics.metrics import Metrics
20
+ from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS
21
+
22
+
23
+ # ---------------------------------------------------------------------------------------------------------------------------
24
+ def do_training_epoch(train_loader, model, loss_module, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None):
25
+ losses = AverageMeter()
26
+ losses_keyp = AverageMeter()
27
+ losses_silh = AverageMeter()
28
+ losses_shape = AverageMeter()
29
+ losses_pose = AverageMeter()
30
+ losses_class = AverageMeter()
31
+ losses_breed = AverageMeter()
32
+ losses_partseg = AverageMeter()
33
+ accuracies = AverageMeter()
34
+ # Put the model in training mode.
35
+ model.train()
36
+ # prepare progress bar
37
+ iterable = enumerate(train_loader)
38
+ progress = None
39
+ if not quiet:
40
+ progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False)
41
+ iterable = progress
42
+ # information for normalization
43
+ norm_dict = {
44
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
45
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
46
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
47
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
48
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
49
+ # prepare variables, put them on the right device
50
+ for i, (input, target_dict) in iterable:
51
+ batch_size = input.shape[0]
52
+ for key in target_dict.keys():
53
+ if key == 'breed_index':
54
+ target_dict[key] = target_dict[key].long().to(device)
55
+ elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
56
+ target_dict[key] = target_dict[key].float().to(device)
57
+ elif key in ['has_seg', 'gc']:
58
+ target_dict[key] = target_dict[key].to(device)
59
+ else:
60
+ pass
61
+ input = input.float().to(device)
62
+
63
+ # ----------------------- do training step -----------------------
64
+ assert model.training, 'model must be in training mode.'
65
+ with torch.enable_grad():
66
+ # ----- forward pass -----
67
+ output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
68
+ # ----- loss -----
69
+ loss, loss_dict = loss_module(output_reproj=output_reproj,
70
+ target_dict=target_dict,
71
+ weight_dict=weight_dict)
72
+ # ----- backward pass and parameter update -----
73
+ optimiser.zero_grad()
74
+ loss.backward()
75
+ optimiser.step()
76
+ # ----------------------------------------------------------------
77
+
78
+ # prepare losses for progress bar
79
+ bs_fake = 1 # batch_size
80
+ losses.update(loss_dict['loss'], bs_fake)
81
+ losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
82
+ losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
83
+ losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
84
+ losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
85
+ losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
86
+ losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
87
+ losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
88
+ acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
89
+ accuracies.update(acc, bs_fake)
90
+ # Show losses as part of the progress bar.
91
+ if progress is not None:
92
+ my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format(
93
+ loss=losses.avg,
94
+ loss_keyp=losses_keyp.avg,
95
+ loss_silh=losses_silh.avg,
96
+ loss_shape=losses_shape.avg,
97
+ loss_pose=losses_pose.avg,
98
+ loss_class=losses_class.avg,
99
+ loss_breed=losses_breed.avg,
100
+ loss_partseg=losses_partseg.avg
101
+ )
102
+ progress.set_postfix_str(my_string)
103
+
104
+ return my_string, accuracies.avg
105
+
106
+
107
+ # ---------------------------------------------------------------------------------------------------------------------------
108
+ def do_validation_epoch(val_loader, model, loss_module, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None):
109
+ losses = AverageMeter()
110
+ losses_keyp = AverageMeter()
111
+ losses_silh = AverageMeter()
112
+ losses_shape = AverageMeter()
113
+ losses_pose = AverageMeter()
114
+ losses_class = AverageMeter()
115
+ losses_breed = AverageMeter()
116
+ losses_partseg = AverageMeter()
117
+ accuracies = AverageMeter()
118
+ if save_imgs_path is not None:
119
+ pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
120
+ # Put the model in evaluation mode.
121
+ model.eval()
122
+ # prepare progress bar
123
+ iterable = enumerate(val_loader)
124
+ progress = None
125
+ if not quiet:
126
+ progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False)
127
+ iterable = progress
128
+ # summarize information for normalization
129
+ norm_dict = {
130
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
131
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
132
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
133
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
134
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
135
+ batch_size = val_loader.batch_size
136
+ # prepare variables, put them on the right device
137
+ my_step = 0
138
+ for i, (input, target_dict) in iterable:
139
+ curr_batch_size = input.shape[0]
140
+ for key in target_dict.keys():
141
+ if key == 'breed_index':
142
+ target_dict[key] = target_dict[key].long().to(device)
143
+ elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
144
+ target_dict[key] = target_dict[key].float().to(device)
145
+ elif key in ['has_seg', 'gc']:
146
+ target_dict[key] = target_dict[key].to(device)
147
+ else:
148
+ pass
149
+ input = input.float().to(device)
150
+
151
+ # ----------------------- do validation step -----------------------
152
+ with torch.no_grad():
153
+ # ----- forward pass -----
154
+ # output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores'])
155
+ # output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints'])
156
+ # output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength'])
157
+ # target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh'])
158
+ output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
159
+ # ----- loss -----
160
+ if metrics == 'no_loss':
161
+ loss, loss_dict = loss_module(output_reproj=output_reproj,
162
+ target_dict=target_dict,
163
+ weight_dict=weight_dict)
164
+ # ----------------------------------------------------------------
165
+
166
+ if i == 0:
167
+ if len_dataset is None:
168
+ len_data = val_loader.batch_size * len(val_loader) # 1703
169
+ else:
170
+ len_data = len_dataset
171
+ if metrics == 'all' or metrics == 'no_loss':
172
+ pck = np.zeros((len_data))
173
+ pck_by_part = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS}
174
+ acc_sil_2d = np.zeros(len_data)
175
+
176
+ all_betas = np.zeros((len_data, output_reproj['betas'].shape[1]))
177
+ all_betas_limbs = np.zeros((len_data, output_reproj['betas_limbs'].shape[1]))
178
+ all_z = np.zeros((len_data, output_reproj['z'].shape[1]))
179
+ all_pose_rotmat = np.zeros((len_data, output_unnorm['pose_rotmat'].shape[1], 3, 3))
180
+ all_flength = np.zeros((len_data, output_unnorm['flength'].shape[1]))
181
+ all_trans = np.zeros((len_data, output_unnorm['trans'].shape[1]))
182
+ all_breed_indices = np.zeros((len_data))
183
+ all_image_names = [] # len_data * [None]
184
+
185
+ index = i
186
+ ind_img = 0
187
+ if save_imgs_path is not None:
188
+ # render predicted 3d models
189
+ visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
190
+ focal_lengths=output_unnorm['flength'],
191
+ color=0) # color=2)
192
+ for ind_img in range(len(target_dict['index'])):
193
+ try:
194
+ if test_name_list is not None:
195
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
196
+ img_name = img_name.split('.')[0]
197
+ else:
198
+ img_name = str(index) + '_' + str(ind_img)
199
+ # save image with predicted keypoints
200
+ out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
201
+ pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
202
+ pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
203
+ pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
204
+ inp_img = input[ind_img, :, :, :].detach().clone()
205
+ save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
206
+ # save predicted 3d model (front view)
207
+ pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
208
+ pred_tex_max = np.max(pred_tex, axis=2)
209
+ out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
210
+ plt.imsave(out_path, pred_tex)
211
+ input_image = input[ind_img, :, :, :].detach().clone()
212
+ for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
213
+ input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
214
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
215
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
216
+ out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
217
+ plt.imsave(out_path, im_masked)
218
+ # save predicted 3d model (side view)
219
+ vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
220
+ roll = np.pi / 2 * torch.ones(1).float().to(device)
221
+ pitch = np.pi / 2 * torch.ones(1).float().to(device)
222
+ tensor_0 = torch.zeros(1).float().to(device)
223
+ tensor_1 = torch.ones(1).float().to(device)
224
+ RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
225
+ RY = torch.stack([
226
+ torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
227
+ torch.stack([tensor_0, tensor_1, tensor_0]),
228
+ torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
229
+ vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3))
230
+ vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
231
+
232
+ visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
233
+ focal_lengths=output_unnorm['flength'],
234
+ color=0) # 2)
235
+ pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
236
+ pred_tex_max = np.max(pred_tex, axis=2)
237
+ out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
238
+ plt.imsave(out_path, pred_tex)
239
+ if render_all:
240
+ # save input image
241
+ inp_img = input[ind_img, :, :, :].detach().clone()
242
+ out_path = save_imgs_path + '/image_' + img_name + '.png'
243
+ save_input_image(inp_img, out_path)
244
+ # save mesh
245
+ V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
246
+ Faces = model.smal.f
247
+ mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
248
+ mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
249
+ except:
250
+ print('dont save an image')
251
+
252
+ if metrics == 'all' or metrics == 'no_loss':
253
+ # prepare a dictionary with all the predicted results
254
+ preds = {}
255
+ preds['betas'] = output_reproj['betas'].cpu().detach().numpy()
256
+ preds['betas_limbs'] = output_reproj['betas_limbs'].cpu().detach().numpy()
257
+ preds['z'] = output_reproj['z'].cpu().detach().numpy()
258
+ preds['pose_rotmat'] = output_unnorm['pose_rotmat'].cpu().detach().numpy()
259
+ preds['flength'] = output_unnorm['flength'].cpu().detach().numpy()
260
+ preds['trans'] = output_unnorm['trans'].cpu().detach().numpy()
261
+ preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1))
262
+ img_names = []
263
+ for ind_img2 in range(0, output_reproj['betas'].shape[0]):
264
+ if test_name_list is not None:
265
+ img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_')
266
+ img_name2 = img_name2.split('.')[0]
267
+ else:
268
+ img_name2 = str(index) + '_' + str(ind_img2)
269
+ img_names.append(img_name2)
270
+ preds['image_names'] = img_names
271
+ # prepare keypoints for PCK calculation - predicted as well as ground truth
272
+ pred_keypoints_norm = output['keypoints_norm'] # -1 to 1
273
+ pred_keypoints_256 = output_reproj['keyp_2d']
274
+ pred_keypoints = pred_keypoints_256
275
+ gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1)
276
+ gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1
277
+ gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm
278
+ # prepare silhouette for IoU calculation - predicted as well as ground truth
279
+ has_seg = target_dict['has_seg']
280
+ img_border_mask = target_dict['img_border_mask'][:, 0, :, :]
281
+ gtseg = target_dict['silh']
282
+ synth_silhouettes = output_reproj['silh'][:, 0, :, :] # output_reproj['silh']
283
+ synth_silhouettes[synth_silhouettes>0.5] = 1
284
+ synth_silhouettes[synth_silhouettes<0.5] = 0
285
+ # calculate PCK as well as IoU (similar to WLDO)
286
+ preds['acc_PCK'] = Metrics.PCK(
287
+ pred_keypoints, gt_keypoints,
288
+ gtseg, has_seg, idxs=EVAL_KEYPOINTS,
289
+ thresh_range=[pck_thresh], # [0.15],
290
+ )
291
+ preds['acc_IOU'] = Metrics.IOU(
292
+ synth_silhouettes, gtseg,
293
+ img_border_mask, mask=has_seg
294
+ )
295
+ for group, group_kps in KEYPOINT_GROUPS.items():
296
+ preds[f'{group}_PCK'] = Metrics.PCK(
297
+ pred_keypoints, gt_keypoints, gtseg, has_seg,
298
+ thresh_range=[pck_thresh], # [0.15],
299
+ idxs=group_kps
300
+ )
301
+ # add results for all images in this batch to lists
302
+ curr_batch_size = pred_keypoints_256.shape[0]
303
+ if not (preds['acc_PCK'].data.cpu().numpy().shape == (pck[my_step * batch_size:my_step * batch_size + curr_batch_size]).shape):
304
+ import pdb; pdb.set_trace()
305
+ pck[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
306
+ acc_sil_2d[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
307
+ for part in pck_by_part:
308
+ pck_by_part[part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
309
+ all_betas[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas']
310
+ all_betas_limbs[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs']
311
+ all_z[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z']
312
+ all_pose_rotmat[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat']
313
+ all_flength[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength']
314
+ all_trans[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans']
315
+ all_breed_indices[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index']
316
+ all_image_names.extend(preds['image_names'])
317
+ # update progress bar
318
+ if progress is not None:
319
+ my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format(
320
+ pck[:(my_step * batch_size + curr_batch_size)].mean(),
321
+ acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean())
322
+ progress.set_postfix_str(my_string)
323
+ else:
324
+ # measure accuracy and record loss
325
+ bs_fake = 1 # batch_size
326
+ losses.update(loss_dict['loss'], bs_fake)
327
+ losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
328
+ losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
329
+ losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
330
+ losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
331
+ losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
332
+ losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
333
+ losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
334
+ acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
335
+ accuracies.update(acc, bs_fake)
336
+ # Show losses as part of the progress bar.
337
+ if progress is not None:
338
+ my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format(
339
+ loss=losses.avg,
340
+ loss_keyp=losses_keyp.avg,
341
+ loss_silh=losses_silh.avg,
342
+ loss_shape=losses_shape.avg,
343
+ loss_pose=losses_pose.avg,
344
+ loss_class=losses_class.avg,
345
+ loss_breed=losses_breed.avg,
346
+ loss_partseg=losses_partseg.avg
347
+ )
348
+ progress.set_postfix_str(my_string)
349
+ my_step += 1
350
+ if metrics == 'all':
351
+ summary = {'pck': pck, 'acc_sil_2d': acc_sil_2d, 'pck_by_part':pck_by_part,
352
+ 'betas': all_betas, 'betas_limbs': all_betas_limbs, 'z': all_z, 'pose_rotmat': all_pose_rotmat,
353
+ 'flenght': all_flength, 'trans': all_trans, 'image_names': all_image_names, 'breed_indices': all_breed_indices}
354
+ return my_string, summary
355
+ elif metrics == 'no_loss':
356
+ return my_string, np.average(np.asarray(acc_sil_2d))
357
+ else:
358
+ return my_string, accuracies.avg
359
+
360
+
361
+ # ---------------------------------------------------------------------------------------------------------------------------
362
+ def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False):
363
+ if save_imgs_path is not None:
364
+ pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
365
+ all_results = []
366
+
367
+ # Put the model in evaluation mode.
368
+ model.eval()
369
+
370
+ iterable = enumerate(val_loader)
371
+
372
+ # information for normalization
373
+ norm_dict = {
374
+ 'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
375
+ 'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
376
+ 'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
377
+ 'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
378
+ 'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
379
+
380
+ '''
381
+ return_mesh_with_gt_groundplane = True
382
+ if return_mesh_with_gt_groundplane:
383
+ remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
384
+ with open(remeshing_path, 'rb') as fp:
385
+ remeshing_dict = pkl.load(fp)
386
+ remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
387
+ remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
388
+
389
+ # from smal_pytorch.smal_model.smal_torch_new import SMAL
390
+ print('start: load smal default model (barc), but only for vertices')
391
+ smal = SMAL()
392
+ print('end: load smal default model (barc), but only for vertices')
393
+ smal_template_verts = smal.v_template.detach().cpu().numpy()
394
+ smal_faces = smal.faces.detach().cpu().numpy()
395
+
396
+ file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.txt', 'a') # append mode
397
+ file_alignment_errors.write(" ----------- start evaluation ------------- \n ")
398
+
399
+ csv_file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.csv', 'w') # write mode
400
+ fieldnames = ['name', 'error']
401
+ writer = csv.DictWriter(csv_file_alignment_errors, fieldnames=fieldnames)
402
+ writer.writeheader()
403
+ '''
404
+
405
+ my_step = 0
406
+ for i, (input, target_dict) in iterable:
407
+ batch_size = input.shape[0]
408
+ input = input.float().to(device)
409
+ partial_results = {}
410
+
411
+ # ----------------------- do visualization step -----------------------
412
+ with torch.no_grad():
413
+ output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
414
+
415
+ index = i
416
+ ind_img = 0
417
+ for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size):
418
+
419
+ try:
420
+ if test_name_list is not None:
421
+ img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
422
+ img_name = img_name.split('.')[0]
423
+ else:
424
+ img_name = str(index) + '_' + str(ind_img)
425
+ partial_results['img_name'] = img_name
426
+ visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
427
+ focal_lengths=output_unnorm['flength'],
428
+ color=0) # 2)
429
+ # save image with predicted keypoints
430
+ pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
431
+ pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
432
+ pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
433
+ inp_img = input[ind_img, :, :, :].detach().clone()
434
+ if save_imgs_path is not None:
435
+ out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
436
+ save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
437
+ # save predicted 3d model
438
+ # (1) front view
439
+ pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
440
+ pred_tex_max = np.max(pred_tex, axis=2)
441
+ partial_results['tex_pred'] = pred_tex
442
+ if save_imgs_path is not None:
443
+ out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
444
+ plt.imsave(out_path, pred_tex)
445
+ input_image = input[ind_img, :, :, :].detach().clone()
446
+ for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
447
+ input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
448
+ im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
449
+ im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
450
+ partial_results['comp_pred'] = im_masked
451
+ if save_imgs_path is not None:
452
+ out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
453
+ plt.imsave(out_path, im_masked)
454
+ # (2) side view
455
+ vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
456
+ roll = np.pi / 2 * torch.ones(1).float().to(device)
457
+ pitch = np.pi / 2 * torch.ones(1).float().to(device)
458
+ tensor_0 = torch.zeros(1).float().to(device)
459
+ tensor_1 = torch.ones(1).float().to(device)
460
+ RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
461
+ RY = torch.stack([
462
+ torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
463
+ torch.stack([tensor_0, tensor_1, tensor_0]),
464
+ torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
465
+ vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3))
466
+ vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
467
+ visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
468
+ focal_lengths=output_unnorm['flength'],
469
+ color=0) # 2)
470
+ pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
471
+ pred_tex_max = np.max(pred_tex, axis=2)
472
+ partial_results['rot_tex_pred'] = pred_tex
473
+ if save_imgs_path is not None:
474
+ out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
475
+ plt.imsave(out_path, pred_tex)
476
+ render_all = True
477
+ if render_all:
478
+ # save input image
479
+ inp_img = input[ind_img, :, :, :].detach().clone()
480
+ if save_imgs_path is not None:
481
+ out_path = save_imgs_path + '/image_' + img_name + '.png'
482
+ save_input_image(inp_img, out_path)
483
+ # save posed mesh
484
+ V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
485
+ Faces = model.smal.f
486
+ mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
487
+ partial_results['mesh_posed'] = mesh_posed
488
+ if save_imgs_path is not None:
489
+ mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
490
+ except:
491
+ print('pass...')
492
+ all_results.append(partial_results)
493
+ if return_results:
494
+ return all_results
495
+ else:
496
+ return
src/configs/SMAL_configs.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+
7
+
8
+ # SMAL_DATA_DIR = '/is/cluster/work/nrueegg/dog_project/pytorch-dogs-inference/src/smal_pytorch/smpl_models/'
9
+ # SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'smal_pytorch', 'smal_data')
10
+ SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'smal_data')
11
+
12
+ # we replace the old SMAL model by a more dog specific model (see BARC cvpr 2022 paper)
13
+ # our model has several differences compared to the original SMAL model, some of them are:
14
+ # - the PCA shape space is recalculated (from partially new data and weighted)
15
+ # - coefficients for limb length changes are allowed (similar to WLDO, we did borrow some of their code)
16
+ # - all dogs have a core of approximately the same length
17
+ # - dogs are centered in their root joint (which is close to the tail base)
18
+ # -> like this the root rotations is always around this joint AND (0, 0, 0)
19
+ # -> before this it would happen that the animal 'slips' from the image middle to the side when rotating it. Now
20
+ # 'trans' also defines the center of the rotation
21
+ # - we correct the back joint locations such that all those joints are more aligned
22
+
23
+ # logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
24
+ # logscale_part_list = ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f']
25
+
26
+ SMAL_MODEL_CONFIG = {
27
+ 'barc': {
28
+ 'smal_model_type': 'barc',
29
+ 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'my_smpl_SMBLD_nbj_v3.pkl'),
30
+ 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl'),
31
+ 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl'),
32
+ 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
33
+ },
34
+ '39dogs_diffsize': {
35
+ 'smal_model_type': '39dogs_diffsize',
36
+ 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_00791_nadine_Jr_4_dog.pkl'),
37
+ 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_00791_nadine_Jr_4_dog.pkl'),
38
+ 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_00791_nadine_Jr_4_dog.pkl'),
39
+ 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
40
+ },
41
+ '39dogs_norm': {
42
+ 'smal_model_type': '39dogs_norm',
43
+ 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_Jr_4_dog.pkl'),
44
+ 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
45
+ 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
46
+ 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
47
+ },
48
+ '39dogs_norm_9ll': { # 9 limb length parameters
49
+ 'smal_model_type': '39dogs_norm_9ll',
50
+ 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_Jr_4_dog.pkl'),
51
+ 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
52
+ 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
53
+ 'logscale_part_list': ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f'],
54
+ },
55
+ '39dogs_norm_newv2': { # front and back legs of equal lengths
56
+ 'smal_model_type': '39dogs_norm_newv2',
57
+ 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_newv2_dog.pkl'),
58
+ 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv2_dog.pkl'),
59
+ 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv2_dog.pkl'),
60
+ 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
61
+ },
62
+ '39dogs_norm_newv3': { # pca on dame AND different front and back legs lengths
63
+ 'smal_model_type': '39dogs_norm_newv3',
64
+ 'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_newv3_dog.pkl'),
65
+ 'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv3_dog.pkl'),
66
+ 'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv3_dog.pkl'),
67
+ 'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
68
+ },
69
+ }
70
+
71
+
72
+ SYMMETRY_INDS_FILE = os.path.join(SMAL_DATA_DIR, 'symmetry_inds.json')
73
+
74
+ mean_dog_bone_lengths_txt = os.path.join(SMAL_DATA_DIR, 'mean_dog_bone_lengths.txt')
75
+
76
+ # some vertex indices, (from silvia zuffi´s code, create_projected_images_cats.py)
77
+ KEY_VIDS = np.array(([1068, 1080, 1029, 1226], # left eye
78
+ [2660, 3030, 2675, 3038], # right eye
79
+ [910], # mouth low
80
+ [360, 1203, 1235, 1230], # front left leg, low
81
+ [3188, 3156, 2327, 3183], # front right leg, low
82
+ [1976, 1974, 1980, 856], # back left leg, low
83
+ [3854, 2820, 3852, 3858], # back right leg, low
84
+ [452, 1811], # tail start
85
+ [416, 235, 182], # front left leg, top
86
+ [2156, 2382, 2203], # front right leg, top
87
+ [829], # back left leg, top
88
+ [2793], # back right leg, top
89
+ [60, 114, 186, 59], # throat, close to base of neck
90
+ [2091, 2037, 2036, 2160], # withers (a bit lower than in reality)
91
+ [384, 799, 1169, 431], # front left leg, middle
92
+ [2351, 2763, 2397, 3127], # front right leg, middle
93
+ [221, 104], # back left leg, middle
94
+ [2754, 2192], # back right leg, middle
95
+ [191, 1158, 3116, 2165], # neck
96
+ [28], # Tail tip
97
+ [542], # Left Ear
98
+ [2507], # Right Ear
99
+ [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip
100
+ [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail
101
+
102
+ # the following vertices are used for visibility only: if one of the vertices is visible,
103
+ # then we assume that the joint is visible! There is some noise, but we don't care, as this is
104
+ # for generation of the synthetic dataset only
105
+ KEY_VIDS_VISIBILITY_ONLY = np.array(([1068, 1080, 1029, 1226, 645], # left eye
106
+ [2660, 3030, 2675, 3038, 2567], # right eye
107
+ [910, 11, 5], # mouth low
108
+ [360, 1203, 1235, 1230, 298, 408, 303, 293, 384], # front left leg, low
109
+ [3188, 3156, 2327, 3183, 2261, 2271, 2573, 2265], # front right leg, low
110
+ [1976, 1974, 1980, 856, 559, 851, 556], # back left leg, low
111
+ [3854, 2820, 3852, 3858, 2524, 2522, 2815, 2072], # back right leg, low
112
+ [452, 1811, 63, 194, 52, 370, 64], # tail start
113
+ [416, 235, 182, 440, 8, 80, 73, 112], # front left leg, top
114
+ [2156, 2382, 2203, 2050, 2052, 2406, 3], # front right leg, top
115
+ [829, 219, 218, 173, 17, 7, 279], # back left leg, top
116
+ [2793, 582, 140, 87, 2188, 2147, 2063], # back right leg, top
117
+ [60, 114, 186, 59, 878, 130, 189, 45], # throat, close to base of neck
118
+ [2091, 2037, 2036, 2160, 190, 2164], # withers (a bit lower than in reality)
119
+ [384, 799, 1169, 431, 321, 314, 437, 310, 323], # front left leg, middle
120
+ [2351, 2763, 2397, 3127, 2278, 2285, 2282, 2275, 2359], # front right leg, middle
121
+ [221, 104, 105, 97, 103], # back left leg, middle
122
+ [2754, 2192, 2080, 2251, 2075, 2074], # back right leg, middle
123
+ [191, 1158, 3116, 2165, 154, 653, 133, 339], # neck
124
+ [28, 474, 475, 731, 24], # Tail tip
125
+ [542, 147, 509, 200, 522], # Left Ear
126
+ [2507,2174, 2122, 2126, 2474], # Right Ear
127
+ [1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip
128
+ [0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail
129
+
130
+ # Keypoint indices for 3d sketchfab evaluation
131
+ SMAL_KEYPOINT_NAMES_FOR_3D_EVAL = ['right_front_paw','right_front_elbow','right_back_paw','right_back_hock','right_ear_top','right_ear_bottom','right_eye', \
132
+ 'left_front_paw','left_front_elbow','left_back_paw','left_back_hock','left_ear_top','left_ear_bottom','left_eye', \
133
+ 'nose','tail_start','tail_end']
134
+ SMAL_KEYPOINT_INDICES_FOR_3D_EVAL = [2577, 2361, 2820, 2085, 2125, 2453, 2668, 613, 394, 855, 786, 149, 486, 1079, 1845, 1820, 28]
135
+ SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL = [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0] # [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
136
+
137
+
138
+
139
+
140
+ # see: https://github.com/benjiebob/SMALify/blob/master/config.py
141
+ # JOINT DEFINITIONS - based on SMAL joints and additional {eyes, ear tips, chin and nose}
142
+ TORSO_JOINTS = [2, 5, 8, 11, 12, 23]
143
+ CANONICAL_MODEL_JOINTS = [
144
+ 10, 9, 8, # upper_left [paw, middle, top]
145
+ 20, 19, 18, # lower_left [paw, middle, top]
146
+ 14, 13, 12, # upper_right [paw, middle, top]
147
+ 24, 23, 22, # lower_right [paw, middle, top]
148
+ 25, 31, # tail [start, end]
149
+ 33, 34, # ear base [left, right]
150
+ 35, 36, # nose, chin
151
+ 38, 37, # ear tip [left, right]
152
+ 39, 40, # eyes [left, right]
153
+ 6, 11, # withers, throat (throat is inaccurate and withers also)
154
+ 28] # tail middle
155
+ # old: 15, 15, # withers, throat (TODO: Labelled same as throat for now), throat
156
+
157
+ CANONICAL_MODEL_JOINTS_REFINED = [
158
+ 41, 9, 8, # upper_left [paw, middle, top]
159
+ 43, 19, 18, # lower_left [paw, middle, top]
160
+ 42, 13, 12, # upper_right [paw, middle, top]
161
+ 44, 23, 22, # lower_right [paw, middle, top]
162
+ 25, 31, # tail [start, end]
163
+ 33, 34, # ear base [left, right]
164
+ 35, 36, # nose, chin
165
+ 38, 37, # ear tip [left, right]
166
+ 39, 40, # eyes [left, right]
167
+ 46, 45, # withers, throat
168
+ 28] # tail middle
169
+
170
+ # the following list gives the indices of the KEY_VIDS_JOINTS that must be taken in order
171
+ # to judge if the CANONICAL_MODEL_JOINTS are visible - those are all approximations!
172
+ CMJ_VISIBILITY_IN_KEY_VIDS = [
173
+ 3, 14, 8, # left front leg
174
+ 5, 16, 10, # left rear leg
175
+ 4, 15, 9, # right front leg
176
+ 6, 17, 11, # right rear leg
177
+ 7, 19, # tail front, tail back
178
+ 20, 21, # ear base (but can not be found in blue, se we take the tip)
179
+ 2, 2, # mouth (was: 22, 2)
180
+ 20, 21, # ear tips
181
+ 1, 0, # eyes
182
+ 18, # withers, not sure where this point is
183
+ 12, # throat
184
+ 23, # mid tail
185
+ ]
186
+
187
+ # define which bone lengths are used as input to the 2d-to-3d network
188
+ IDXS_BONES_NO_REDUNDANCY = [6,7,8,9,16,17,18,19,32,1,2,3,4,5,14,15,24,25,26,27,28,29,30,31]
189
+ # load bone lengths of the mean dog (already filtered)
190
+ mean_dog_bone_lengths = []
191
+ with open(mean_dog_bone_lengths_txt, 'r') as f:
192
+ for line in f:
193
+ mean_dog_bone_lengths.append(float(line.split('\n')[0]))
194
+ MEAN_DOG_BONE_LENGTHS_NO_RED = np.asarray(mean_dog_bone_lengths)[IDXS_BONES_NO_REDUNDANCY] # (24, )
195
+
196
+ # Body part segmentation:
197
+ # the body can be segmented based on the bones and for the new dog model also based on the new shapedirs
198
+ # axis_horizontal = self.shapedirs[2, :].reshape((-1, 3))[:, 0]
199
+ # all_indices = np.arange(3889)
200
+ # tail_indices = all_indices[axis_horizontal.detach().cpu().numpy() < 0.0]
201
+ VERTEX_IDS_TAIL = [ 0, 4, 9, 10, 24, 25, 28, 453, 454, 456, 457,
202
+ 458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468,
203
+ 469, 470, 471, 472, 473, 474, 475, 724, 725, 726, 727,
204
+ 728, 729, 730, 731, 813, 975, 976, 977, 1109, 1110, 1111,
205
+ 1811, 1813, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827,
206
+ 1828, 1835, 1836, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967,
207
+ 1968, 1969, 2418, 2419, 2421, 2422, 2423, 2424, 2425, 2426, 2427,
208
+ 2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438,
209
+ 2439, 2440, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2777,
210
+ 3067, 3068, 3069, 3842, 3843, 3844, 3845, 3846, 3847]
211
+
212
+ # same as in https://github.com/benjiebob/WLDO/blob/master/global_utils/config.py
213
+ EVAL_KEYPOINTS = [
214
+ 0, 1, 2, # left front
215
+ 3, 4, 5, # left rear
216
+ 6, 7, 8, # right front
217
+ 9, 10, 11, # right rear
218
+ 12, 13, # tail start -> end
219
+ 14, 15, # left ear, right ear
220
+ 16, 17, # nose, chin
221
+ 18, 19] # left tip, right tip
222
+
223
+ KEYPOINT_GROUPS = {
224
+ 'legs': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # legs
225
+ 'tail': [12, 13], # tail
226
+ 'ears': [14, 15, 18, 19], # ears
227
+ 'face': [16, 17] # face
228
+ }
229
+
230
+
src/configs/anipose_data_info.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ import json
4
+ import numpy as np
5
+ import os
6
+
7
+ STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics')
8
+ STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json')
9
+
10
+ @dataclass
11
+ class DataInfo:
12
+ rgb_mean: List[float]
13
+ rgb_stddev: List[float]
14
+ joint_names: List[str]
15
+ hflip_indices: List[int]
16
+ n_joints: int
17
+ n_keyp: int
18
+ n_bones: int
19
+ n_betas: int
20
+ image_size: int
21
+ trans_mean: np.ndarray
22
+ trans_std: np.ndarray
23
+ flength_mean: np.ndarray
24
+ flength_std: np.ndarray
25
+ pose_rot6d_mean: np.ndarray
26
+ keypoint_weights: List[float]
27
+
28
+ # SMAL samples 3d statistics
29
+ # statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore
30
+ def load_statistics(statistics_path):
31
+ with open(statistics_path) as f:
32
+ statistics = json.load(f)
33
+ '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']]
34
+ statistics['pose_mean'] = new_pose_mean
35
+ j_out = json.dumps(statistics, indent=4) #, sort_keys=True)
36
+ with open(self.statistics_path, 'w') as file: file.write(j_out)'''
37
+ new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']),
38
+ 'trans_std': np.asarray(statistics['trans_std']),
39
+ 'flength_mean': np.asarray(statistics['flength_mean']),
40
+ 'flength_std': np.asarray(statistics['flength_std']),
41
+ 'pose_mean': np.asarray(statistics['pose_mean']),
42
+ }
43
+ new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6))
44
+ return new_statistics
45
+ STATISTICS = load_statistics(STATISTICS_PATH)
46
+
47
+ AniPose_JOINT_NAMES_swapped = [
48
+ 'L_F_Paw', 'L_F_Knee', 'L_F_Elbow',
49
+ 'L_B_Paw', 'L_B_Knee', 'L_B_Elbow',
50
+ 'R_F_Paw', 'R_F_Knee', 'R_F_Elbow',
51
+ 'R_B_Paw', 'R_B_Knee', 'R_B_Elbow',
52
+ 'TailBase', '_Tail_end_', 'L_EarBase', 'R_EarBase',
53
+ 'Nose', '_Chin_', '_Left_ear_tip_', '_Right_ear_tip_',
54
+ 'L_Eye', 'R_Eye', 'Withers', 'Throat']
55
+
56
+ KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2]
57
+
58
+ COMPLETE_DATA_INFO = DataInfo(
59
+ rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
60
+ rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
61
+ joint_names=AniPose_JOINT_NAMES_swapped, # AniPose_JOINT_NAMES,
62
+ hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23],
63
+ n_joints = 35,
64
+ n_keyp = 24, # 20, # 25,
65
+ n_bones = 24,
66
+ n_betas = 30, # 10,
67
+ image_size = 256,
68
+ trans_mean = STATISTICS['trans_mean'],
69
+ trans_std = STATISTICS['trans_std'],
70
+ flength_mean = STATISTICS['flength_mean'],
71
+ flength_std = STATISTICS['flength_std'],
72
+ pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
73
+ keypoint_weights = KEYPOINT_WEIGHTS
74
+ )
src/configs/barc_cfg_defaults.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from yacs.config import CfgNode as CN
3
+ import argparse
4
+ import yaml
5
+ import os
6
+
7
+ abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
8
+
9
+ _C = CN()
10
+ _C.barc_dir = abs_barc_dir
11
+ _C.device = 'cuda'
12
+
13
+ ## path settings
14
+ _C.paths = CN()
15
+ _C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/'
16
+ _C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/'
17
+ _C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
18
+
19
+ ## parameter settings
20
+ _C.params = CN()
21
+ _C.params.ARCH = 'hg8'
22
+ _C.params.STRUCTURE_POSE_NET = 'normflow' # 'default' # 'vae'
23
+ _C.params.NF_VERSION = 3
24
+ _C.params.N_JOINTS = 35
25
+ _C.params.N_KEYP = 24 #20
26
+ _C.params.N_SEG = 2
27
+ _C.params.N_PARTSEG = 15
28
+ _C.params.UPSAMPLE_SEG = True
29
+ _C.params.ADD_PARTSEG = True # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt)
30
+ _C.params.N_BETAS = 30 # 10
31
+ _C.params.N_BETAS_LIMBS = 7
32
+ _C.params.N_BONES = 24
33
+ _C.params.N_BREEDS = 121 # 120 breeds plus background
34
+ _C.params.IMG_SIZE = 256
35
+ _C.params.SILH_NO_TAIL = False
36
+ _C.params.KP_THRESHOLD = None
37
+ _C.params.ADD_Z_TO_3D_INPUT = False
38
+ _C.params.N_SEGBPS = 64*2
39
+ _C.params.ADD_SEGBPS_TO_3D_INPUT = True
40
+ _C.params.FIX_FLENGTH = False
41
+ _C.params.RENDER_ALL = True
42
+ _C.params.VLIN = 2
43
+ _C.params.STRUCTURE_Z_TO_B = 'lin'
44
+ _C.params.N_Z_FREE = 64
45
+ _C.params.PCK_THRESH = 0.15
46
+ _C.params.REF_NET_TYPE = 'add' # refinement network type
47
+ _C.params.REF_DETACH_SHAPE = True
48
+ _C.params.GRAPHCNN_TYPE = 'inexistent'
49
+ _C.params.ISFLAT_TYPE = 'inexistent'
50
+ _C.params.SHAPEREF_TYPE = 'inexistent'
51
+
52
+ ## SMAL settings
53
+ _C.smal = CN()
54
+ _C.smal.SMAL_MODEL_TYPE = 'barc'
55
+ _C.smal.SMAL_KEYP_CONF = 'green'
56
+
57
+ ## optimization settings
58
+ _C.optim = CN()
59
+ _C.optim.LR = 5e-4
60
+ _C.optim.SCHEDULE = [150, 175, 200]
61
+ _C.optim.GAMMA = 0.1
62
+ _C.optim.MOMENTUM = 0
63
+ _C.optim.WEIGHT_DECAY = 0
64
+ _C.optim.EPOCHS = 220
65
+ _C.optim.BATCH_SIZE = 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
66
+ _C.optim.TRAIN_PARTS = 'all_without_shapedirs'
67
+
68
+ ## dataset settings
69
+ _C.data = CN()
70
+ _C.data.DATASET = 'stanext24'
71
+ _C.data.V12 = True
72
+ _C.data.SHORTEN_VAL_DATASET_TO = None
73
+ _C.data.VAL_OPT = 'val'
74
+ _C.data.VAL_METRICS = 'no_loss'
75
+
76
+ # ---------------------------------------
77
+ def update_dependent_vars(cfg):
78
+ cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG
79
+ if cfg.params.VLIN == 0:
80
+ cfg.params.NUM_STAGE_COMB = 2
81
+ cfg.params.NUM_STAGE_HEADS = 1
82
+ cfg.params.NUM_STAGE_HEADS_POSE = 1
83
+ cfg.params.TRANS_SEP = False
84
+ elif cfg.params.VLIN == 1:
85
+ cfg.params.NUM_STAGE_COMB = 3
86
+ cfg.params.NUM_STAGE_HEADS = 1
87
+ cfg.params.NUM_STAGE_HEADS_POSE = 2
88
+ cfg.params.TRANS_SEP = False
89
+ elif cfg.params.VLIN == 2:
90
+ cfg.params.NUM_STAGE_COMB = 3
91
+ cfg.params.NUM_STAGE_HEADS = 1
92
+ cfg.params.NUM_STAGE_HEADS_POSE = 2
93
+ cfg.params.TRANS_SEP = True
94
+ else:
95
+ raise NotImplementedError
96
+ if cfg.params.STRUCTURE_Z_TO_B == '1dconv':
97
+ cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS
98
+ else:
99
+ cfg.params.N_Z = cfg.params.N_Z_FREE
100
+ return
101
+
102
+
103
+ update_dependent_vars(_C)
104
+ global _cfg_global
105
+ _cfg_global = _C.clone()
106
+
107
+
108
+ def get_cfg_defaults():
109
+ # Get a yacs CfgNode object with default values as defined within this file.
110
+ # Return a clone so that the defaults will not be altered.
111
+ return _C.clone()
112
+
113
+ def update_cfg_global_with_yaml(cfg_yaml_file):
114
+ _cfg_global.merge_from_file(cfg_yaml_file)
115
+ update_dependent_vars(_cfg_global)
116
+ return
117
+
118
+ def get_cfg_global_updated():
119
+ # return _cfg_global.clone()
120
+ return _cfg_global
121
+
src/configs/barc_cfg_train.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ paths:
3
+ ROOT_OUT_PATH: './results/'
4
+ ROOT_CHECKPOINT_PATH: './checkpoint/'
5
+ MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
6
+
7
+ smal:
8
+ SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_diffsize' # 'barc'
9
+ SMAL_KEYP_CONF: 'olive' # 'green'
10
+
11
+ optim:
12
+ LR: 5e-4
13
+ SCHEDULE: [150, 175, 200]
14
+ GAMMA: 0.1
15
+ MOMENTUM: 0
16
+ WEIGHT_DECAY: 0
17
+ EPOCHS: 220
18
+ BATCH_SIZE: 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
19
+ TRAIN_PARTS: 'all_without_shapedirs'
20
+
21
+ data:
22
+ DATASET: 'stanext24'
23
+ SHORTEN_VAL_DATASET_TO: 600 # this is faster as we do not evaluate on the whole validation set
24
+ VAL_OPT: 'val'
src/configs/barc_loss_weights_allzeros.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ {
5
+ "breed_options": [
6
+ "4"
7
+ ],
8
+ "breed": 0.0,
9
+ "class": 0.0,
10
+ "models3d": 0.0,
11
+ "keyp": 0.0,
12
+ "silh": 0.0,
13
+ "shape_options": [
14
+ "smal",
15
+ "limbs7"
16
+ ],
17
+ "shape": [
18
+ 0,
19
+ 0
20
+ ],
21
+ "poseprior_options": [
22
+ "normalizing_flow_tiger_logprob"
23
+ ],
24
+ "poseprior": 0.0,
25
+ "poselegssidemovement": 0.0,
26
+ "flength": 0.0,
27
+ "partseg": 0,
28
+ "shapedirs": 0,
29
+ "pose_0": 0.0
30
+ }
src/configs/barc_loss_weights_with3dcgloss_higherbetaloss_v2_dm39dnnv3v2.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ {
5
+ "breed_options": [
6
+ "4"
7
+ ],
8
+ "breed": 5.0,
9
+ "class": 5.0,
10
+ "models3d": 0.1,
11
+ "keyp": 0.2,
12
+ "silh": 50.0,
13
+ "shape_options": [
14
+ "smal",
15
+ "limbs7"
16
+ ],
17
+ "shape": [
18
+ 0.1,
19
+ 1.0
20
+ ],
21
+ "poseprior_options": [
22
+ "normalizing_flow_tiger_logprob"
23
+ ],
24
+ "poseprior": 0.1,
25
+ "poselegssidemovement": 10.0,
26
+ "flength": 1.0,
27
+ "partseg": 0,
28
+ "shapedirs": 0,
29
+ "pose_0": 0.0
30
+ }
src/configs/data_info.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ import json
4
+ import numpy as np
5
+ import os
6
+ import sys
7
+
8
+ STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics')
9
+ STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json')
10
+
11
+ @dataclass
12
+ class DataInfo:
13
+ rgb_mean: List[float]
14
+ rgb_stddev: List[float]
15
+ joint_names: List[str]
16
+ hflip_indices: List[int]
17
+ n_joints: int
18
+ n_keyp: int
19
+ n_bones: int
20
+ n_betas: int
21
+ image_size: int
22
+ trans_mean: np.ndarray
23
+ trans_std: np.ndarray
24
+ flength_mean: np.ndarray
25
+ flength_std: np.ndarray
26
+ pose_rot6d_mean: np.ndarray
27
+ keypoint_weights: List[float]
28
+
29
+ # SMAL samples 3d statistics
30
+ # statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore
31
+ def load_statistics(statistics_path):
32
+ with open(statistics_path) as f:
33
+ statistics = json.load(f)
34
+ '''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']]
35
+ statistics['pose_mean'] = new_pose_mean
36
+ j_out = json.dumps(statistics, indent=4) #, sort_keys=True)
37
+ with open(self.statistics_path, 'w') as file: file.write(j_out)'''
38
+ new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']),
39
+ 'trans_std': np.asarray(statistics['trans_std']),
40
+ 'flength_mean': np.asarray(statistics['flength_mean']),
41
+ 'flength_std': np.asarray(statistics['flength_std']),
42
+ 'pose_mean': np.asarray(statistics['pose_mean']),
43
+ }
44
+ new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6))
45
+ return new_statistics
46
+ STATISTICS = load_statistics(STATISTICS_PATH)
47
+
48
+
49
+ ############################################################################
50
+ # for StanExt (original number of keypoints, 20 not 24)
51
+
52
+ # for keypoint names see: https://github.com/benjiebob/StanfordExtra/blob/master/keypoint_definitions.csv
53
+ StanExt_JOINT_NAMES = [
54
+ 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top',
55
+ 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top',
56
+ 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top',
57
+ 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top',
58
+ 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear',
59
+ 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip']
60
+
61
+ KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2]
62
+
63
+ COMPLETE_DATA_INFO = DataInfo(
64
+ rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
65
+ rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
66
+ joint_names=StanExt_JOINT_NAMES,
67
+ hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18],
68
+ n_joints = 35,
69
+ n_keyp = 20, # 25,
70
+ n_bones = 24,
71
+ n_betas = 30, # 10,
72
+ image_size = 256,
73
+ trans_mean = STATISTICS['trans_mean'],
74
+ trans_std = STATISTICS['trans_std'],
75
+ flength_mean = STATISTICS['flength_mean'],
76
+ flength_std = STATISTICS['flength_std'],
77
+ pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
78
+ keypoint_weights = KEYPOINT_WEIGHTS
79
+ )
80
+
81
+
82
+ ############################################################################
83
+ # new for StanExt24
84
+
85
+ # ..., 'Left_eye', 'Right_eye', 'Withers', 'Throat'] # the last 4 keypoints are in the animal_pose dataset, but not StanfordExtra
86
+ StanExt_JOINT_NAMES_24 = [
87
+ 'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top',
88
+ 'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top',
89
+ 'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top',
90
+ 'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top',
91
+ 'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear',
92
+ 'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip',
93
+ 'Left_eye', 'Right_eye', 'Withers', 'Throat']
94
+
95
+ KEYPOINT_WEIGHTS_24 = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2, 1, 1, 0, 0]
96
+
97
+ COMPLETE_DATA_INFO_24 = DataInfo(
98
+ rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
99
+ rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
100
+ joint_names=StanExt_JOINT_NAMES_24,
101
+ hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23],
102
+ n_joints = 35,
103
+ n_keyp = 24, # 20, # 25,
104
+ n_bones = 24,
105
+ n_betas = 30, # 10,
106
+ image_size = 256,
107
+ trans_mean = STATISTICS['trans_mean'],
108
+ trans_std = STATISTICS['trans_std'],
109
+ flength_mean = STATISTICS['flength_mean'],
110
+ flength_std = STATISTICS['flength_std'],
111
+ pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
112
+ keypoint_weights = KEYPOINT_WEIGHTS_24
113
+ )
114
+
115
+
src/configs/dataset_path_configs.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import numpy as np
4
+ import os
5
+ import sys
6
+
7
+ abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
8
+
9
+ # stanext dataset
10
+ # (1) path to stanext dataset
11
+ STAN_V12_ROOT_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset' + '/StanfordExtra_V12/'
12
+ IMG_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'StanExtV12_Images')
13
+ JSON_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', "StanfordExtra_v12.json")
14
+ STAN_V12_TRAIN_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'train_stanford_StanfordExtra_v12.npy')
15
+ STAN_V12_VAL_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'val_stanford_StanfordExtra_v12.npy')
16
+ STAN_V12_TEST_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'test_stanford_StanfordExtra_v12.npy')
17
+ # (2) path to related data such as breed indices and prepared predictions for withers, throat and eye keypoints
18
+ STANEXT_RELATED_DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'stanext_related_data')
19
+
20
+ # test image crop dataset
21
+ TEST_IMAGE_CROP_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'datasets', 'test_image_crops')
src/configs/dog_breeds/dog_breed_class.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import warnings
4
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
5
+ import pandas as pd
6
+ import difflib
7
+ import json
8
+ import pickle as pkl
9
+ import csv
10
+ import numpy as np
11
+
12
+
13
+ # ----------------------------------------------------------------------------------------------------------------- #
14
+ class DogBreed(object):
15
+ def __init__(self, abbrev, name_akc=None, name_stanext=None, name_xlsx=None, path_akc=None, path_stanext=None, ind_in_xlsx=None, ind_in_xlsx_matrix=None, ind_in_stanext=None, clade=None):
16
+ self._abbrev = abbrev
17
+ self._name_xlsx = name_xlsx
18
+ self._name_akc = name_akc
19
+ self._name_stanext = name_stanext
20
+ self._path_stanext = path_stanext
21
+ self._additional_names = set()
22
+ if self._name_akc is not None:
23
+ self.add_akc_info(name_akc, path_akc)
24
+ if self._name_stanext is not None:
25
+ self.add_stanext_info(name_stanext, path_stanext, ind_in_stanext)
26
+ if self._name_xlsx is not None:
27
+ self.add_xlsx_info(name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade)
28
+ def add_xlsx_info(self, name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade):
29
+ assert (name_xlsx is not None) and (ind_in_xlsx is not None) and (ind_in_xlsx_matrix is not None) and (clade is not None)
30
+ self._name_xlsx = name_xlsx
31
+ self._ind_in_xlsx = ind_in_xlsx
32
+ self._ind_in_xlsx_matrix = ind_in_xlsx_matrix
33
+ self._clade = clade
34
+ def add_stanext_info(self, name_stanext, path_stanext, ind_in_stanext):
35
+ assert (name_stanext is not None) and (path_stanext is not None) and (ind_in_stanext is not None)
36
+ self._name_stanext = name_stanext
37
+ self._path_stanext = path_stanext
38
+ self._ind_in_stanext = ind_in_stanext
39
+ def add_akc_info(self, name_akc, path_akc):
40
+ assert (name_akc is not None) and (path_akc is not None)
41
+ self._name_akc = name_akc
42
+ self._path_akc = path_akc
43
+ def add_additional_names(self, name_list):
44
+ self._additional_names = self._additional_names.union(set(name_list))
45
+ def add_text_info(self, text_height, text_weight, text_life_exp):
46
+ self._text_height = text_height
47
+ self._text_weight = text_weight
48
+ self._text_life_exp = text_life_exp
49
+ def get_datasets(self):
50
+ # all datasets in which this breed is found
51
+ datasets = set()
52
+ if self._name_akc is not None:
53
+ datasets.add('akc')
54
+ if self._name_stanext is not None:
55
+ datasets.add('stanext')
56
+ if self._name_xlsx is not None:
57
+ datasets.add('xlsx')
58
+ return datasets
59
+ def get_names(self):
60
+ # set of names for this breed
61
+ names = {self._abbrev, self._name_akc, self._name_stanext, self._name_xlsx, self._path_stanext}.union(self._additional_names)
62
+ names.discard(None)
63
+ return names
64
+ def get_names_as_pointing_dict(self):
65
+ # each name points to the abbreviation
66
+ names = self.get_names()
67
+ my_dict = {}
68
+ for name in names:
69
+ my_dict[name] = self._abbrev
70
+ return my_dict
71
+ def print_overview(self):
72
+ # print important information to get an overview of the class instance
73
+ if self._name_akc is not None:
74
+ name = self._name_akc
75
+ elif self._name_xlsx is not None:
76
+ name = self._name_xlsx
77
+ else:
78
+ name = self._name_stanext
79
+ print('----------------------------------------------------')
80
+ print('----- dog breed: ' + name )
81
+ print('----------------------------------------------------')
82
+ print('[names]')
83
+ print(self.get_names())
84
+ print('[datasets]')
85
+ print(self.get_datasets())
86
+ # see https://stackoverflow.com/questions/9058305/getting-attributes-of-a-class
87
+ print('[instance attributes]')
88
+ for attribute, value in self.__dict__.items():
89
+ print(attribute, '=', value)
90
+ def use_dict_to_save_class_instance(self):
91
+ my_dict = {}
92
+ for attribute, value in self.__dict__.items():
93
+ my_dict[attribute] = value
94
+ return my_dict
95
+ def use_dict_to_load_class_instance(self, my_dict):
96
+ for attribute, value in my_dict.items():
97
+ setattr(self, attribute, value)
98
+ return
99
+
100
+ # ----------------------------------------------------------------------------------------------------------------- #
101
+ def get_name_list_from_summary(summary):
102
+ name_from_abbrev_dict = {}
103
+ for breed in summary.values():
104
+ abbrev = breed._abbrev
105
+ all_names = breed.get_names()
106
+ name_from_abbrev_dict[abbrev] = list(all_names)
107
+ return name_from_abbrev_dict
108
+ def get_partial_summary(summary, part):
109
+ assert part in ['xlsx', 'akc', 'stanext']
110
+ partial_summary = {}
111
+ for key, value in summary.items():
112
+ if (part == 'xlsx' and value._name_xlsx is not None) \
113
+ or (part == 'akc' and value._name_akc is not None) \
114
+ or (part == 'stanext' and value._name_stanext is not None):
115
+ partial_summary[key] = value
116
+ return partial_summary
117
+ def get_akc_but_not_stanext_partial_summary(summary):
118
+ partial_summary = {}
119
+ for key, value in summary.items():
120
+ if value._name_akc is not None:
121
+ if value._name_stanext is None:
122
+ partial_summary[key] = value
123
+ return partial_summary
124
+
125
+ # ----------------------------------------------------------------------------------------------------------------- #
126
+ def main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1):
127
+ with open(path_complete_abbrev_dict_v1, 'rb') as file:
128
+ complete_abbrev_dict = pkl.load(file)
129
+ with open(path_complete_summary_breeds_v1, 'rb') as file:
130
+ complete_summary_breeds_attributes_only = pkl.load(file)
131
+
132
+ complete_summary_breeds = {}
133
+ for key, value in complete_summary_breeds_attributes_only.items():
134
+ attributes_only = complete_summary_breeds_attributes_only[key]
135
+ complete_summary_breeds[key] = DogBreed(abbrev=attributes_only['_abbrev'])
136
+ complete_summary_breeds[key].use_dict_to_load_class_instance(attributes_only)
137
+ return complete_abbrev_dict, complete_summary_breeds
138
+
139
+
140
+ # ----------------------------------------------------------------------------------------------------------------- #
141
+ def load_similarity_matrix_raw(xlsx_path):
142
+ # --- LOAD EXCEL FILE FROM DOG BREED PAPER
143
+ xlsx = pd.read_excel(xlsx_path)
144
+ # create an array
145
+ abbrev_indices = {}
146
+ matrix_raw = np.zeros((168, 168))
147
+ for ind in range(1, 169):
148
+ abbrev = xlsx[xlsx.columns[2]][ind]
149
+ abbrev_indices[abbrev] = ind-1
150
+ for ind_col in range(0, 168):
151
+ for ind_row in range(0, 168):
152
+ matrix_raw[ind_col, ind_row] = float(xlsx[xlsx.columns[3+ind_col]][1+ind_row])
153
+ return matrix_raw, abbrev_indices
154
+
155
+
156
+
157
+ # ----------------------------------------------------------------------------------------------------------------- #
158
+ # ----------------------------------------------------------------------------------------------------------------- #
159
+ # load the (in advance created) final dict of dog breed classes
160
+ ROOT_PATH_BREED_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', 'data', 'breed_data')
161
+ path_complete_abbrev_dict_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_abbrev_dict_v2.pkl')
162
+ path_complete_summary_breeds_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_summary_breeds_v2.pkl')
163
+ COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS = main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1)
164
+ # load similarity matrix, data from:
165
+ # Parker H. G., Dreger D. L., Rimbault M., Davis B. W., Mullen A. B., Carpintero-Ramirez G., and Ostrander E. A.
166
+ # Genomic analyses reveal the influence of geographic origin, migration, and hybridization on modern dog breed
167
+ # development. Cell Reports, 4(19):697–708, 2017.
168
+ xlsx_path = os.path.join(ROOT_PATH_BREED_DATA, 'NIHMS866262-supplement-2.xlsx')
169
+ SIM_MATRIX_RAW, SIM_ABBREV_INDICES = load_similarity_matrix_raw(xlsx_path)
170
+
src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ paths:
3
+ ROOT_OUT_PATH: './results/'
4
+ ROOT_CHECKPOINT_PATH: './checkpoint/'
5
+ MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
6
+
7
+ smal:
8
+ SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc'
9
+ SMAL_KEYP_CONF: 'olive' # 'green'
10
+
11
+ optim:
12
+ BATCH_SIZE: 12
13
+
14
+ params:
15
+ REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot' # 'add'
16
+ REF_DETACH_SHAPE: True
17
+ GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent'
18
+ SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent'
19
+ ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent'
20
+
21
+ data:
22
+ DATASET: 'stanext24'
23
+ VAL_OPT: 'test' # 'val'
src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ paths:
3
+ ROOT_OUT_PATH: './results/'
4
+ ROOT_CHECKPOINT_PATH: './checkpoint/'
5
+ MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
6
+
7
+ smal:
8
+ SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc'
9
+ SMAL_KEYP_CONF: 'olive' # 'green'
10
+
11
+ optim:
12
+ BATCH_SIZE: 12
13
+
14
+ params:
15
+ REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot' # 'add'
16
+ REF_DETACH_SHAPE: True
17
+ GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent'
18
+ SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent'
19
+ ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent'
20
+
21
+ data:
22
+ DATASET: 'ImgCropList'
23
+ VAL_OPT: 'test' # 'val'
src/configs/refinement_cfg_train_withvertexwisegc_isflat_csmorestanding.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ paths:
3
+ ROOT_OUT_PATH: './results/'
4
+ ROOT_CHECKPOINT_PATH: './checkpoint/'
5
+ MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
6
+
7
+ smal:
8
+ SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc'
9
+ SMAL_KEYP_CONF: 'olive' # 'green'
10
+
11
+ optim:
12
+ LR: 5e-5 # 5e-7 # (new) 5e-6 # 5e-5 # 5e-5 # 5e-4
13
+ SCHEDULE: [150, 175, 200] # [220, 270] # [150, 175, 200]
14
+ GAMMA: 0.1
15
+ MOMENTUM: 0
16
+ WEIGHT_DECAY: 0
17
+ EPOCHS: 220 # 300
18
+ BATCH_SIZE: 14 # 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
19
+ TRAIN_PARTS: 'refinement_model' # 'refinement_model_and_shape' # 'refinement_model'
20
+
21
+ params:
22
+ REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot01' # 'multrot01' # 'multrot' # 'multrot_res34' # 'multrot' # 'add'
23
+ REF_DETACH_SHAPE: True
24
+ GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent'
25
+ SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent'
26
+ ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent'
27
+
28
+ data:
29
+ DATASET: 'stanext24_withgc_csaddnonflatmorestanding' # 'stanext24_withgc_csaddnonflat' # 'stanext24_withgc_cs0'
30
+ SHORTEN_VAL_DATASET_TO: 600 # this is faster as we do not evaluate on the whole validation set
31
+ VAL_OPT: 'val'
src/configs/refinement_loss_weights_withgc_withvertexwise_addnonflat.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ {
5
+ "keyp_ref": 0.2,
6
+ "silh_ref": 50.0,
7
+ "pose_legs_side": 1.0,
8
+ "pose_legs_tors": 1.0,
9
+ "pose_tail_side": 0.0,
10
+ "pose_tail_tors": 0.0,
11
+ "pose_spine_side": 0.0,
12
+ "pose_spine_tors": 0.0,
13
+ "reg_trans": 0.0,
14
+ "reg_flength": 0.0,
15
+ "reg_pose": 0.0,
16
+ "gc_plane": 5.0,
17
+ "gc_blowplane": 5.0,
18
+ "gc_vertexwise": 10.0,
19
+ "gc_isflat": 0.5
20
+ }
src/configs/ttopt_loss_weights/bite_loss_weights_ttopt.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "silhouette": {
3
+ "weight": 40.0,
4
+ "weight_vshift": 20.0,
5
+ "value": 0.0
6
+ },
7
+ "keyp":{
8
+ "weight": 0.2,
9
+ "weight_vshift": 0.01,
10
+ "value": 0.0
11
+ },
12
+ "pose_legs_side":{
13
+ "weight": 1.0,
14
+ "weight_vshift": 1.0,
15
+ "value": 0.0
16
+ },
17
+ "pose_legs_tors":{
18
+ "weight": 10.0,
19
+ "weight_vshift": 10.0,
20
+ "value": 0.0
21
+ },
22
+ "pose_tail_side":{
23
+ "weight": 1,
24
+ "weight_vshift": 1,
25
+ "value": 0.0
26
+ },
27
+ "pose_tail_tors":{
28
+ "weight": 10.0,
29
+ "weight_vshift": 10.0,
30
+ "value": 0.0
31
+ },
32
+ "pose_spine_side":{
33
+ "weight": 0.0,
34
+ "weight_vshift": 0.0,
35
+ "value": 0.0
36
+ },
37
+ "pose_spine_tors":{
38
+ "weight": 0.0,
39
+ "weight_vshift": 0.0,
40
+ "value": 0.0
41
+ },
42
+ "gc_plane":{
43
+ "weight": 10.0,
44
+ "weight_vshift": 20.0,
45
+ "value": 0.0
46
+ },
47
+ "gc_belowplane":{
48
+ "weight": 10.0,
49
+ "weight_vshift": 20.0,
50
+ "value": 0.0
51
+ },
52
+ "lapctf": {
53
+ "weight": 0.0,
54
+ "weight_vshift": 10.0,
55
+ "value": 0.0
56
+ },
57
+ "arap": {
58
+ "weight": 0.0,
59
+ "weight_vshift": 0.0,
60
+ "value": 0.0
61
+ },
62
+ "edge": {
63
+ "weight": 0.0,
64
+ "weight_vshift": 10.0,
65
+ "value": 0.0
66
+ },
67
+ "normal": {
68
+ "weight": 0.0,
69
+ "weight_vshift": 1.0,
70
+ "value": 0.0
71
+ },
72
+ "laplacian": {
73
+ "weight": 0.0,
74
+ "weight_vshift": 0.0,
75
+ "value": 0.0
76
+ }
77
+ }
src/configs/ttopt_loss_weights/ttopt_loss_weights_v2c_withlapcft_v2.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "silhouette": {
3
+ "weight": 40.0,
4
+ "weight_vshift": 20.0,
5
+ "value": 0.0
6
+ },
7
+ "keyp":{
8
+ "weight": 0.2,
9
+ "weight_vshift": 0.01,
10
+ "value": 0.0
11
+ },
12
+ "pose_legs_side":{
13
+ "weight": 1.0,
14
+ "weight_vshift": 1.0,
15
+ "value": 0.0
16
+ },
17
+ "pose_legs_tors":{
18
+ "weight": 10.0,
19
+ "weight_vshift": 10.0,
20
+ "value": 0.0
21
+ },
22
+ "pose_tail_side":{
23
+ "weight": 1,
24
+ "weight_vshift": 1,
25
+ "value": 0.0
26
+ },
27
+ "pose_tail_tors":{
28
+ "weight": 10.0,
29
+ "weight_vshift": 10.0,
30
+ "value": 0.0
31
+ },
32
+ "pose_spine_side":{
33
+ "weight": 0.0,
34
+ "weight_vshift": 0.0,
35
+ "value": 0.0
36
+ },
37
+ "pose_spine_tors":{
38
+ "weight": 0.0,
39
+ "weight_vshift": 0.0,
40
+ "value": 0.0
41
+ },
42
+ "gc_plane":{
43
+ "weight": 10.0,
44
+ "weight_vshift": 20.0,
45
+ "value": 0.0
46
+ },
47
+ "gc_belowplane":{
48
+ "weight": 10.0,
49
+ "weight_vshift": 20.0,
50
+ "value": 0.0
51
+ },
52
+ "lapctf": {
53
+ "weight": 0.0,
54
+ "weight_vshift": 10.0,
55
+ "value": 0.0
56
+ },
57
+ "arap": {
58
+ "weight": 0.0,
59
+ "weight_vshift": 0.0,
60
+ "value": 0.0
61
+ },
62
+ "edge": {
63
+ "weight": 0.0,
64
+ "weight_vshift": 10.0,
65
+ "value": 0.0
66
+ },
67
+ "normal": {
68
+ "weight": 0.0,
69
+ "weight_vshift": 1.0,
70
+ "value": 0.0
71
+ },
72
+ "laplacian": {
73
+ "weight": 0.0,
74
+ "weight_vshift": 0.0,
75
+ "value": 0.0
76
+ }
77
+ }
src/graph_networks/__init__.py ADDED
File without changes
src/graph_networks/graphcmr/__init__.py ADDED
File without changes
src/graph_networks/graphcmr/get_downsampled_mesh_npz.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # try to use aenv_conda3 (maybe also export PYOPENGL_PLATFORM=osmesa)
3
+ # python src/graph_networks/graphcmr/get_downsampled_mesh_npz.py
4
+
5
+ # see https://github.com/nkolot/GraphCMR/issues/35
6
+
7
+
8
+ from __future__ import print_function
9
+ # import mesh_sampling
10
+ from psbody.mesh import Mesh, MeshViewer, MeshViewers
11
+ import numpy as np
12
+ import json
13
+ import os
14
+ import copy
15
+ import argparse
16
+ import pickle
17
+ import time
18
+ import sys
19
+ import trimesh
20
+
21
+
22
+
23
+ sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))
24
+ from barc_for_bite.src.graph_networks.graphcmr.pytorch_coma_mesh_operations import generate_transform_matrices
25
+ from barc_for_bite.src.configs.SMAL_configs import SMAL_MODEL_CONFIG
26
+ from barc_for_bite.src.smal_pytorch.smal_model.smal_torch_new import SMAL
27
+ # smal_model_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data/new_dog_models/my_smpl_00791_nadine_Jr_4_dog.pkl'
28
+
29
+
30
+ SMAL_MODEL_TYPE = '39dogs_diffsize' # '39dogs_diffsize' # '39dogs_norm' # 'barc'
31
+ smal_model_path = SMAL_MODEL_CONFIG[SMAL_MODEL_TYPE]['smal_model_path']
32
+
33
+ # data_path_root = "/is/cluster/work/nrueegg/icon_pifu_related/ICON/lib/graph_networks/graphcmr/data/"
34
+ data_path_root = "/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/"
35
+
36
+ smal_dog_model_name = os.path.basename(smal_model_path).split('.pkl')[0] # 'my_smpl_SMBLD_nbj_v3'
37
+ suffix = "_template"
38
+ template_obj_path = data_path_root + smal_dog_model_name + suffix + ".obj"
39
+
40
+ print("Loading smal .. ")
41
+ print(SMAL_MODEL_TYPE)
42
+ print(smal_model_path)
43
+
44
+ smal = SMAL(smal_model_type=SMAL_MODEL_TYPE, template_name='neutral')
45
+ smal_verts = smal.v_template.detach().cpu().numpy() # (3889, 3)
46
+ smal_faces = smal.f # (7774, 3)
47
+ smal_trimesh = trimesh.base.Trimesh(vertices=smal_verts, faces=smal_faces, process=False, maintain_order=True)
48
+ smal_trimesh.export(file_obj=template_obj_path) # file_type='obj')
49
+
50
+
51
+ print("Loading data .. ")
52
+ reference_mesh_file = template_obj_path # 'data/barc_neutral_vertices.obj' # 'data/smpl_neutral_vertices.obj'
53
+ reference_mesh = Mesh(filename=reference_mesh_file)
54
+
55
+ # ds_factors = [4, 4] # ds_factors = [4,1] # Sampling factor of the mesh at each stage of sampling
56
+ ds_factors = [4, 4, 4, 4]
57
+ print("Generating Transform Matrices ..")
58
+
59
+
60
+ # Generates adjecency matrices A, downsampling matrices D, and upsamling matrices U by sampling
61
+ # the mesh 4 times. Each time the mesh is sampled by a factor of 4
62
+
63
+ # M,A,D,U = mesh_sampling.generate_transform_matrices(reference_mesh, ds_factors)
64
+ M,A,D,U = generate_transform_matrices(reference_mesh, ds_factors)
65
+
66
+ # REMARK: there is a warning:
67
+ # lib/graph_networks/graphcmr/../../../lib/graph_networks/graphcmr/pytorch_coma_mesh_operations.py:237: FutureWarning: `rcond` parameter will
68
+ # change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.
69
+ # To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.
70
+
71
+
72
+ print(type(A))
73
+ np.savez(data_path_root + 'mesh_downsampling_' + smal_dog_model_name + suffix + '.npz', A = A, D = D, U = U)
74
+ np.savez(data_path_root + 'meshes/' + 'mesh_downsampling_meshes' + smal_dog_model_name + suffix + '.npz', M = M)
75
+
76
+ for ind_m, my_mesh in enumerate(M):
77
+ new_suffix = '_template_downsampled' + str(ind_m)
78
+ my_mesh_tri = trimesh.Trimesh(vertices=my_mesh.v, faces=my_mesh.f, process=False, maintain_order=True)
79
+ my_mesh_tri.export(data_path_root + 'meshes/' + 'mesh_downsampling_meshes' + smal_dog_model_name + new_suffix + '.obj')
80
+
81
+
82
+
83
+
84
+
src/graph_networks/graphcmr/graph_cnn.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
3
+ This file contains the Definition of GraphCNN
4
+ GraphCNN includes ResNet50 as a submodule
5
+ """
6
+ from __future__ import division
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from .graph_layers import GraphResBlock, GraphLinear
12
+ from .resnet import resnet50
13
+
14
+ class GraphCNN(nn.Module):
15
+
16
+ def __init__(self, A, ref_vertices, num_layers=5, num_channels=512):
17
+ super(GraphCNN, self).__init__()
18
+ self.A = A
19
+ self.ref_vertices = ref_vertices
20
+ self.resnet = resnet50(pretrained=True)
21
+ layers = [GraphLinear(3 + 2048, 2 * num_channels)]
22
+ layers.append(GraphResBlock(2 * num_channels, num_channels, A))
23
+ for i in range(num_layers):
24
+ layers.append(GraphResBlock(num_channels, num_channels, A))
25
+ self.shape = nn.Sequential(GraphResBlock(num_channels, 64, A),
26
+ GraphResBlock(64, 32, A),
27
+ nn.GroupNorm(32 // 8, 32),
28
+ nn.ReLU(inplace=True),
29
+ GraphLinear(32, 3))
30
+ self.gc = nn.Sequential(*layers)
31
+ self.camera_fc = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels),
32
+ nn.ReLU(inplace=True),
33
+ GraphLinear(num_channels, 1),
34
+ nn.ReLU(inplace=True),
35
+ nn.Linear(A.shape[0], 3))
36
+
37
+ def forward(self, image):
38
+ """Forward pass
39
+ Inputs:
40
+ image: size = (B, 3, 224, 224)
41
+ Returns:
42
+ Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
43
+ Weak-perspective camera: size = (B, 3)
44
+ """
45
+ batch_size = image.shape[0]
46
+ ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1)
47
+ image_resnet = self.resnet(image)
48
+ image_enc = image_resnet.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1])
49
+ x = torch.cat([ref_vertices, image_enc], dim=1)
50
+ x = self.gc(x)
51
+ shape = self.shape(x)
52
+ camera = self.camera_fc(x).view(batch_size, 3)
53
+ return shape, camera
src/graph_networks/graphcmr/graph_cnn_groundcontact.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
3
+ This file contains the Definition of GraphCNN
4
+ GraphCNN includes ResNet50 as a submodule
5
+ """
6
+ from __future__ import division
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ # from .resnet import resnet50
12
+ import torchvision.models as models
13
+
14
+
15
+ import os
16
+ import sys
17
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
18
+ from src.graph_networks.graphcmr.utils_mesh import Mesh
19
+ from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear
20
+
21
+
22
+ class GraphCNN(nn.Module):
23
+
24
+ def __init__(self, A, ref_vertices, n_resnet_in, n_resnet_out, num_layers=5, num_channels=512):
25
+ super(GraphCNN, self).__init__()
26
+ self.A = A
27
+ self.ref_vertices = ref_vertices
28
+ # self.resnet = resnet50(pretrained=True)
29
+ # -> within the GraphCMR network they ignore the last fully connected layer
30
+ # replace the first layer
31
+ self.resnet = models.resnet34(pretrained=False)
32
+ n_in = 3 + 1
33
+ self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
34
+ # replace the last layer
35
+ self.resnet.fc = nn.Linear(512, n_resnet_out)
36
+
37
+
38
+ layers = [GraphLinear(3 + n_resnet_out, 2 * num_channels)] # [GraphLinear(3 + 2048, 2 * num_channels)]
39
+ layers.append(GraphResBlock(2 * num_channels, num_channels, A))
40
+ for i in range(num_layers):
41
+ layers.append(GraphResBlock(num_channels, num_channels, A))
42
+ self.n_out_gc = 2 # two labels per vertex
43
+ self.gc = nn.Sequential(GraphResBlock(num_channels, 64, A),
44
+ GraphResBlock(64, 32, A),
45
+ nn.GroupNorm(32 // 8, 32),
46
+ nn.ReLU(inplace=True),
47
+ GraphLinear(32, self.n_out_gc))
48
+ self.gcnn = nn.Sequential(*layers)
49
+ self.n_out_flatground = 1
50
+ self.flat_ground = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels),
51
+ nn.ReLU(inplace=True),
52
+ GraphLinear(num_channels, 1),
53
+ nn.ReLU(inplace=True),
54
+ nn.Linear(A.shape[0], self.n_out_flatground))
55
+
56
+ def forward(self, image):
57
+ """Forward pass
58
+ Inputs:
59
+ image: size = (B, 3, 256, 256)
60
+ Returns:
61
+ Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
62
+ Weak-perspective camera: size = (B, 3)
63
+ """
64
+ # import pdb; pdb.set_trace()
65
+
66
+ batch_size = image.shape[0]
67
+ ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
68
+ image_resnet = self.resnet(image) # (bs, 512)
69
+ image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973)
70
+ x = torch.cat([ref_vertices, image_enc], dim=1)
71
+ x = self.gcnn(x) # (bs, 512, 973)
72
+ ground_contact = self.gc(x) # (bs, 2, 973)
73
+ ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1)
74
+ return ground_contact, ground_flatness
75
+
76
+
77
+
78
+
79
+ # how to use it:
80
+ #
81
+ # from src.graph_networks.graphcmr.utils_mesh import Mesh
82
+ #
83
+ # create Mesh object
84
+ # self.mesh = Mesh()
85
+ # self.faces = self.mesh.faces.to(self.device)
86
+ #
87
+ # create GraphCNN
88
+ # self.graph_cnn = GraphCNN(self.mesh.adjmat,
89
+ # self.mesh.ref_vertices.t(),
90
+ # num_channels=self.options.num_channels,
91
+ # num_layers=self.options.num_layers
92
+ # ).to(self.device)
93
+ # ------------
94
+ #
95
+ # Feed image in the GraphCNN
96
+ # Returns subsampled mesh and camera parameters
97
+ # pred_vertices_sub, pred_camera = self.graph_cnn(images)
98
+ #
99
+ # Upsample mesh in the original size
100
+ # pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2))
101
+ #
src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ code from
3
+ https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
4
+ https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py
5
+ This file contains the Definition of GraphCNN
6
+ GraphCNN includes ResNet50 as a submodule
7
+ """
8
+ from __future__ import division
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ # from .resnet import resnet50
14
+ import torchvision.models as models
15
+
16
+
17
+ import os
18
+ import sys
19
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
20
+ from src.graph_networks.graphcmr.utils_mesh import Mesh
21
+ from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear
22
+
23
+
24
+ class GraphCNNMS(nn.Module):
25
+
26
+ def __init__(self, mesh, num_downsample=0, num_layers=5, n_resnet_out=256, num_channels=256):
27
+ '''
28
+ Args:
29
+ mesh: mesh data that store the adjacency matrix
30
+ num_channels: number of channels of GCN
31
+ num_downsample: number of downsampling of the input mesh
32
+ '''
33
+
34
+ super(GraphCNNMS, self).__init__()
35
+
36
+ self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled
37
+ # self.num_layers = len(self.A) - 1
38
+ self.num_layers = num_layers
39
+ assert self.num_layers <= len(self.A) - 1
40
+ print("Number of downsampling layer: {}".format(self.num_layers))
41
+ self.num_downsample = num_downsample
42
+ self.n_resnet_out = n_resnet_out
43
+
44
+
45
+ '''
46
+ self.use_pret_res = use_pret_res
47
+ # self.resnet = resnet50(pretrained=True)
48
+ # -> within the GraphCMR network they ignore the last fully connected layer
49
+ # replace the first layer
50
+ self.resnet = models.resnet34(pretrained=self.use_pret_res)
51
+ if (self.use_pret_res) and (n_resnet_in == 3):
52
+ print('use full pretrained resnet including first layer!')
53
+ else:
54
+ self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
55
+ # replace the last layer
56
+ self.resnet.fc = nn.Linear(512, n_resnet_out)
57
+ '''
58
+
59
+ self.lin1 = GraphLinear(3 + n_resnet_out, 2 * num_channels)
60
+ self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0])
61
+ encode_layers = []
62
+ decode_layers = []
63
+
64
+ for i in range(self.num_layers + 1): # range(len(self.A)):
65
+ encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i]))
66
+
67
+ decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels,
68
+ self.A[self.num_layers - i]))
69
+ current_channels = (i+1)*num_channels
70
+ # number of channels for the input is different because of the concatenation operation
71
+ self.n_out_gc = 2 # two labels per vertex
72
+ self.gc = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]),
73
+ GraphResBlock(64, 32, self.A[0]),
74
+ nn.GroupNorm(32 // 8, 32),
75
+ nn.ReLU(inplace=True),
76
+ GraphLinear(32, self.n_out_gc))
77
+
78
+ '''
79
+ self.n_out_flatground = 2
80
+ self.flat_ground = nn.Sequential(nn.GroupNorm(current_channels // 8, current_channels),
81
+ nn.ReLU(inplace=True),
82
+ GraphLinear(current_channels, 1),
83
+ nn.ReLU(inplace=True),
84
+ nn.Linear(A.shape[0], self.n_out_flatground))
85
+ '''
86
+
87
+ self.encoder = nn.Sequential(*encode_layers)
88
+ self.decoder = nn.Sequential(*decode_layers)
89
+ self.mesh = mesh
90
+
91
+
92
+
93
+
94
+ def forward(self, image_enc):
95
+ """Forward pass
96
+ Inputs:
97
+ image_enc: size = (B, self.n_resnet_out)
98
+ Returns:
99
+ Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
100
+ Weak-perspective camera: size = (B, 3)
101
+ """
102
+ # import pdb; pdb.set_trace()
103
+
104
+ batch_size = image_enc.shape[0]
105
+ # ref_vertices = (self.mesh.get_ref_vertices(n=self.num_downsample).t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
106
+ ref_vertices = (self.mesh.ref_vertices.t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
107
+ '''image_resnet = self.resnet(image) # (bs, 512)'''
108
+ image_enc_prep = image_enc.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973)
109
+
110
+ # prepare network input
111
+ # -> for each node we feed the location of the vertex in the template mesh and an image encoding
112
+ x = torch.cat([ref_vertices, image_enc_prep], dim=1)
113
+ x = self.lin1(x)
114
+ x = self.res1(x)
115
+ x_ = [x]
116
+ output_list = []
117
+ for i in range(self.num_layers + 1):
118
+ if i == self.num_layers:
119
+ x = self.encoder[i](x)
120
+ else:
121
+ x = self.encoder[i](x)
122
+ x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1)
123
+ x = x.transpose(1, 2)
124
+ if i < self.num_layers-1:
125
+ x_.append(x)
126
+ for i in range(self.num_layers + 1):
127
+ if i == self.num_layers:
128
+ x = self.decoder[i](x)
129
+ output_list.append(x)
130
+ else:
131
+ x = self.decoder[i](x)
132
+ output_list.append(x)
133
+ x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample,
134
+ n2=self.num_layers-i-1+self.num_downsample)
135
+ x = x.transpose(1, 2)
136
+ x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder
137
+
138
+ ground_contact = self.gc(x)
139
+
140
+ '''
141
+ ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1)
142
+ '''
143
+
144
+ return ground_contact, output_list # , ground_flatness
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+ # how to use it:
153
+ #
154
+ # from src.graph_networks.graphcmr.utils_mesh import Mesh
155
+ #
156
+ # create Mesh object
157
+ # self.mesh = Mesh()
158
+ # self.faces = self.mesh.faces.to(self.device)
159
+ #
160
+ # create GraphCNN
161
+ # self.graph_cnn = GraphCNN(self.mesh.adjmat,
162
+ # self.mesh.ref_vertices.t(),
163
+ # num_channels=self.options.num_channels,
164
+ # num_layers=self.options.num_layers
165
+ # ).to(self.device)
166
+ # ------------
167
+ #
168
+ # Feed image in the GraphCNN
169
+ # Returns subsampled mesh and camera parameters
170
+ # pred_vertices_sub, pred_camera = self.graph_cnn(images)
171
+ #
172
+ # Upsample mesh in the original size
173
+ # pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2))
174
+ #
src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage_includingresnet.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ code from
3
+ https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
4
+ https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py
5
+ This file contains the Definition of GraphCNN
6
+ GraphCNN includes ResNet50 as a submodule
7
+ """
8
+ from __future__ import division
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ # from .resnet import resnet50
14
+ import torchvision.models as models
15
+
16
+
17
+ import os
18
+ import sys
19
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
20
+ from src.graph_networks.graphcmr.utils_mesh import Mesh
21
+ from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear
22
+
23
+
24
+ class GraphCNNMS(nn.Module):
25
+
26
+ def __init__(self, mesh, num_downsample=0, num_layers=5, n_resnet_in=3, n_resnet_out=256, num_channels=256, use_pret_res=False):
27
+ '''
28
+ Args:
29
+ mesh: mesh data that store the adjacency matrix
30
+ num_channels: number of channels of GCN
31
+ num_downsample: number of downsampling of the input mesh
32
+ '''
33
+
34
+ super(GraphCNNMS, self).__init__()
35
+
36
+ self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled
37
+ # self.num_layers = len(self.A) - 1
38
+ self.num_layers = num_layers
39
+ assert self.num_layers <= len(self.A) - 1
40
+ print("Number of downsampling layer: {}".format(self.num_layers))
41
+ self.num_downsample = num_downsample
42
+ self.use_pret_res = use_pret_res
43
+
44
+ # self.resnet = resnet50(pretrained=True)
45
+ # -> within the GraphCMR network they ignore the last fully connected layer
46
+ # replace the first layer
47
+ self.resnet = models.resnet34(pretrained=self.use_pret_res)
48
+ if (self.use_pret_res) and (n_resnet_in == 3):
49
+ print('use full pretrained resnet including first layer!')
50
+ else:
51
+ self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
52
+ # replace the last layer
53
+ self.resnet.fc = nn.Linear(512, n_resnet_out)
54
+
55
+ self.lin1 = GraphLinear(3 + n_resnet_out, 2 * num_channels)
56
+ self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0])
57
+ encode_layers = []
58
+ decode_layers = []
59
+
60
+ for i in range(self.num_layers + 1): # range(len(self.A)):
61
+ encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i]))
62
+
63
+ decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels,
64
+ self.A[self.num_layers - i]))
65
+ current_channels = (i+1)*num_channels
66
+ # number of channels for the input is different because of the concatenation operation
67
+ self.n_out_gc = 2 # two labels per vertex
68
+ self.gc = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]),
69
+ GraphResBlock(64, 32, self.A[0]),
70
+ nn.GroupNorm(32 // 8, 32),
71
+ nn.ReLU(inplace=True),
72
+ GraphLinear(32, self.n_out_gc))
73
+
74
+ '''
75
+ self.n_out_flatground = 2
76
+ self.flat_ground = nn.Sequential(nn.GroupNorm(current_channels // 8, current_channels),
77
+ nn.ReLU(inplace=True),
78
+ GraphLinear(current_channels, 1),
79
+ nn.ReLU(inplace=True),
80
+ nn.Linear(A.shape[0], self.n_out_flatground))
81
+ '''
82
+
83
+ self.encoder = nn.Sequential(*encode_layers)
84
+ self.decoder = nn.Sequential(*decode_layers)
85
+ self.mesh = mesh
86
+
87
+
88
+
89
+
90
+ def forward(self, image):
91
+ """Forward pass
92
+ Inputs:
93
+ image: size = (B, 3, 256, 256)
94
+ Returns:
95
+ Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
96
+ Weak-perspective camera: size = (B, 3)
97
+ """
98
+ # import pdb; pdb.set_trace()
99
+
100
+ batch_size = image.shape[0]
101
+ # ref_vertices = (self.mesh.get_ref_vertices(n=self.num_downsample).t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
102
+ ref_vertices = (self.mesh.ref_vertices.t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
103
+ image_resnet = self.resnet(image) # (bs, 512)
104
+ image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973)
105
+
106
+ # prepare network input
107
+ # -> for each node we feed the location of the vertex in the template mesh and an image encoding
108
+ x = torch.cat([ref_vertices, image_enc], dim=1)
109
+ x = self.lin1(x)
110
+ x = self.res1(x)
111
+ x_ = [x]
112
+ output_list = []
113
+ for i in range(self.num_layers + 1):
114
+ if i == self.num_layers:
115
+ x = self.encoder[i](x)
116
+ else:
117
+ x = self.encoder[i](x)
118
+ x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1)
119
+ x = x.transpose(1, 2)
120
+ if i < self.num_layers-1:
121
+ x_.append(x)
122
+ for i in range(self.num_layers + 1):
123
+ if i == self.num_layers:
124
+ x = self.decoder[i](x)
125
+ output_list.append(x)
126
+ else:
127
+ x = self.decoder[i](x)
128
+ output_list.append(x)
129
+ x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample,
130
+ n2=self.num_layers-i-1+self.num_downsample)
131
+ x = x.transpose(1, 2)
132
+ x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder
133
+
134
+ ground_contact = self.gc(x)
135
+
136
+ '''
137
+ ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1)
138
+ '''
139
+
140
+ return ground_contact, output_list # , ground_flatness
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+ # how to use it:
149
+ #
150
+ # from src.graph_networks.graphcmr.utils_mesh import Mesh
151
+ #
152
+ # create Mesh object
153
+ # self.mesh = Mesh()
154
+ # self.faces = self.mesh.faces.to(self.device)
155
+ #
156
+ # create GraphCNN
157
+ # self.graph_cnn = GraphCNN(self.mesh.adjmat,
158
+ # self.mesh.ref_vertices.t(),
159
+ # num_channels=self.options.num_channels,
160
+ # num_layers=self.options.num_layers
161
+ # ).to(self.device)
162
+ # ------------
163
+ #
164
+ # Feed image in the GraphCNN
165
+ # Returns subsampled mesh and camera parameters
166
+ # pred_vertices_sub, pred_camera = self.graph_cnn(images)
167
+ #
168
+ # Upsample mesh in the original size
169
+ # pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2))
170
+ #
src/graph_networks/graphcmr/graph_layers.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ code from https://github.com/nkolot/GraphCMR/blob/master/models/graph_layers.py
3
+ This file contains definitions of layers used to build the GraphCNN
4
+ """
5
+ from __future__ import division
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import math
11
+
12
+ class GraphConvolution(nn.Module):
13
+ """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
14
+ def __init__(self, in_features, out_features, adjmat, bias=True):
15
+ super(GraphConvolution, self).__init__()
16
+ self.in_features = in_features
17
+ self.out_features = out_features
18
+ self.adjmat = adjmat
19
+ self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
20
+ if bias:
21
+ self.bias = nn.Parameter(torch.FloatTensor(out_features))
22
+ else:
23
+ self.register_parameter('bias', None)
24
+ self.reset_parameters()
25
+
26
+ def reset_parameters(self):
27
+ # stdv = 1. / math.sqrt(self.weight.size(1))
28
+ stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
29
+ self.weight.data.uniform_(-stdv, stdv)
30
+ if self.bias is not None:
31
+ self.bias.data.uniform_(-stdv, stdv)
32
+
33
+ def forward(self, x):
34
+ if x.ndimension() == 2:
35
+ support = torch.matmul(x, self.weight)
36
+ output = torch.matmul(self.adjmat, support)
37
+ if self.bias is not None:
38
+ output = output + self.bias
39
+ return output
40
+ else:
41
+ output = []
42
+ for i in range(x.shape[0]):
43
+ support = torch.matmul(x[i], self.weight)
44
+ # output.append(torch.matmul(self.adjmat, support))
45
+ output.append(spmm(self.adjmat, support))
46
+ output = torch.stack(output, dim=0)
47
+ if self.bias is not None:
48
+ output = output + self.bias
49
+ return output
50
+
51
+ def __repr__(self):
52
+ return self.__class__.__name__ + ' (' \
53
+ + str(self.in_features) + ' -> ' \
54
+ + str(self.out_features) + ')'
55
+
56
+ class GraphLinear(nn.Module):
57
+ """
58
+ Generalization of 1x1 convolutions on Graphs
59
+ """
60
+ def __init__(self, in_channels, out_channels):
61
+ super(GraphLinear, self).__init__()
62
+ self.in_channels = in_channels
63
+ self.out_channels = out_channels
64
+ self.W = nn.Parameter(torch.FloatTensor(out_channels, in_channels))
65
+ self.b = nn.Parameter(torch.FloatTensor(out_channels))
66
+ self.reset_parameters()
67
+
68
+ def reset_parameters(self):
69
+ w_stdv = 1 / (self.in_channels * self.out_channels)
70
+ self.W.data.uniform_(-w_stdv, w_stdv)
71
+ self.b.data.uniform_(-w_stdv, w_stdv)
72
+
73
+ def forward(self, x):
74
+ return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
75
+
76
+ class GraphResBlock(nn.Module):
77
+ """
78
+ Graph Residual Block similar to the Bottleneck Residual Block in ResNet
79
+ """
80
+
81
+ def __init__(self, in_channels, out_channels, A):
82
+ super(GraphResBlock, self).__init__()
83
+ self.in_channels = in_channels
84
+ self.out_channels = out_channels
85
+ self.lin1 = GraphLinear(in_channels, out_channels // 2)
86
+ self.conv = GraphConvolution(out_channels // 2, out_channels // 2, A)
87
+ self.lin2 = GraphLinear(out_channels // 2, out_channels)
88
+ self.skip_conv = GraphLinear(in_channels, out_channels)
89
+ self.pre_norm = nn.GroupNorm(in_channels // 8, in_channels)
90
+ self.norm1 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2))
91
+ self.norm2 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2))
92
+
93
+ def forward(self, x):
94
+ y = F.relu(self.pre_norm(x))
95
+ y = self.lin1(y)
96
+
97
+ y = F.relu(self.norm1(y))
98
+ y = self.conv(y.transpose(1,2)).transpose(1,2)
99
+
100
+ y = F.relu(self.norm2(y))
101
+ y = self.lin2(y)
102
+ if self.in_channels != self.out_channels:
103
+ x = self.skip_conv(x)
104
+ return x+y
105
+
106
+ class SparseMM(torch.autograd.Function):
107
+ """Redefine sparse @ dense matrix multiplication to enable backpropagation.
108
+ The builtin matrix multiplication operation does not support backpropagation in some cases.
109
+ """
110
+ @staticmethod
111
+ def forward(ctx, sparse, dense):
112
+ ctx.req_grad = dense.requires_grad
113
+ ctx.save_for_backward(sparse)
114
+ return torch.matmul(sparse, dense)
115
+
116
+ @staticmethod
117
+ def backward(ctx, grad_output):
118
+ grad_input = None
119
+ sparse, = ctx.saved_tensors
120
+ if ctx.req_grad:
121
+ grad_input = torch.matmul(sparse.t(), grad_output)
122
+ return None, grad_input
123
+
124
+ def spmm(sparse, dense):
125
+ return SparseMM.apply(sparse, dense)
src/graph_networks/graphcmr/graphcnn_coarse_to_fine_animal_pose.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py
4
+ This file contains the Definition of GraphCNN
5
+ GraphCNN includes ResNet50 as a submodule
6
+ """
7
+ from __future__ import division
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from model.networks.graph_layers import GraphResBlock, GraphLinear
13
+ from smal.mesh import Mesh
14
+ from smal.smal_torch import SMAL
15
+
16
+ # encoder-decoder structured GCN with skip connections
17
+ class GraphCNN_hg(nn.Module):
18
+
19
+ def __init__(self, mesh, num_channels=256, local_feat=False, num_downsample=0):
20
+ '''
21
+ Args:
22
+ mesh: mesh data that store the adjacency matrix
23
+ num_channels: number of channels of GCN
24
+ local_feat: whether use local feature for refinement
25
+ num_downsample: number of downsampling of the input mesh
26
+ '''
27
+ super(GraphCNN_hg, self).__init__()
28
+ self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled
29
+ self.num_layers = len(self.A) - 1
30
+ print("Number of downsampling layer: {}".format(self.num_layers))
31
+ self.num_downsample = num_downsample
32
+ if local_feat:
33
+ self.lin1 = GraphLinear(3 + 2048 + 3840, 2 * num_channels)
34
+ else:
35
+ self.lin1 = GraphLinear(3 + 2048, 2 * num_channels)
36
+ self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0])
37
+ encode_layers = []
38
+ decode_layers = []
39
+
40
+ for i in range(len(self.A)):
41
+ encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i]))
42
+
43
+ decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels,
44
+ self.A[len(self.A) - i - 1]))
45
+ current_channels = (i+1)*num_channels
46
+ # number of channels for the input is different because of the concatenation operation
47
+ self.shape = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]),
48
+ GraphResBlock(64, 32, self.A[0]),
49
+ nn.GroupNorm(32 // 8, 32),
50
+ nn.ReLU(inplace=True),
51
+ GraphLinear(32, 3))
52
+
53
+ self.encoder = nn.Sequential(*encode_layers)
54
+ self.decoder = nn.Sequential(*decode_layers)
55
+ self.mesh = mesh
56
+
57
+ def forward(self, verts_c, img_fea_global, img_fea_multiscale=None, points_local=None):
58
+ '''
59
+ Args:
60
+ verts_c: vertices from the coarse estimation
61
+ img_fea_global: global feature for mesh refinement
62
+ img_fea_multiscale: multi-scale feature from the encoder, used for local feature extraction
63
+ points_local: 2D keypoint for local feature extraction
64
+ Returns: refined mesh
65
+ '''
66
+ batch_size = img_fea_global.shape[0]
67
+ ref_vertices = verts_c.transpose(1, 2)
68
+ image_enc = img_fea_global.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1])
69
+ if points_local is not None:
70
+ feat_local = torch.nn.functional.grid_sample(img_fea_multiscale, points_local)
71
+ x = torch.cat([ref_vertices, image_enc, feat_local.squeeze(2)], dim=1)
72
+ else:
73
+ x = torch.cat([ref_vertices, image_enc], dim=1)
74
+ x = self.lin1(x)
75
+ x = self.res1(x)
76
+ x_ = [x]
77
+ for i in range(self.num_layers + 1):
78
+ if i == self.num_layers:
79
+ x = self.encoder[i](x)
80
+ else:
81
+ x = self.encoder[i](x)
82
+ x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1)
83
+ x = x.transpose(1, 2)
84
+ if i < self.num_layers-1:
85
+ x_.append(x)
86
+ for i in range(self.num_layers + 1):
87
+ if i == self.num_layers:
88
+ x = self.decoder[i](x)
89
+ else:
90
+ x = self.decoder[i](x)
91
+ x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample,
92
+ n2=self.num_layers-i-1+self.num_downsample)
93
+ x = x.transpose(1, 2)
94
+ x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder
95
+
96
+ shape = self.shape(x)
97
+ return shape
src/graph_networks/graphcmr/my_remarks.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ this folder contains code from https://github.com/nkolot/GraphCMR/tree/master/models
3
+
4
+
5
+ other (newer) networks operating on meshes such as SMAL would be:
6
+ https://github.com/microsoft/MeshTransformer
7
+ https://github.com/microsoft/MeshGraphormer
8
+
9
+ see also:
10
+ https://arxiv.org/pdf/2112.01554.pdf, page 13
11
+ (Neural Head Avatars from Monocular RGB Videos)
src/graph_networks/graphcmr/pytorch_coma_mesh_operations.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from https://github.com/pixelite1201/pytorch_coma/blob/master/mesh_operations.py
2
+
3
+ import math
4
+ import heapq
5
+ import numpy as np
6
+ import scipy.sparse as sp
7
+ from psbody.mesh import Mesh
8
+
9
+ def row(A):
10
+ return A.reshape((1, -1))
11
+
12
+ def col(A):
13
+ return A.reshape((-1, 1))
14
+
15
+ def get_vert_connectivity(mesh_v, mesh_f):
16
+ """Returns a sparse matrix (of size #verts x #verts) where each nonzero
17
+ element indicates a neighborhood relation. For example, if there is a
18
+ nonzero element in position (15,12), that means vertex 15 is connected
19
+ by an edge to vertex 12."""
20
+
21
+ vpv = sp.csc_matrix((len(mesh_v),len(mesh_v)))
22
+
23
+ # for each column in the faces...
24
+ for i in range(3):
25
+ IS = mesh_f[:,i]
26
+ JS = mesh_f[:,(i+1)%3]
27
+ data = np.ones(len(IS))
28
+ ij = np.vstack((row(IS.flatten()), row(JS.flatten())))
29
+ mtx = sp.csc_matrix((data, ij), shape=vpv.shape)
30
+ vpv = vpv + mtx + mtx.T
31
+
32
+ return vpv
33
+
34
+ def get_vertices_per_edge(mesh_v, mesh_f):
35
+ """Returns an Ex2 array of adjacencies between vertices, where
36
+ each element in the array is a vertex index. Each edge is included
37
+ only once. If output of get_faces_per_edge is provided, this is used to
38
+ avoid call to get_vert_connectivity()"""
39
+
40
+ vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f))
41
+ result = np.hstack((col(vc.row), col(vc.col)))
42
+ result = result[result[:,0] < result[:,1]] # for uniqueness
43
+
44
+ return result
45
+
46
+
47
+ def vertex_quadrics(mesh):
48
+ """Computes a quadric for each vertex in the Mesh.
49
+ see also:
50
+ https://www.cs.cmu.edu/~./garland/Papers/quadrics.pdf
51
+ https://users.csc.calpoly.edu/~zwood/teaching/csc570/final06/jseeba/
52
+ Returns:
53
+ v_quadrics: an (N x 4 x 4) array, where N is # vertices.
54
+ """
55
+
56
+ # Allocate quadrics
57
+ v_quadrics = np.zeros((len(mesh.v), 4, 4,))
58
+
59
+ # For each face...
60
+ for f_idx in range(len(mesh.f)):
61
+
62
+ # Compute normalized plane equation for that face
63
+ vert_idxs = mesh.f[f_idx]
64
+ verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1, 1]).reshape(-1, 1)))
65
+ u, s, v = np.linalg.svd(verts)
66
+ eq = v[-1, :].reshape(-1, 1)
67
+ eq = eq / (np.linalg.norm(eq[0:3]))
68
+
69
+ # Add the outer product of the plane equation to the
70
+ # quadrics of the vertices for this face
71
+ for k in range(3):
72
+ v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq)
73
+
74
+ return v_quadrics
75
+
76
+ def _get_sparse_transform(faces, num_original_verts):
77
+ verts_left = np.unique(faces.flatten())
78
+ IS = np.arange(len(verts_left))
79
+ JS = verts_left
80
+ data = np.ones(len(JS))
81
+
82
+ mp = np.arange(0, np.max(faces.flatten()) + 1)
83
+ mp[JS] = IS
84
+ new_faces = mp[faces.copy().flatten()].reshape((-1, 3))
85
+
86
+ ij = np.vstack((IS.flatten(), JS.flatten()))
87
+ mtx = sp.csc_matrix((data, ij), shape=(len(verts_left) , num_original_verts ))
88
+
89
+ return (new_faces, mtx)
90
+
91
+ def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None):
92
+ """Return a simplified version of this mesh.
93
+
94
+ A Qslim-style approach is used here.
95
+
96
+ :param factor: fraction of the original vertices to retain
97
+ :param n_verts_desired: number of the original vertices to retain
98
+ :returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix
99
+ """
100
+
101
+ if factor is None and n_verts_desired is None:
102
+ raise Exception('Need either factor or n_verts_desired.')
103
+
104
+ if n_verts_desired is None:
105
+ n_verts_desired = math.ceil(len(mesh.v) * factor)
106
+
107
+ Qv = vertex_quadrics(mesh)
108
+
109
+ # fill out a sparse matrix indicating vertex-vertex adjacency
110
+ # from psbody.mesh.topology.connectivity import get_vertices_per_edge
111
+ vert_adj = get_vertices_per_edge(mesh.v, mesh.f)
112
+ # vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v)))
113
+ # for f_idx in range(len(mesh.f)):
114
+ # vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1
115
+
116
+ vert_adj = sp.csc_matrix((vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])), shape=(len(mesh.v), len(mesh.v)))
117
+ vert_adj = vert_adj + vert_adj.T
118
+ vert_adj = vert_adj.tocoo()
119
+
120
+ def collapse_cost(Qv, r, c, v):
121
+ Qsum = Qv[r, :, :] + Qv[c, :, :]
122
+ p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
123
+ p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
124
+
125
+ destroy_c_cost = p1.T.dot(Qsum).dot(p1)
126
+ destroy_r_cost = p2.T.dot(Qsum).dot(p2)
127
+ result = {
128
+ 'destroy_c_cost': destroy_c_cost,
129
+ 'destroy_r_cost': destroy_r_cost,
130
+ 'collapse_cost': min([destroy_c_cost, destroy_r_cost]),
131
+ 'Qsum': Qsum}
132
+ return result
133
+
134
+ # construct a queue of edges with costs
135
+ queue = []
136
+ for k in range(vert_adj.nnz):
137
+ r = vert_adj.row[k]
138
+ c = vert_adj.col[k]
139
+
140
+ if r > c:
141
+ continue
142
+
143
+ cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost']
144
+ heapq.heappush(queue, (cost, (r, c)))
145
+
146
+ # decimate
147
+ collapse_list = []
148
+ nverts_total = len(mesh.v)
149
+ faces = mesh.f.copy()
150
+ while nverts_total > n_verts_desired:
151
+ e = heapq.heappop(queue)
152
+ r = e[1][0]
153
+ c = e[1][1]
154
+ if r == c:
155
+ continue
156
+
157
+ cost = collapse_cost(Qv, r, c, mesh.v)
158
+ if cost['collapse_cost'] > e[0]:
159
+ heapq.heappush(queue, (cost['collapse_cost'], e[1]))
160
+ # print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost'])
161
+ continue
162
+ else:
163
+
164
+ # update old vert idxs to new one,
165
+ # in queue and in face list
166
+ if cost['destroy_c_cost'] < cost['destroy_r_cost']:
167
+ to_destroy = c
168
+ to_keep = r
169
+ else:
170
+ to_destroy = r
171
+ to_keep = c
172
+
173
+ collapse_list.append([to_keep, to_destroy])
174
+
175
+ # in our face array, replace "to_destroy" vertidx with "to_keep" vertidx
176
+ np.place(faces, faces == to_destroy, to_keep)
177
+
178
+ # same for queue
179
+ which1 = [idx for idx in range(len(queue)) if queue[idx][1][0] == to_destroy]
180
+ which2 = [idx for idx in range(len(queue)) if queue[idx][1][1] == to_destroy]
181
+ for k in which1:
182
+ queue[k] = (queue[k][0], (to_keep, queue[k][1][1]))
183
+ for k in which2:
184
+ queue[k] = (queue[k][0], (queue[k][1][0], to_keep))
185
+
186
+ Qv[r, :, :] = cost['Qsum']
187
+ Qv[c, :, :] = cost['Qsum']
188
+
189
+ a = faces[:, 0] == faces[:, 1]
190
+ b = faces[:, 1] == faces[:, 2]
191
+ c = faces[:, 2] == faces[:, 0]
192
+
193
+ # remove degenerate faces
194
+ def logical_or3(x, y, z):
195
+ return np.logical_or(x, np.logical_or(y, z))
196
+
197
+ faces_to_keep = np.logical_not(logical_or3(a, b, c))
198
+ faces = faces[faces_to_keep, :].copy()
199
+
200
+ nverts_total = (len(np.unique(faces.flatten())))
201
+
202
+ new_faces, mtx = _get_sparse_transform(faces, len(mesh.v))
203
+ return new_faces, mtx
204
+
205
+
206
+ def setup_deformation_transfer(source, target, use_normals=False):
207
+ rows = np.zeros(3 * target.v.shape[0])
208
+ cols = np.zeros(3 * target.v.shape[0])
209
+ coeffs_v = np.zeros(3 * target.v.shape[0])
210
+ coeffs_n = np.zeros(3 * target.v.shape[0])
211
+
212
+ nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree().nearest(target.v, True)
213
+ nearest_faces = nearest_faces.ravel().astype(np.int64)
214
+ nearest_parts = nearest_parts.ravel().astype(np.int64)
215
+ nearest_vertices = nearest_vertices.ravel()
216
+
217
+ for i in range(target.v.shape[0]):
218
+ # Closest triangle index
219
+ f_id = nearest_faces[i]
220
+ # Closest triangle vertex ids
221
+ nearest_f = source.f[f_id]
222
+
223
+ # Closest surface point
224
+ nearest_v = nearest_vertices[3 * i:3 * i + 3]
225
+ # Distance vector to the closest surface point
226
+ dist_vec = target.v[i] - nearest_v
227
+
228
+ rows[3 * i:3 * i + 3] = i * np.ones(3)
229
+ cols[3 * i:3 * i + 3] = nearest_f
230
+
231
+ n_id = nearest_parts[i]
232
+ if n_id == 0:
233
+ # Closest surface point in triangle
234
+ A = np.vstack((source.v[nearest_f])).T
235
+ coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v)[0]
236
+ elif n_id > 0 and n_id <= 3:
237
+ # Closest surface point on edge
238
+ A = np.vstack((source.v[nearest_f[n_id - 1]], source.v[nearest_f[n_id % 3]])).T
239
+ tmp_coeffs = np.linalg.lstsq(A, target.v[i])[0]
240
+ coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0]
241
+ coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1]
242
+ else:
243
+ # Closest surface point a vertex
244
+ coeffs_v[3 * i + n_id - 4] = 1.0
245
+
246
+ # if use_normals:
247
+ # A = np.vstack((vn[nearest_f])).T
248
+ # coeffs_n[3 * i:3 * i + 3] = np.linalg.lstsq(A, dist_vec)[0]
249
+
250
+ #coeffs = np.hstack((coeffs_v, coeffs_n))
251
+ #rows = np.hstack((rows, rows))
252
+ #cols = np.hstack((cols, source.v.shape[0] + cols))
253
+ matrix = sp.csc_matrix((coeffs_v, (rows, cols)), shape=(target.v.shape[0], source.v.shape[0]))
254
+ return matrix
255
+
256
+
257
+ def generate_transform_matrices(mesh, factors):
258
+ """Generates len(factors) meshes, each of them is scaled by factors[i] and
259
+ computes the transformations between them.
260
+
261
+ Returns:
262
+ M: a set of meshes downsampled from mesh by a factor specified in factors.
263
+ A: Adjacency matrix for each of the meshes
264
+ D: Downsampling transforms between each of the meshes
265
+ U: Upsampling transforms between each of the meshes
266
+ """
267
+
268
+ factors = map(lambda x: 1.0 / x, factors)
269
+ M, A, D, U = [], [], [], []
270
+ A.append(get_vert_connectivity(mesh.v, mesh.f).tocoo())
271
+ M.append(mesh)
272
+
273
+ for i,factor in enumerate(factors):
274
+ ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor)
275
+ D.append(ds_D.tocoo())
276
+ new_mesh_v = ds_D.dot(M[-1].v)
277
+ new_mesh = Mesh(v=new_mesh_v, f=ds_f)
278
+ M.append(new_mesh)
279
+ A.append(get_vert_connectivity(new_mesh.v, new_mesh.f).tocoo())
280
+ U.append(setup_deformation_transfer(M[-1], M[-2]).tocoo())
281
+
282
+ return M, A, D, U
src/graph_networks/graphcmr/utils_mesh.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code from https://github.com/nkolot/GraphCMR/blob/master/utils/mesh.py
2
+
3
+ from __future__ import division
4
+ import torch
5
+ import numpy as np
6
+ import scipy.sparse
7
+
8
+ # from models import SMPL
9
+ import os
10
+ import sys
11
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
12
+ from graph_networks.graphcmr.graph_layers import spmm
13
+
14
+ def scipy_to_pytorch(A, U, D):
15
+ """Convert scipy sparse matrices to pytorch sparse matrix."""
16
+ ptU = []
17
+ ptD = []
18
+
19
+ for i in range(len(U)):
20
+ u = scipy.sparse.coo_matrix(U[i])
21
+ i = torch.LongTensor(np.array([u.row, u.col]))
22
+ v = torch.FloatTensor(u.data)
23
+ ptU.append(torch.sparse.FloatTensor(i, v, u.shape))
24
+
25
+ for i in range(len(D)):
26
+ d = scipy.sparse.coo_matrix(D[i])
27
+ i = torch.LongTensor(np.array([d.row, d.col]))
28
+ v = torch.FloatTensor(d.data)
29
+ ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
30
+
31
+ return ptU, ptD
32
+
33
+
34
+ def adjmat_sparse(adjmat, nsize=1):
35
+ """Create row-normalized sparse graph adjacency matrix."""
36
+ adjmat = scipy.sparse.csr_matrix(adjmat)
37
+ if nsize > 1:
38
+ orig_adjmat = adjmat.copy()
39
+ for _ in range(1, nsize):
40
+ adjmat = adjmat * orig_adjmat
41
+ adjmat.data = np.ones_like(adjmat.data)
42
+ for i in range(adjmat.shape[0]):
43
+ adjmat[i,i] = 1
44
+ num_neighbors = np.array(1 / adjmat.sum(axis=-1))
45
+ adjmat = adjmat.multiply(num_neighbors)
46
+ adjmat = scipy.sparse.coo_matrix(adjmat)
47
+ row = adjmat.row
48
+ col = adjmat.col
49
+ data = adjmat.data
50
+ i = torch.LongTensor(np.array([row, col]))
51
+ v = torch.from_numpy(data).float()
52
+ adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
53
+ return adjmat
54
+
55
+ def get_graph_params(filename, nsize=1):
56
+ """Load and process graph adjacency matrix and upsampling/downsampling matrices."""
57
+ data = np.load(filename, encoding='latin1', allow_pickle=True) # np.load(filename, encoding='latin1')
58
+ A = data['A']
59
+ U = data['U']
60
+ D = data['D']
61
+ U, D = scipy_to_pytorch(A, U, D)
62
+ A = [adjmat_sparse(a, nsize=nsize) for a in A]
63
+ return A, U, D
64
+
65
+ class Mesh(object):
66
+ """Mesh object that is used for handling certain graph operations."""
67
+ def __init__(self, filename='data/mesh_downsampling.npz',
68
+ num_downsampling=1, nsize=1, body_model=None, device=torch.device('cuda')):
69
+ self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
70
+ self._A = [a.to(device) for a in self._A]
71
+ self._U = [u.to(device) for u in self._U]
72
+ self._D = [d.to(device) for d in self._D]
73
+ self.num_downsampling = num_downsampling
74
+
75
+ # load template vertices from SMPL and normalize them
76
+ if body_model is None:
77
+ smpl = SMPL()
78
+ else:
79
+ smpl = body_model
80
+ ref_vertices = smpl.v_template
81
+ center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
82
+ ref_vertices -= center
83
+ ref_vertices /= ref_vertices.abs().max().item()
84
+
85
+ self._ref_vertices = ref_vertices.to(device)
86
+ self.faces = smpl.faces.int().to(device)
87
+
88
+ @property
89
+ def adjmat(self):
90
+ """Return the graph adjacency matrix at the specified subsampling level."""
91
+ return self._A[self.num_downsampling].float()
92
+
93
+ @property
94
+ def ref_vertices(self):
95
+ """Return the template vertices at the specified subsampling level."""
96
+ ref_vertices = self._ref_vertices
97
+ for i in range(self.num_downsampling):
98
+ ref_vertices = torch.spmm(self._D[i], ref_vertices)
99
+ return ref_vertices
100
+
101
+ def get_ref_vertices(self, n_downsample):
102
+ """Return the template vertices at any desired subsampling level."""
103
+ ref_vertices = self._ref_vertices
104
+ for i in range(n_downsample):
105
+ ref_vertices = torch.spmm(self._D[i], ref_vertices)
106
+ return ref_vertices
107
+
108
+ def downsample(self, x, n1=0, n2=None):
109
+ """Downsample mesh."""
110
+ if n2 is None:
111
+ n2 = self.num_downsampling
112
+ if x.ndimension() < 3:
113
+ for i in range(n1, n2):
114
+ x = spmm(self._D[i], x)
115
+ elif x.ndimension() == 3:
116
+ out = []
117
+ for i in range(x.shape[0]):
118
+ y = x[i]
119
+ for j in range(n1, n2):
120
+ y = spmm(self._D[j], y)
121
+ out.append(y)
122
+ x = torch.stack(out, dim=0)
123
+ return x
124
+
125
+ def upsample(self, x, n1=1, n2=0):
126
+ """Upsample mesh."""
127
+ if x.ndimension() < 3:
128
+ for i in reversed(range(n2, n1)):
129
+ x = spmm(self._U[i], x)
130
+ elif x.ndimension() == 3:
131
+ out = []
132
+ for i in range(x.shape[0]):
133
+ y = x[i]
134
+ for j in reversed(range(n2, n1)):
135
+ y = spmm(self._U[j], y)
136
+ out.append(y)
137
+ x = torch.stack(out, dim=0)
138
+ return x
src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py
4
+ shortest.py
5
+ ----------------
6
+ Given a mesh and two vertex indices find the shortest path
7
+ between the two vertices while only traveling along edges
8
+ of the mesh.
9
+ """
10
+
11
+ # python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py
12
+
13
+
14
+ import os
15
+ import sys
16
+ import glob
17
+ import csv
18
+ import json
19
+ import shutil
20
+ import tqdm
21
+ import numpy as np
22
+ import pickle as pkl
23
+ import trimesh
24
+ import networkx as nx
25
+
26
+
27
+
28
+
29
+
30
+ def read_csv(csv_file):
31
+ with open(csv_file,'r') as f:
32
+ reader = csv.reader(f)
33
+ headers = next(reader)
34
+ row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader]
35
+ return row_list
36
+
37
+
38
+ def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'):
39
+ vert_dists = np.load(root_out_path + filename)
40
+ return vert_dists
41
+
42
+
43
+ def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False):
44
+ # root_out_path = ROOT_OUT_PATH
45
+ '''
46
+ from smal_pytorch.smal_model.smal_torch_new import SMAL
47
+ smal = SMAL()
48
+ verts = smal.v_template.detach().cpu().numpy()
49
+ faces = smal.faces.detach().cpu().numpy()
50
+ '''
51
+ # path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
52
+ my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
53
+ verts = my_mesh.vertices
54
+ faces = my_mesh.faces
55
+ # edges without duplication
56
+ edges = my_mesh.edges_unique
57
+ # the actual length of each unique edge
58
+ length = my_mesh.edges_unique_length
59
+ # create the graph with edge attributes for length (option A)
60
+ # g = nx.Graph()
61
+ # for edge, L in zip(edges, length): g.add_edge(*edge, length=L)
62
+ # you can create the graph with from_edgelist and
63
+ # a list comprehension (option B)
64
+ ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)])
65
+ # calculate the distances between all vertex pairs
66
+ if calc_dist_mat:
67
+ # calculate distances between all possible vertex pairs
68
+ # shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length')
69
+ # shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length')
70
+ dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra'))
71
+ vertex_distances = np.zeros((n_verts_smal, n_verts_smal))
72
+ for ind_v0 in range(n_verts_smal):
73
+ print(ind_v0)
74
+ for ind_v1 in range(ind_v0, n_verts_smal):
75
+ vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1]
76
+ vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1]
77
+ # save those distances
78
+ np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances)
79
+ vert_dists = vertex_distances
80
+ else:
81
+ vert_dists = np.load(root_out_path + 'all_vertex_distances.npy')
82
+ return ga, vert_dists
83
+
84
+
85
+ def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None):
86
+ # input:
87
+ # root_out_path_vis = ROOT_OUT_PATH
88
+ # img_v12_dir = IMG_V12_DIR
89
+ # name = images_with_gc_labelled[ind_img]
90
+ # gc_info_raw = gc_dict['bite/' + name]
91
+ # output:
92
+ # vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist]
93
+ n_verts_smal = 3889
94
+ gc_vertices = []
95
+ gc_info_np = np.zeros((n_verts_smal))
96
+ for ind_v in gc_info_raw:
97
+ if ind_v < n_verts_smal:
98
+ gc_vertices.append(ind_v)
99
+ gc_info_np[ind_v] = 1
100
+ # save a visualization of those annotations
101
+ if root_out_path_vis is not None:
102
+ my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True)
103
+ if img_v12_dir is not None and root_out_path_vis is not None:
104
+ vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1)
105
+ my_mesh.visual.vertex_colors = vert_colors
106
+ my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj'))
107
+ img_path = img_v12_dir + name
108
+ shutil.copy(img_path, root_out_path_vis + name)
109
+ # calculate for each vertex the distance to the closest element of the other group
110
+ non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices))
111
+ print('vertices in contact: ' + str(len(gc_vertices)))
112
+ print('vertices without contact: ' + str(len(non_gc_vertices)))
113
+ vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist
114
+ vertex_overview[:, 0] = gc_info_np
115
+ # loop through all contact vertices
116
+ for ind_v in gc_vertices:
117
+ min_length = 100
118
+ for ind_v_ps in non_gc_vertices: # possible solution
119
+ # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
120
+ # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
121
+ this_length = vert_dists[ind_v, ind_v_ps]
122
+ if this_length < min_length:
123
+ min_length = this_length
124
+ vertex_overview[ind_v, 1] = ind_v_ps
125
+ vertex_overview[ind_v, 2] = this_length
126
+ # loop through all non-contact vertices
127
+ for ind_v in non_gc_vertices:
128
+ min_length = 100
129
+ for ind_v_ps in gc_vertices: # possible solution
130
+ # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
131
+ # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
132
+ this_length = vert_dists[ind_v, ind_v_ps]
133
+ if this_length < min_length:
134
+ min_length = this_length
135
+ vertex_overview[ind_v, 1] = ind_v_ps
136
+ vertex_overview[ind_v, 2] = this_length
137
+ if root_out_path_vis is not None:
138
+ # save a colored mesh
139
+ my_mesh_dists = my_mesh.copy()
140
+ scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max()
141
+ scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max()
142
+ vert_col = np.zeros((n_verts_smal, 3))
143
+ vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green
144
+ vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red
145
+ my_mesh_dists.visual.vertex_colors = np.uint8(vert_col)
146
+ my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj'))
147
+ return vertex_overview
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+ def main():
157
+
158
+ ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/'
159
+ ROOT_PATH_ANNOT = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/'
160
+ IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/'
161
+ # ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/'
162
+ ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/'
163
+ ROOT_OUT_PATH_VIS = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/vis/'
164
+ ROOT_OUT_PATH_DISTSGCNONGC = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/vertex_distances_gc_nongc/'
165
+ ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/'
166
+
167
+ # load all vertex distances
168
+ path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
169
+ my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
170
+ verts = my_mesh.vertices
171
+ faces = my_mesh.faces
172
+ # vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False)
173
+ vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy')
174
+
175
+
176
+
177
+
178
+ all_keys = []
179
+ gc_dict = {}
180
+ # data/stanext_related_data/ground_contact_annotations/stage3/main_partA1667_20221021_140108.csv
181
+ # for csv_file in ['main_partA500_20221018_131139.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']:
182
+ # for csv_file in ['main_partA1667_20221021_140108.csv', 'main_partA500_20221018_131139.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']:
183
+ for csv_file in ['main_partA1667_20221021_140108.csv', 'main_partA500_20221018_131139.csv', 'main_partB20221023_150926.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']:
184
+ # load all ground contact annotations
185
+ gc_annot_csv = ROOT_PATH_ANNOT + csv_file # 'my_gcannotations_qualification.csv'
186
+ gc_row_list = read_csv(gc_annot_csv)
187
+ for ind_row in range(len(gc_row_list)):
188
+ json_acceptable_string = (gc_row_list[ind_row]['vertices']).replace("'", "\"")
189
+ gc_dict_temp = json.loads(json_acceptable_string)
190
+ all_keys.extend(gc_dict_temp.keys())
191
+ gc_dict.update(gc_dict_temp)
192
+ print(len(gc_dict.keys()))
193
+
194
+ print('number of labeled images: ' + str(len(gc_dict.keys()))) # WHY IS THIS ONLY 699?
195
+
196
+ import pdb; pdb.set_trace()
197
+
198
+
199
+ # prepare and save contact annotations including distances
200
+ vertex_overview_dict = {}
201
+ for ind_img, name_ingcdict in enumerate(gc_dict.keys()): # range(len(gc_dict.keys())):
202
+ name = name_ingcdict.split('bite/')[1]
203
+ # name = images_with_gc_labelled[ind_img]
204
+ print('work on image ' + str(ind_img) + ': ' + name)
205
+ # gc_info_raw = gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact
206
+ gc_info_raw = gc_dict[name_ingcdict] # a list with all vertex numbers that are in ground contact
207
+
208
+ if not os.path.exists(ROOT_OUT_PATH_VIS + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_VIS + name.split('/')[0])
209
+ if not os.path.exists(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0])
210
+
211
+ vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH_VIS, verts=verts, faces=faces, img_v12_dir=None)
212
+ np.save(ROOT_OUT_PATH_DISTSGCNONGC + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview)
213
+
214
+ vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw}
215
+
216
+
217
+
218
+
219
+
220
+ # import pdb; pdb.set_trace()
221
+
222
+ with open(ROOT_OUT_PATH + 'gc_annots_overview_stage3complete_withtraintestval_xx.pkl', 'wb') as fp:
223
+ pkl.dump(vertex_overview_dict, fp)
224
+
225
+
226
+
227
+
228
+
229
+
230
+
231
+
232
+
233
+
234
+
235
+
236
+
237
+ if __name__ == "__main__":
238
+ main()
239
+
240
+
241
+
242
+
243
+
244
+
245
+
src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py
4
+ shortest.py
5
+ ----------------
6
+ Given a mesh and two vertex indices find the shortest path
7
+ between the two vertices while only traveling along edges
8
+ of the mesh.
9
+ """
10
+
11
+ # python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py
12
+
13
+
14
+ import os
15
+ import sys
16
+ import glob
17
+ import csv
18
+ import json
19
+ import shutil
20
+ import tqdm
21
+ import numpy as np
22
+ import pickle as pkl
23
+ import trimesh
24
+ import networkx as nx
25
+
26
+
27
+
28
+
29
+
30
+ def read_csv(csv_file):
31
+ with open(csv_file,'r') as f:
32
+ reader = csv.reader(f)
33
+ headers = next(reader)
34
+ row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader]
35
+ return row_list
36
+
37
+
38
+ def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'):
39
+ vert_dists = np.load(root_out_path + filename)
40
+ return vert_dists
41
+
42
+
43
+ def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False):
44
+ # root_out_path = ROOT_OUT_PATH
45
+ '''
46
+ from smal_pytorch.smal_model.smal_torch_new import SMAL
47
+ smal = SMAL()
48
+ verts = smal.v_template.detach().cpu().numpy()
49
+ faces = smal.faces.detach().cpu().numpy()
50
+ '''
51
+ # path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
52
+ my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
53
+ verts = my_mesh.vertices
54
+ faces = my_mesh.faces
55
+ # edges without duplication
56
+ edges = my_mesh.edges_unique
57
+ # the actual length of each unique edge
58
+ length = my_mesh.edges_unique_length
59
+ # create the graph with edge attributes for length (option A)
60
+ # g = nx.Graph()
61
+ # for edge, L in zip(edges, length): g.add_edge(*edge, length=L)
62
+ # you can create the graph with from_edgelist and
63
+ # a list comprehension (option B)
64
+ ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)])
65
+ # calculate the distances between all vertex pairs
66
+ if calc_dist_mat:
67
+ # calculate distances between all possible vertex pairs
68
+ # shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length')
69
+ # shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length')
70
+ dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra'))
71
+ vertex_distances = np.zeros((n_verts_smal, n_verts_smal))
72
+ for ind_v0 in range(n_verts_smal):
73
+ print(ind_v0)
74
+ for ind_v1 in range(ind_v0, n_verts_smal):
75
+ vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1]
76
+ vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1]
77
+ # save those distances
78
+ np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances)
79
+ vert_dists = vertex_distances
80
+ else:
81
+ vert_dists = np.load(root_out_path + 'all_vertex_distances.npy')
82
+ return ga, vert_dists
83
+
84
+
85
+ def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None):
86
+ # input:
87
+ # root_out_path_vis = ROOT_OUT_PATH
88
+ # img_v12_dir = IMG_V12_DIR
89
+ # name = images_with_gc_labelled[ind_img]
90
+ # gc_info_raw = gc_dict['bite/' + name]
91
+ # output:
92
+ # vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist]
93
+ n_verts_smal = 3889
94
+ gc_vertices = []
95
+ gc_info_np = np.zeros((n_verts_smal))
96
+ for ind_v in gc_info_raw:
97
+ if ind_v < n_verts_smal:
98
+ gc_vertices.append(ind_v)
99
+ gc_info_np[ind_v] = 1
100
+ # save a visualization of those annotations
101
+ if root_out_path_vis is not None:
102
+ my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True)
103
+ if img_v12_dir is not None and root_out_path_vis is not None:
104
+ vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1)
105
+ my_mesh.visual.vertex_colors = vert_colors
106
+ my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj'))
107
+ img_path = img_v12_dir + name
108
+ shutil.copy(img_path, root_out_path_vis + name)
109
+ # calculate for each vertex the distance to the closest element of the other group
110
+ non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices))
111
+ print('vertices in contact: ' + str(len(gc_vertices)))
112
+ print('vertices without contact: ' + str(len(non_gc_vertices)))
113
+ vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist
114
+ vertex_overview[:, 0] = gc_info_np
115
+ # loop through all contact vertices
116
+ for ind_v in gc_vertices:
117
+ min_length = 100
118
+ for ind_v_ps in non_gc_vertices: # possible solution
119
+ # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
120
+ # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
121
+ this_length = vert_dists[ind_v, ind_v_ps]
122
+ if this_length < min_length:
123
+ min_length = this_length
124
+ vertex_overview[ind_v, 1] = ind_v_ps
125
+ vertex_overview[ind_v, 2] = this_length
126
+ # loop through all non-contact vertices
127
+ for ind_v in non_gc_vertices:
128
+ min_length = 100
129
+ for ind_v_ps in gc_vertices: # possible solution
130
+ # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
131
+ # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
132
+ this_length = vert_dists[ind_v, ind_v_ps]
133
+ if this_length < min_length:
134
+ min_length = this_length
135
+ vertex_overview[ind_v, 1] = ind_v_ps
136
+ vertex_overview[ind_v, 2] = this_length
137
+ if root_out_path_vis is not None:
138
+ # save a colored mesh
139
+ my_mesh_dists = my_mesh.copy()
140
+ scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max()
141
+ scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max()
142
+ vert_col = np.zeros((n_verts_smal, 3))
143
+ vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green
144
+ vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red
145
+ my_mesh_dists.visual.vertex_colors = np.uint8(vert_col)
146
+ my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj'))
147
+ return vertex_overview
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+ def main():
158
+
159
+ ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/'
160
+ IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/'
161
+ # ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/'
162
+ ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/'
163
+ ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/'
164
+
165
+ # load all vertex distances
166
+ path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
167
+ my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
168
+ verts = my_mesh.vertices
169
+ faces = my_mesh.faces
170
+ # vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False)
171
+ vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy')
172
+
173
+ # paw vertices:
174
+ # left and right is a bit different, but that is ok (we will anyways mirror data at training time)
175
+ right_front_paw = [3829,+3827,+3825,+3718,+3722,+3723,+3743,+3831,+3719,+3726,+3716,+3724,+3828,+3717,+3721,+3725,+3832,+3830,+3720,+3288,+3740,+3714,+3826,+3715,+3728,+3712,+3287,+3284,+3727,+3285,+3742,+3291,+3710,+3697,+3711,+3289,+3730,+3713,+3739,+3282,+3738,+3708,+3709,+3741,+3698,+3696,+3308,+3695,+3706,+3700,+3707,+3306,+3305,+3737,+3304,+3303,+3307,+3736,+3735,+3250,+3261,+3732,+3734,+3733,+3731,+3729,+3299,+3297,+3298,+3295,+3293,+3296,+3294,+3292,+3312,+3311,+3314,+3309,+3290,+3313,+3410,+3315,+3411,+3412,+3316,+3421,+3317,+3415,+3445,+3327,+3328,+3283,+3343,+3326,+3325,+3330,+3286,+3399,+3398,+3329,+3446,+3400,+3331,+3401,+3281,+3332,+3279,+3402,+3419,+3407,+3356,+3358,+3357,+3280,+3354,+3277,+3278,+3346,+3347,+3377,+3378,+3345,+3386,+3379,+3348,+3384,+3418,+3372,+3276,+3275,+3374,+3274,+3373,+3375,+3369,+3371,+3376,+3273,+3396,+3397,+3395,+3388,+3360,+3370,+3361,+3394,+3387,+3420,+3359,+3389,+3272,+3391,+3393,+3390,+3392,+3363,+3362,+3367,+3365,+3705,+3271,+3704,+3703,+3270,+3269,+3702,+3268,+3224,+3267,+3701,+3225,+3699,+3265,+3264,+3266,+3263,+3262,+3249,+3228,+3230,+3251,+3301,+3300,+3302,+3252]
176
+ right_back_paw = [3472,+3627,+3470,+3469,+3471,+3473,+3626,+3625,+3475,+3655,+3519,+3468,+3629,+3466,+3476,+3624,+3521,+3654,+3657,+3838,+3518,+3653,+3839,+3553,+3474,+3516,+3656,+3628,+3834,+3535,+3630,+3658,+3477,+3520,+3517,+3595,+3522,+3597,+3596,+3501,+3534,+3503,+3478,+3500,+3479,+3502,+3607,+3499,+3608,+3496,+3605,+3609,+3504,+3606,+3642,+3614,+3498,+3480,+3631,+3610,+3613,+3506,+3659,+3660,+3632,+3841,+3661,+3836,+3662,+3633,+3663,+3664,+3634,+3635,+3486,+3665,+3636,+3637,+3666,+3490,+3837,+3667,+3493,+3638,+3492,+3495,+3616,+3644,+3494,+3835,+3643,+3833,+3840,+3615,+3650,+3668,+3652,+3651,+3645,+3646,+3647,+3649,+3648,+3622,+3617,+3448,+3621,+3618,+3623,+3462,+3464,+3460,+3620,+3458,+3461,+3463,+3465,+3573,+3571,+3467,+3569,+3557,+3558,+3572,+3570,+3556,+3585,+3593,+3594,+3459,+3566,+3592,+3567,+3568,+3538,+3539,+3555,+3537,+3536,+3554,+3575,+3574,+3583,+3541,+3550,+3576,+3581,+3639,+3577,+3551,+3582,+3580,+3552,+3578,+3542,+3549,+3579,+3523,+3526,+3598,+3525,+3600,+3640,+3599,+3601,+3602,+3603,+3529,+3604,+3530,+3533,+3532,+3611,+3612,+3482,+3481,+3505,+3452,+3455,+3456,+3454,+3457,+3619,+3451,+3450,+3449,+3591,+3589,+3641,+3584,+3561,+3587,+3559,+3488,+3484,+3483]
177
+ left_front_paw = [1791,+1950,+1948,+1790,+1789,+1746,+1788,+1747,+1949,+1944,+1792,+1945,+1356,+1775,+1759,+1777,+1787,+1946,+1757,+1761,+1745,+1943,+1947,+1744,+1309,+1786,+1771,+1354,+1774,+1765,+1767,+1768,+1772,+1763,+1770,+1773,+1769,+1764,+1766,+1758,+1760,+1762,+1336,+1333,+1330,+1325,+1756,+1323,+1755,+1753,+1749,+1754,+1751,+1321,+1752,+1748,+1750,+1312,+1319,+1315,+1313,+1317,+1318,+1316,+1314,+1311,+1310,+1299,+1276,+1355,+1297,+1353,+1298,+1300,+1352,+1351,+1785,+1784,+1349,+1783,+1782,+1781,+1780,+1779,+1778,+1776,+1343,+1341,+1344,+1339,+1342,+1340,+1360,+1335,+1338,+1362,+1357,+1361,+1363,+1458,+1337,+1459,+1456,+1460,+1493,+1332,+1375,+1376,+1331,+1374,+1378,+1334,+1373,+1494,+1377,+1446,+1448,+1379,+1449,+1329,+1327,+1404,+1406,+1405,+1402,+1328,+1426,+1432,+1434,+1403,+1394,+1395,+1433,+1425,+1286,+1380,+1466,+1431,+1290,+1401,+1381,+1427,+1450,+1393,+1430,+1326,+1396,+1428,+1397,+1429,+1398,+1420,+1324,+1422,+1417,+1419,+1421,+1443,+1418,+1423,+1444,+1442,+1424,+1445,+1495,+1440,+1441,+1468,+1436,+1408,+1322,+1435,+1415,+1439,+1409,+1283,+1438,+1416,+1407,+1437,+1411,+1413,+1414,+1320,+1273,+1272,+1278,+1469,+1463,+1457,+1358,+1464,+1465,+1359,+1372,+1391,+1390,+1455,+1447,+1454,+1467,+1453,+1452,+1451,+1383,+1345,+1347,+1348,+1350,+1364,+1392,+1410,+1412]
178
+ left_back_paw = [1957,+1958,+1701,+1956,+1951,+1703,+1715,+1702,+1700,+1673,+1705,+1952,+1955,+1674,+1699,+1675,+1953,+1704,+1954,+1698,+1677,+1671,+1672,+1714,+1706,+1676,+1519,+1523,+1686,+1713,+1692,+1685,+1543,+1664,+1712,+1691,+1959,+1541,+1684,+1542,+1496,+1663,+1540,+1497,+1499,+1498,+1500,+1693,+1665,+1694,+1716,+1666,+1695,+1501,+1502,+1696,+1667,+1503,+1697,+1504,+1668,+1669,+1506,+1670,+1508,+1510,+1507,+1509,+1511,+1512,+1621,+1606,+1619,+1605,+1513,+1620,+1618,+1604,+1633,+1641,+1642,+1607,+1617,+1514,+1632,+1614,+1689,+1640,+1515,+1586,+1616,+1516,+1517,+1603,+1615,+1639,+1585,+1521,+1602,+1587,+1584,+1601,+1623,+1622,+1631,+1598,+1624,+1629,+1589,+1687,+1625,+1599,+1630,+1569,+1570,+1628,+1626,+1597,+1627,+1590,+1594,+1571,+1568,+1567,+1574,+1646,+1573,+1645,+1648,+1564,+1688,+1647,+1643,+1649,+1650,+1651,+1577,+1644,+1565,+1652,+1566,+1578,+1518,+1524,+1583,+1582,+1520,+1581,+1522,+1525,+1549,+1551,+1580,+1552,+1550,+1656,+1658,+1554,+1657,+1659,+1548,+1655,+1690,+1660,+1556,+1653,+1558,+1661,+1544,+1662,+1654,+1547,+1545,+1527,+1560,+1526,+1678,+1679,+1528,+1708,+1707,+1680,+1529,+1530,+1709,+1546,+1681,+1710,+1711,+1682,+1532,+1531,+1683,+1534,+1533,+1536,+1538,+1600,+1553]
179
+
180
+
181
+ all_contact_vertices = right_front_paw + right_back_paw + left_front_paw + left_back_paw
182
+
183
+ name = 'all4pawsincontact.jpg'
184
+ print('work on 4paw images')
185
+ gc_info_raw = all_contact_vertices # a list with all vertex numbers that are in ground contact
186
+
187
+ vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH, verts=verts, faces=faces, img_v12_dir=None)
188
+ np.save(ROOT_OUT_PATH + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview)
189
+
190
+ vertex_overview_dict = {}
191
+ vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw}
192
+ with open(ROOT_OUT_PATH + 'gc_annots_overview_all4pawsincontact_xx.pkl', 'wb') as fp:
193
+ pkl.dump(vertex_overview_dict, fp)
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
207
+
208
+
209
+
210
+
211
+
212
+
213
+
src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py
4
+ shortest.py
5
+ ----------------
6
+ Given a mesh and two vertex indices find the shortest path
7
+ between the two vertices while only traveling along edges
8
+ of the mesh.
9
+ """
10
+
11
+ # python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py
12
+
13
+
14
+ import os
15
+ import sys
16
+ import glob
17
+ import csv
18
+ import json
19
+ import shutil
20
+ import tqdm
21
+ import numpy as np
22
+ import pickle as pkl
23
+ import trimesh
24
+ import networkx as nx
25
+
26
+
27
+
28
+
29
+
30
+ def read_csv(csv_file):
31
+ with open(csv_file,'r') as f:
32
+ reader = csv.reader(f)
33
+ headers = next(reader)
34
+ row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader]
35
+ return row_list
36
+
37
+
38
+ def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'):
39
+ vert_dists = np.load(root_out_path + filename)
40
+ return vert_dists
41
+
42
+
43
+ def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False):
44
+ # root_out_path = ROOT_OUT_PATH
45
+ '''
46
+ from smal_pytorch.smal_model.smal_torch_new import SMAL
47
+ smal = SMAL()
48
+ verts = smal.v_template.detach().cpu().numpy()
49
+ faces = smal.faces.detach().cpu().numpy()
50
+ '''
51
+ # path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
52
+ my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
53
+ verts = my_mesh.vertices
54
+ faces = my_mesh.faces
55
+ # edges without duplication
56
+ edges = my_mesh.edges_unique
57
+ # the actual length of each unique edge
58
+ length = my_mesh.edges_unique_length
59
+ # create the graph with edge attributes for length (option A)
60
+ # g = nx.Graph()
61
+ # for edge, L in zip(edges, length): g.add_edge(*edge, length=L)
62
+ # you can create the graph with from_edgelist and
63
+ # a list comprehension (option B)
64
+ ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)])
65
+ # calculate the distances between all vertex pairs
66
+ if calc_dist_mat:
67
+ # calculate distances between all possible vertex pairs
68
+ # shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length')
69
+ # shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length')
70
+ dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra'))
71
+ vertex_distances = np.zeros((n_verts_smal, n_verts_smal))
72
+ for ind_v0 in range(n_verts_smal):
73
+ print(ind_v0)
74
+ for ind_v1 in range(ind_v0, n_verts_smal):
75
+ vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1]
76
+ vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1]
77
+ # save those distances
78
+ np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances)
79
+ vert_dists = vertex_distances
80
+ else:
81
+ vert_dists = np.load(root_out_path + 'all_vertex_distances.npy')
82
+ return ga, vert_dists
83
+
84
+
85
+ def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None):
86
+ # input:
87
+ # root_out_path_vis = ROOT_OUT_PATH
88
+ # img_v12_dir = IMG_V12_DIR
89
+ # name = images_with_gc_labelled[ind_img]
90
+ # gc_info_raw = gc_dict['bite/' + name]
91
+ # output:
92
+ # vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist]
93
+ n_verts_smal = 3889
94
+ gc_vertices = []
95
+ gc_info_np = np.zeros((n_verts_smal))
96
+ for ind_v in gc_info_raw:
97
+ if ind_v < n_verts_smal:
98
+ gc_vertices.append(ind_v)
99
+ gc_info_np[ind_v] = 1
100
+ # save a visualization of those annotations
101
+ if root_out_path_vis is not None:
102
+ my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True)
103
+ if img_v12_dir is not None and root_out_path_vis is not None:
104
+ vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1)
105
+ my_mesh.visual.vertex_colors = vert_colors
106
+ my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj'))
107
+ img_path = img_v12_dir + name
108
+ shutil.copy(img_path, root_out_path_vis + name)
109
+ # calculate for each vertex the distance to the closest element of the other group
110
+ non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices))
111
+ print('vertices in contact: ' + str(len(gc_vertices)))
112
+ print('vertices without contact: ' + str(len(non_gc_vertices)))
113
+ vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist
114
+ vertex_overview[:, 0] = gc_info_np
115
+ # loop through all contact vertices
116
+ for ind_v in gc_vertices:
117
+ min_length = 100
118
+ for ind_v_ps in non_gc_vertices: # possible solution
119
+ # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
120
+ # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
121
+ this_length = vert_dists[ind_v, ind_v_ps]
122
+ if this_length < min_length:
123
+ min_length = this_length
124
+ vertex_overview[ind_v, 1] = ind_v_ps
125
+ vertex_overview[ind_v, 2] = this_length
126
+ # loop through all non-contact vertices
127
+ for ind_v in non_gc_vertices:
128
+ min_length = 100
129
+ for ind_v_ps in gc_vertices: # possible solution
130
+ # this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
131
+ # this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
132
+ this_length = vert_dists[ind_v, ind_v_ps]
133
+ if this_length < min_length:
134
+ min_length = this_length
135
+ vertex_overview[ind_v, 1] = ind_v_ps
136
+ vertex_overview[ind_v, 2] = this_length
137
+ if root_out_path_vis is not None:
138
+ # save a colored mesh
139
+ my_mesh_dists = my_mesh.copy()
140
+ scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max()
141
+ scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max()
142
+ vert_col = np.zeros((n_verts_smal, 3))
143
+ vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green
144
+ vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red
145
+ my_mesh_dists.visual.vertex_colors = np.uint8(vert_col)
146
+ my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj'))
147
+ return vertex_overview
148
+
149
+
150
+ def summarize_results_stage2b(row_list, display_worker_performance=False):
151
+ # four catch trials are included in every batch
152
+ annot_n02088466_3184 = {'paw_rb': 0, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 1, 'additional_part': 0, 'no_contact': 0}
153
+ annot_n02100583_9922 = {'paw_rb': 1, 'paw_rf': 0, 'paw_lb': 0, 'paw_lf': 0, 'additional_part': 0, 'no_contact': 0}
154
+ annot_n02105056_2798 = {'paw_rb': 1, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 1, 'additional_part': 1, 'no_contact': 0}
155
+ annot_n02091831_2288 = {'paw_rb': 0, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 0, 'additional_part': 0, 'no_contact': 0}
156
+ all_comments = []
157
+ all_annotations = {}
158
+ for row in row_list:
159
+ all_comments.append(row['Answer.submitComments'])
160
+ worker_id = row['WorkerId']
161
+ if display_worker_performance:
162
+ print('----------------------------------------------------------------------------------------------')
163
+ print('Worker ID: ' + worker_id)
164
+ n_wrong = 0
165
+ n_correct = 0
166
+ for ind in range(0, len(row['Answer.submitValuesNotSure'].split(';')) - 1):
167
+ input_image = (row['Input.images'].split(';')[ind]).split('StanExtV12_Images/')[-1]
168
+ paw_rb = row['Answer.submitValuesRightBack'].split(';')[ind]
169
+ paw_rf = row['Answer.submitValuesRightFront'].split(';')[ind]
170
+ paw_lb = row['Answer.submitValuesLeftBack'].split(';')[ind]
171
+ paw_lf = row['Answer.submitValuesLeftFront'].split(';')[ind]
172
+ addpart = row['Answer.submitValuesAdditional'].split(';')[ind]
173
+ no_contact = row['Answer.submitValuesNoContact'].split(';')[ind]
174
+ unsure = row['Answer.submitValuesNotSure'].split(';')[ind]
175
+ annot = {'paw_rb': paw_rb, 'paw_rf': paw_rf, 'paw_lb': paw_lb, 'paw_lf': paw_lf,
176
+ 'additional_part': addpart, 'no_contact': no_contact, 'not_sure': unsure,
177
+ 'worker_id': worker_id} # , 'input_image': input_image}
178
+ if ind == 0:
179
+ gt = annot_n02088466_3184
180
+ elif ind == 1:
181
+ gt = annot_n02105056_2798
182
+ elif ind == 2:
183
+ gt = annot_n02091831_2288
184
+ elif ind == 3:
185
+ gt = annot_n02100583_9922
186
+ else:
187
+ pass
188
+ if ind < 4:
189
+ for key in gt.keys():
190
+ if str(annot[key]) == str(gt[key]):
191
+ n_correct += 1
192
+ else:
193
+ if display_worker_performance:
194
+ print(input_image)
195
+ print(key + ':[ expected: ' + str(gt[key]) + ' predicted: ' + str(annot[key]) + ' ]')
196
+ n_wrong += 1
197
+ else:
198
+ all_annotations[input_image] = annot
199
+ if display_worker_performance:
200
+ print('n_correct: ' + str(n_correct))
201
+ print('n_wrong: ' + str(n_wrong))
202
+ return all_annotations, all_comments
203
+
204
+
205
+
206
+
207
+
208
+
209
+ def main():
210
+
211
+ ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/'
212
+ ROOT_PATH_ANNOT = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/'
213
+ IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/'
214
+ # ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/'
215
+ ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/'
216
+ ROOT_OUT_PATH_VIS = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/vis/'
217
+ ROOT_OUT_PATH_DISTSGCNONGC = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/vertex_distances_gc_nongc/'
218
+ ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/'
219
+
220
+ # load all vertex distances
221
+ path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
222
+ my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
223
+ verts = my_mesh.vertices
224
+ faces = my_mesh.faces
225
+ # vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False)
226
+ vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy')
227
+
228
+
229
+
230
+
231
+
232
+ # paw vertices:
233
+ # left and right is a bit different, but that is ok (we will anyways mirror data at training time)
234
+ right_front_paw = [3829,+3827,+3825,+3718,+3722,+3723,+3743,+3831,+3719,+3726,+3716,+3724,+3828,+3717,+3721,+3725,+3832,+3830,+3720,+3288,+3740,+3714,+3826,+3715,+3728,+3712,+3287,+3284,+3727,+3285,+3742,+3291,+3710,+3697,+3711,+3289,+3730,+3713,+3739,+3282,+3738,+3708,+3709,+3741,+3698,+3696,+3308,+3695,+3706,+3700,+3707,+3306,+3305,+3737,+3304,+3303,+3307,+3736,+3735,+3250,+3261,+3732,+3734,+3733,+3731,+3729,+3299,+3297,+3298,+3295,+3293,+3296,+3294,+3292,+3312,+3311,+3314,+3309,+3290,+3313,+3410,+3315,+3411,+3412,+3316,+3421,+3317,+3415,+3445,+3327,+3328,+3283,+3343,+3326,+3325,+3330,+3286,+3399,+3398,+3329,+3446,+3400,+3331,+3401,+3281,+3332,+3279,+3402,+3419,+3407,+3356,+3358,+3357,+3280,+3354,+3277,+3278,+3346,+3347,+3377,+3378,+3345,+3386,+3379,+3348,+3384,+3418,+3372,+3276,+3275,+3374,+3274,+3373,+3375,+3369,+3371,+3376,+3273,+3396,+3397,+3395,+3388,+3360,+3370,+3361,+3394,+3387,+3420,+3359,+3389,+3272,+3391,+3393,+3390,+3392,+3363,+3362,+3367,+3365,+3705,+3271,+3704,+3703,+3270,+3269,+3702,+3268,+3224,+3267,+3701,+3225,+3699,+3265,+3264,+3266,+3263,+3262,+3249,+3228,+3230,+3251,+3301,+3300,+3302,+3252]
235
+ right_back_paw = [3472,+3627,+3470,+3469,+3471,+3473,+3626,+3625,+3475,+3655,+3519,+3468,+3629,+3466,+3476,+3624,+3521,+3654,+3657,+3838,+3518,+3653,+3839,+3553,+3474,+3516,+3656,+3628,+3834,+3535,+3630,+3658,+3477,+3520,+3517,+3595,+3522,+3597,+3596,+3501,+3534,+3503,+3478,+3500,+3479,+3502,+3607,+3499,+3608,+3496,+3605,+3609,+3504,+3606,+3642,+3614,+3498,+3480,+3631,+3610,+3613,+3506,+3659,+3660,+3632,+3841,+3661,+3836,+3662,+3633,+3663,+3664,+3634,+3635,+3486,+3665,+3636,+3637,+3666,+3490,+3837,+3667,+3493,+3638,+3492,+3495,+3616,+3644,+3494,+3835,+3643,+3833,+3840,+3615,+3650,+3668,+3652,+3651,+3645,+3646,+3647,+3649,+3648,+3622,+3617,+3448,+3621,+3618,+3623,+3462,+3464,+3460,+3620,+3458,+3461,+3463,+3465,+3573,+3571,+3467,+3569,+3557,+3558,+3572,+3570,+3556,+3585,+3593,+3594,+3459,+3566,+3592,+3567,+3568,+3538,+3539,+3555,+3537,+3536,+3554,+3575,+3574,+3583,+3541,+3550,+3576,+3581,+3639,+3577,+3551,+3582,+3580,+3552,+3578,+3542,+3549,+3579,+3523,+3526,+3598,+3525,+3600,+3640,+3599,+3601,+3602,+3603,+3529,+3604,+3530,+3533,+3532,+3611,+3612,+3482,+3481,+3505,+3452,+3455,+3456,+3454,+3457,+3619,+3451,+3450,+3449,+3591,+3589,+3641,+3584,+3561,+3587,+3559,+3488,+3484,+3483]
236
+ left_front_paw = [1791,+1950,+1948,+1790,+1789,+1746,+1788,+1747,+1949,+1944,+1792,+1945,+1356,+1775,+1759,+1777,+1787,+1946,+1757,+1761,+1745,+1943,+1947,+1744,+1309,+1786,+1771,+1354,+1774,+1765,+1767,+1768,+1772,+1763,+1770,+1773,+1769,+1764,+1766,+1758,+1760,+1762,+1336,+1333,+1330,+1325,+1756,+1323,+1755,+1753,+1749,+1754,+1751,+1321,+1752,+1748,+1750,+1312,+1319,+1315,+1313,+1317,+1318,+1316,+1314,+1311,+1310,+1299,+1276,+1355,+1297,+1353,+1298,+1300,+1352,+1351,+1785,+1784,+1349,+1783,+1782,+1781,+1780,+1779,+1778,+1776,+1343,+1341,+1344,+1339,+1342,+1340,+1360,+1335,+1338,+1362,+1357,+1361,+1363,+1458,+1337,+1459,+1456,+1460,+1493,+1332,+1375,+1376,+1331,+1374,+1378,+1334,+1373,+1494,+1377,+1446,+1448,+1379,+1449,+1329,+1327,+1404,+1406,+1405,+1402,+1328,+1426,+1432,+1434,+1403,+1394,+1395,+1433,+1425,+1286,+1380,+1466,+1431,+1290,+1401,+1381,+1427,+1450,+1393,+1430,+1326,+1396,+1428,+1397,+1429,+1398,+1420,+1324,+1422,+1417,+1419,+1421,+1443,+1418,+1423,+1444,+1442,+1424,+1445,+1495,+1440,+1441,+1468,+1436,+1408,+1322,+1435,+1415,+1439,+1409,+1283,+1438,+1416,+1407,+1437,+1411,+1413,+1414,+1320,+1273,+1272,+1278,+1469,+1463,+1457,+1358,+1464,+1465,+1359,+1372,+1391,+1390,+1455,+1447,+1454,+1467,+1453,+1452,+1451,+1383,+1345,+1347,+1348,+1350,+1364,+1392,+1410,+1412]
237
+ left_back_paw = [1957,+1958,+1701,+1956,+1951,+1703,+1715,+1702,+1700,+1673,+1705,+1952,+1955,+1674,+1699,+1675,+1953,+1704,+1954,+1698,+1677,+1671,+1672,+1714,+1706,+1676,+1519,+1523,+1686,+1713,+1692,+1685,+1543,+1664,+1712,+1691,+1959,+1541,+1684,+1542,+1496,+1663,+1540,+1497,+1499,+1498,+1500,+1693,+1665,+1694,+1716,+1666,+1695,+1501,+1502,+1696,+1667,+1503,+1697,+1504,+1668,+1669,+1506,+1670,+1508,+1510,+1507,+1509,+1511,+1512,+1621,+1606,+1619,+1605,+1513,+1620,+1618,+1604,+1633,+1641,+1642,+1607,+1617,+1514,+1632,+1614,+1689,+1640,+1515,+1586,+1616,+1516,+1517,+1603,+1615,+1639,+1585,+1521,+1602,+1587,+1584,+1601,+1623,+1622,+1631,+1598,+1624,+1629,+1589,+1687,+1625,+1599,+1630,+1569,+1570,+1628,+1626,+1597,+1627,+1590,+1594,+1571,+1568,+1567,+1574,+1646,+1573,+1645,+1648,+1564,+1688,+1647,+1643,+1649,+1650,+1651,+1577,+1644,+1565,+1652,+1566,+1578,+1518,+1524,+1583,+1582,+1520,+1581,+1522,+1525,+1549,+1551,+1580,+1552,+1550,+1656,+1658,+1554,+1657,+1659,+1548,+1655,+1690,+1660,+1556,+1653,+1558,+1661,+1544,+1662,+1654,+1547,+1545,+1527,+1560,+1526,+1678,+1679,+1528,+1708,+1707,+1680,+1529,+1530,+1709,+1546,+1681,+1710,+1711,+1682,+1532,+1531,+1683,+1534,+1533,+1536,+1538,+1600,+1553]
238
+
239
+
240
+
241
+
242
+
243
+ all_keys = []
244
+ gc_dict = {}
245
+ vertex_overview_nocontact = {}
246
+ # data/stanext_related_data/ground_contact_annotations/stage3/main_partA1667_20221021_140108.csv
247
+ for csv_file in ['Stage2b_finalResults.csv']:
248
+ # load all ground contact annotations
249
+ gc_annot_csv = ROOT_PATH_ANNOT + csv_file # 'my_gcannotations_qualification.csv'
250
+ gc_row_list = read_csv(gc_annot_csv)
251
+ all_annotations, all_comments = summarize_results_stage2b(gc_row_list, display_worker_performance=False)
252
+ for key, value in all_annotations.items():
253
+ if value['not_sure'] == '0':
254
+ if value['no_contact'] == '1':
255
+ vertex_overview_nocontact[key.split('.')[0]] = {'gc_vertdists_overview': 'no contact', 'gc_index_list': None}
256
+ else:
257
+ all_contact_vertices = []
258
+ if value['paw_rf'] == '1':
259
+ all_contact_vertices.extend(right_front_paw)
260
+ if value['paw_rb'] == '1':
261
+ all_contact_vertices.extend(right_back_paw)
262
+ if value['paw_lf'] == '1':
263
+ all_contact_vertices.extend(left_front_paw)
264
+ if value['paw_lb'] == '1':
265
+ all_contact_vertices.extend(left_back_paw)
266
+ gc_dict[key] = all_contact_vertices
267
+ print('number of labeled images: ' + str(len(gc_dict.keys())))
268
+ print('number of images without contact: ' + str(len(vertex_overview_nocontact.keys())))
269
+
270
+ # prepare and save contact annotations including distances
271
+ vertex_overview_dict = {}
272
+ for ind_img, name_ingcdict in enumerate(gc_dict.keys()): # range(len(gc_dict.keys())):
273
+ name = name_ingcdict # name_ingcdict.split('bite/')[1]
274
+ # name = images_with_gc_labelled[ind_img]
275
+ print('work on image ' + str(ind_img) + ': ' + name)
276
+ # gc_info_raw = gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact
277
+ gc_info_raw = gc_dict[name_ingcdict] # a list with all vertex numbers that are in ground contact
278
+
279
+ if not os.path.exists(ROOT_OUT_PATH_VIS + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_VIS + name.split('/')[0])
280
+ if not os.path.exists(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0])
281
+
282
+ vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH_VIS, verts=verts, faces=faces, img_v12_dir=None)
283
+ np.save(ROOT_OUT_PATH_DISTSGCNONGC + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview)
284
+
285
+ vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw}
286
+
287
+
288
+
289
+
290
+
291
+ # import pdb; pdb.set_trace()
292
+
293
+ with open(ROOT_OUT_PATH + 'gc_annots_overview_stage2b_contact_complete_xx.pkl', 'wb') as fp:
294
+ pkl.dump(vertex_overview_dict, fp)
295
+
296
+ with open(ROOT_OUT_PATH + 'gc_annots_overview_stage2b_nocontact_complete_xx.pkl', 'wb') as fp:
297
+ pkl.dump(vertex_overview_nocontact, fp)
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+
308
+
309
+ if __name__ == "__main__":
310
+ main()
311
+
312
+
313
+
314
+
315
+
316
+
317
+