Spaces:
Runtime error
Runtime error
Nadine Rueegg
commited on
Commit
•
753fd9a
1
Parent(s):
45abb23
initial commit with code and data
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +60 -0
- README.md +7 -10
- packages.txt +8 -0
- requirements.txt +15 -0
- scripts/gradio_demo.py +672 -0
- src/__init__.py +0 -0
- src/bps_2d/bps_for_segmentation.py +114 -0
- src/combined_model/__init__.py +0 -0
- src/combined_model/helper.py +207 -0
- src/combined_model/helper3.py +17 -0
- src/combined_model/loss_image_to_3d_refinement.py +216 -0
- src/combined_model/loss_image_to_3d_withbreedrel.py +342 -0
- src/combined_model/loss_utils/loss_arap.py +153 -0
- src/combined_model/loss_utils/loss_laplacian_mesh_comparison.py +45 -0
- src/combined_model/loss_utils/loss_sdf.py +122 -0
- src/combined_model/loss_utils/loss_utils.py +191 -0
- src/combined_model/loss_utils/loss_utils_gc.py +179 -0
- src/combined_model/model_shape_v7_withref_withgraphcnn.py +927 -0
- src/combined_model/train_main_image_to_3d_wbr_withref.py +955 -0
- src/combined_model/train_main_image_to_3d_withbreedrel.py +496 -0
- src/configs/SMAL_configs.py +230 -0
- src/configs/anipose_data_info.py +74 -0
- src/configs/barc_cfg_defaults.py +121 -0
- src/configs/barc_cfg_train.yaml +24 -0
- src/configs/barc_loss_weights_allzeros.json +30 -0
- src/configs/barc_loss_weights_with3dcgloss_higherbetaloss_v2_dm39dnnv3v2.json +30 -0
- src/configs/data_info.py +115 -0
- src/configs/dataset_path_configs.py +21 -0
- src/configs/dog_breeds/dog_breed_class.py +170 -0
- src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml +23 -0
- src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml +23 -0
- src/configs/refinement_cfg_train_withvertexwisegc_isflat_csmorestanding.yaml +31 -0
- src/configs/refinement_loss_weights_withgc_withvertexwise_addnonflat.json +20 -0
- src/configs/ttopt_loss_weights/bite_loss_weights_ttopt.json +77 -0
- src/configs/ttopt_loss_weights/ttopt_loss_weights_v2c_withlapcft_v2.json +77 -0
- src/graph_networks/__init__.py +0 -0
- src/graph_networks/graphcmr/__init__.py +0 -0
- src/graph_networks/graphcmr/get_downsampled_mesh_npz.py +84 -0
- src/graph_networks/graphcmr/graph_cnn.py +53 -0
- src/graph_networks/graphcmr/graph_cnn_groundcontact.py +101 -0
- src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage.py +174 -0
- src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage_includingresnet.py +170 -0
- src/graph_networks/graphcmr/graph_layers.py +125 -0
- src/graph_networks/graphcmr/graphcnn_coarse_to_fine_animal_pose.py +97 -0
- src/graph_networks/graphcmr/my_remarks.txt +11 -0
- src/graph_networks/graphcmr/pytorch_coma_mesh_operations.py +282 -0
- src/graph_networks/graphcmr/utils_mesh.py +138 -0
- src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py +245 -0
- src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py +213 -0
- src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py +317 -0
LICENSE
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
License
|
2 |
+
Software Copyright License for non-commercial scientific research purposes
|
3 |
+
Please read carefully the following terms and conditions and any accompanying documentation before you download and/or use BITE data, model and software, (the "Data & Software"), including 3D meshes, images, videos, textures, software, scripts, and animations. By downloading and/or using the Data & Software (including downloading, cloning, installing, and any other use of the corresponding github repository), you acknowledge that you have read these terms and conditions, understand them, and agree to be bound by them. If you do not agree with these terms and conditions, you must not download and/or use the Data & Software. Any infringement of the terms of this agreement will automatically terminate your rights under this License
|
4 |
+
|
5 |
+
Ownership / Licensees
|
6 |
+
The Software and the associated materials has been developed at the
|
7 |
+
|
8 |
+
Max Planck Institute for Intelligent Systems
|
9 |
+
and
|
10 |
+
ETH Zurich
|
11 |
+
|
12 |
+
Any copyright or patent right is owned by and proprietary material of the
|
13 |
+
|
14 |
+
Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (hereinafter “MPG”; MPI and MPG hereinafter collectively “Max-Planck”)
|
15 |
+
|
16 |
+
hereinafter the “Licensor”.
|
17 |
+
|
18 |
+
License Grant
|
19 |
+
Licensor grants you (Licensee) personally a single-user, non-exclusive, non-transferable, free of charge right:
|
20 |
+
|
21 |
+
To install the Data & Software on computers owned, leased or otherwise controlled by you and/or your organization;
|
22 |
+
To use the Data & Software for the sole purpose of performing non-commercial scientific research, non-commercial education, or non-commercial artistic projects;
|
23 |
+
Any other use, in particular any use for commercial, pornographic, military, or surveillance, purposes is prohibited. This includes, without limitation, incorporation in a commercial product, use in a commercial service, or production of other artifacts for commercial purposes. The Data & Software may not be used to create fake, libelous, misleading, or defamatory content of any kind excluding analyses in peer-reviewed scientific research. The Data & Software may not be reproduced, modified and/or made available in any form to any third party without Max-Planck’s prior written permission.
|
24 |
+
|
25 |
+
The Data & Software may not be used for pornographic purposes or to generate pornographic material whether commercial or not. This license also prohibits the use of the Software to train methods/algorithms/neural networks/etc. for commercial, pornographic, military, surveillance, or defamatory use of any kind. By downloading the Data & Software, you agree not to reverse engineer it.
|
26 |
+
|
27 |
+
No Distribution
|
28 |
+
The Data & Software and the license herein granted shall not be copied, shared, distributed, re-sold, offered for re-sale, transferred or sub-licensed in whole or in part except that you may make one copy for archive purposes only.
|
29 |
+
|
30 |
+
Disclaimer of Representations and Warranties
|
31 |
+
You expressly acknowledge and agree that the Data & Software results from basic research, is provided “AS IS”, may contain errors, and that any use of the Data & Software is at your sole risk. LICENSOR MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE DATA & SOFTWARE, NEITHER EXPRESS NOR IMPLIED, AND THE ABSENCE OF ANY LEGAL OR ACTUAL DEFECTS, WHETHER DISCOVERABLE OR NOT. Specifically, and not to limit the foregoing, licensor makes no representations or warranties (i) regarding the merchantability or fitness for a particular purpose of the Data & Software, (ii) that the use of the Data & Software will not infringe any patents, copyrights or other intellectual property rights of a third party, and (iii) that the use of the Data & Software will not cause any damage of any kind to you or a third party.
|
32 |
+
|
33 |
+
Limitation of Liability
|
34 |
+
Because this Data & Software License Agreement qualifies as a donation, according to Section 521 of the German Civil Code (Bürgerliches Gesetzbuch – BGB) Licensor as a donor is liable for intent and gross negligence only. If the Licensor fraudulently conceals a legal or material defect, they are obliged to compensate the Licensee for the resulting damage.
|
35 |
+
Licensor shall be liable for loss of data only up to the amount of typical recovery costs which would have arisen had proper and regular data backup measures been taken. For the avoidance of doubt Licensor shall be liable in accordance with the German Product Liability Act in the event of product liability. The foregoing applies also to Licensor’s legal representatives or assistants in performance. Any further liability shall be excluded.
|
36 |
+
Patent claims generated through the usage of the Data & Software cannot be directed towards the copyright holders.
|
37 |
+
The Data & Software is provided in the state of development the licensor defines. If modified or extended by Licensee, the Licensor makes no claims about the fitness of the Data & Software and is not responsible for any problems such modifications cause.
|
38 |
+
|
39 |
+
No Maintenance Services
|
40 |
+
You understand and agree that Licensor is under no obligation to provide either maintenance services, update services, notices of latent defects, or corrections of defects with regard to the Data & Software. Licensor nevertheless reserves the right to update, modify, or discontinue the Data & Software at any time.
|
41 |
+
|
42 |
+
Defects of the Data & Software must be notified in writing to the Licensor with a comprehensible description of the error symptoms. The notification of the defect should enable the reproduction of the error. The Licensee is encouraged to communicate any use, results, modification or publication.
|
43 |
+
|
44 |
+
Publications using the Data & Software
|
45 |
+
You acknowledge that the Data & Software is a valuable scientific resource and agree to appropriately reference the following paper in any publication making use of the Data & Software.
|
46 |
+
|
47 |
+
Citation:
|
48 |
+
|
49 |
+
|
50 |
+
@inproceedings{BITE:2023,
|
51 |
+
title = {BITE: Beyond priors for Improved Three-D dog pose Estimation},
|
52 |
+
author = {Rueegg, Nadine and Tripathi, Shashank and Schindler, Konrad and Black, Michael J. and Zuffi, Silvia},
|
53 |
+
booktitle = {under review},
|
54 |
+
year = {2023}
|
55 |
+
url = {https://bite.is.tue.mpg.de}
|
56 |
+
}
|
57 |
+
Commercial licensing opportunities
|
58 |
+
For commercial uses of the Data & Software, please send email to [email protected]
|
59 |
+
|
60 |
+
This Agreement shall be governed by the laws of the Federal Republic of Germany except for the UN Sales Convention.
|
README.md
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
colorTo: pink
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 3.
|
8 |
-
app_file:
|
9 |
pinned: false
|
10 |
-
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
title: BITE
|
2 |
+
emoji: 🐩 🐶 🐕
|
3 |
+
colorFrom: pink
|
4 |
+
colorTo: green
|
|
|
5 |
sdk: gradio
|
6 |
+
sdk_version: 3.0.2
|
7 |
+
app_file: ./scripts/gradio_demo.py
|
8 |
pinned: false
|
9 |
+
python_version: 3.7.6
|
|
|
|
packages.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
libgl1
|
2 |
+
unzip
|
3 |
+
ffmpeg
|
4 |
+
libsm6
|
5 |
+
libxext6
|
6 |
+
libgl1-mesa-dri
|
7 |
+
libegl1-mesa
|
8 |
+
libgbm1
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.6.0
|
2 |
+
torchvision==0.7.0
|
3 |
+
pytorch3d==0.2.5
|
4 |
+
kornia==0.4.0
|
5 |
+
matplotlib
|
6 |
+
opencv-python
|
7 |
+
trimesh
|
8 |
+
scipy
|
9 |
+
chumpy
|
10 |
+
pymp
|
11 |
+
importlib-resources
|
12 |
+
pycocotools
|
13 |
+
openpyxl
|
14 |
+
dominate
|
15 |
+
git+https://github.com/runa91/FrEIA.git
|
scripts/gradio_demo.py
ADDED
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# aenv_new_icon_2
|
3 |
+
|
4 |
+
# was used for ttoptv6_sketchfab_v16: python src/test_time_optimization/ttopt_fromref_v6_sketchfab.py --workers 12 --save-images True --config refinement_cfg_visualization_withgc_withvertexwisegc_isflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar --sketchfab 1
|
5 |
+
|
6 |
+
# for stanext images:
|
7 |
+
# python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar -s ttopt_vtest1
|
8 |
+
# for all images from the folder datasets/test_image_crops:
|
9 |
+
# python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar -s ttopt_vtest2
|
10 |
+
|
11 |
+
'''import os
|
12 |
+
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
|
13 |
+
os.environ["CUDA_VISIBLE_DEVICES"]="0"
|
14 |
+
try:
|
15 |
+
# os.system("pip install --upgrade torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html")
|
16 |
+
os.system("pip install --upgrade torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/cu101/torch_stable.html")
|
17 |
+
except Exception as e:
|
18 |
+
print(e)'''
|
19 |
+
|
20 |
+
import argparse
|
21 |
+
import os.path
|
22 |
+
import json
|
23 |
+
import numpy as np
|
24 |
+
import pickle as pkl
|
25 |
+
import csv
|
26 |
+
from distutils.util import strtobool
|
27 |
+
import torch
|
28 |
+
from torch import nn
|
29 |
+
import torch.backends.cudnn
|
30 |
+
from torch.nn import DataParallel
|
31 |
+
from torch.utils.data import DataLoader
|
32 |
+
from collections import OrderedDict
|
33 |
+
import glob
|
34 |
+
from tqdm import tqdm
|
35 |
+
from dominate import document
|
36 |
+
from dominate.tags import *
|
37 |
+
from PIL import Image
|
38 |
+
from matplotlib import pyplot as plt
|
39 |
+
import trimesh
|
40 |
+
import cv2
|
41 |
+
import shutil
|
42 |
+
import random
|
43 |
+
import gradio as gr
|
44 |
+
|
45 |
+
import torchvision
|
46 |
+
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
47 |
+
import torchvision.transforms as T
|
48 |
+
from pytorch3d.structures import Meshes
|
49 |
+
from pytorch3d.loss import mesh_edge_loss, mesh_laplacian_smoothing, mesh_normal_consistency
|
50 |
+
|
51 |
+
|
52 |
+
import sys
|
53 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'src'))
|
54 |
+
|
55 |
+
from combined_model.train_main_image_to_3d_wbr_withref import do_validation_epoch
|
56 |
+
from combined_model.model_shape_v7_withref_withgraphcnn import ModelImageTo3d_withshape_withproj
|
57 |
+
|
58 |
+
from configs.barc_cfg_defaults import get_cfg_defaults, update_cfg_global_with_yaml, get_cfg_global_updated
|
59 |
+
|
60 |
+
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d
|
61 |
+
from stacked_hourglass.datasets.utils_dataset_selection import get_evaluation_dataset, get_sketchfab_evaluation_dataset, get_crop_evaluation_dataset, get_norm_dict, get_single_crop_dataset_from_image
|
62 |
+
|
63 |
+
from test_time_optimization.bite_inference_model_for_ttopt import BITEInferenceModel
|
64 |
+
from smal_pytorch.smal_model.smal_torch_new import SMAL
|
65 |
+
from configs.SMAL_configs import SMAL_MODEL_CONFIG
|
66 |
+
from smal_pytorch.renderer.differentiable_renderer import SilhRenderer
|
67 |
+
from test_time_optimization.utils.utils_ttopt import reset_loss_values, get_optimed_pose_with_glob
|
68 |
+
|
69 |
+
from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error
|
70 |
+
from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch
|
71 |
+
from combined_model.loss_utils.loss_arap import Arap_Loss
|
72 |
+
from combined_model.loss_utils.loss_laplacian_mesh_comparison import LaplacianCTF # (coarse to fine animal)
|
73 |
+
from graph_networks import graphcmr # .utils_mesh import Mesh
|
74 |
+
from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
|
75 |
+
|
76 |
+
random.seed(0)
|
77 |
+
|
78 |
+
print(
|
79 |
+
"torch: ", torch.__version__,
|
80 |
+
"\ntorchvision: ", torchvision.__version__,
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def get_prediction(model, img_path_or_img, confidence=0.5):
|
85 |
+
"""
|
86 |
+
see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g
|
87 |
+
get_prediction
|
88 |
+
parameters:
|
89 |
+
- img_path - path of the input image
|
90 |
+
- confidence - threshold value for prediction score
|
91 |
+
method:
|
92 |
+
- Image is obtained from the image path
|
93 |
+
- the image is converted to image tensor using PyTorch's Transforms
|
94 |
+
- image is passed through the model to get the predictions
|
95 |
+
- class, box coordinates are obtained, but only prediction score > threshold
|
96 |
+
are chosen.
|
97 |
+
"""
|
98 |
+
if isinstance(img_path_or_img, str):
|
99 |
+
img = Image.open(img_path_or_img).convert('RGB')
|
100 |
+
else:
|
101 |
+
img = img_path_or_img
|
102 |
+
transform = T.Compose([T.ToTensor()])
|
103 |
+
img = transform(img)
|
104 |
+
pred = model([img])
|
105 |
+
# pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
|
106 |
+
pred_class = list(pred[0]['labels'].numpy())
|
107 |
+
pred_boxes = [[(int(i[0]), int(i[1])), (int(i[2]), int(i[3]))] for i in list(pred[0]['boxes'].detach().numpy())]
|
108 |
+
pred_score = list(pred[0]['scores'].detach().numpy())
|
109 |
+
try:
|
110 |
+
pred_t = [pred_score.index(x) for x in pred_score if x>confidence][-1]
|
111 |
+
pred_boxes = pred_boxes[:pred_t+1]
|
112 |
+
pred_class = pred_class[:pred_t+1]
|
113 |
+
return pred_boxes, pred_class, pred_score
|
114 |
+
except:
|
115 |
+
print('no bounding box with a score that is high enough found! -> work on full image')
|
116 |
+
return None, None, None
|
117 |
+
|
118 |
+
|
119 |
+
def detect_object(model, img_path_or_img, confidence=0.5, rect_th=2, text_size=0.5, text_th=1):
|
120 |
+
"""
|
121 |
+
see https://haochen23.github.io/2020/04/object-detection-faster-rcnn.html#.YsMCm4TP3-g
|
122 |
+
object_detection_api
|
123 |
+
parameters:
|
124 |
+
- img_path_or_img - path of the input image
|
125 |
+
- confidence - threshold value for prediction score
|
126 |
+
- rect_th - thickness of bounding box
|
127 |
+
- text_size - size of the class label text
|
128 |
+
- text_th - thichness of the text
|
129 |
+
method:
|
130 |
+
- prediction is obtained from get_prediction method
|
131 |
+
- for each prediction, bounding box is drawn and text is written
|
132 |
+
with opencv
|
133 |
+
- the final image is displayed
|
134 |
+
"""
|
135 |
+
boxes, pred_cls, pred_scores = get_prediction(model, img_path_or_img, confidence)
|
136 |
+
if isinstance(img_path_or_img, str):
|
137 |
+
img = cv2.imread(img_path_or_img)
|
138 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
139 |
+
else:
|
140 |
+
img = img_path_or_img
|
141 |
+
is_first = True
|
142 |
+
bbox = None
|
143 |
+
if boxes is not None:
|
144 |
+
for i in range(len(boxes)):
|
145 |
+
cls = pred_cls[i]
|
146 |
+
if cls == 18 and bbox is None:
|
147 |
+
cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
|
148 |
+
# cv2.putText(img, pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
|
149 |
+
# cv2.putText(img, str(pred_scores[i]), boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
|
150 |
+
bbox = boxes[i]
|
151 |
+
return img, bbox
|
152 |
+
|
153 |
+
|
154 |
+
# -------------------------------------------------------------------------------------------------------------------- #
|
155 |
+
model_bbox = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
|
156 |
+
model_bbox.eval()
|
157 |
+
|
158 |
+
def run_bbox_inference(input_image):
|
159 |
+
# load configs
|
160 |
+
cfg = get_cfg_global_updated()
|
161 |
+
out_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples', 'test2.png')
|
162 |
+
img, bbox = detect_object(model=model_bbox, img_path_or_img=input_image, confidence=0.5)
|
163 |
+
fig = plt.figure() # plt.figure(figsize=(20,30))
|
164 |
+
plt.imsave(out_path, img)
|
165 |
+
return img, bbox
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
# -------------------------------------------------------------------------------------------------------------------- #
|
170 |
+
# python scripts/gradio.py --workers 12 --config refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml --model-file-complete=cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar
|
171 |
+
args_config = "refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml"
|
172 |
+
args_model_file_complete = "cvpr23_dm39dnnv3barcv2b_refwithgcpervertisflat0morestanding0/checkpoint.pth.tar"
|
173 |
+
args_suffix = "ttopt_v0"
|
174 |
+
args_loss_weight_ttopt_path = "bite_loss_weights_ttopt.json"
|
175 |
+
args_workers = 12
|
176 |
+
# -------------------------------------------------------------------------------------------------------------------- #
|
177 |
+
|
178 |
+
|
179 |
+
|
180 |
+
# load configs
|
181 |
+
# step 1: load default configs
|
182 |
+
# step 2: load updates from .yaml file
|
183 |
+
path_config = os.path.join(get_cfg_defaults().barc_dir, 'src', 'configs', args_config)
|
184 |
+
update_cfg_global_with_yaml(path_config)
|
185 |
+
cfg = get_cfg_global_updated()
|
186 |
+
|
187 |
+
# define path to load the trained model
|
188 |
+
path_model_file_complete = os.path.join(cfg.paths.ROOT_CHECKPOINT_PATH, args_model_file_complete)
|
189 |
+
|
190 |
+
# define and create paths to save results
|
191 |
+
out_sub_name = cfg.data.VAL_OPT + '_' + cfg.data.DATASET + '_' + args_suffix + '/'
|
192 |
+
root_out_path = os.path.join(os.path.dirname(path_model_file_complete).replace(cfg.paths.ROOT_CHECKPOINT_PATH, cfg.paths.ROOT_OUT_PATH + 'results_gradio/'), out_sub_name)
|
193 |
+
root_out_path_details = root_out_path + 'details/'
|
194 |
+
if not os.path.exists(root_out_path): os.makedirs(root_out_path)
|
195 |
+
if not os.path.exists(root_out_path_details): os.makedirs(root_out_path_details)
|
196 |
+
print('root_out_path: ' + root_out_path)
|
197 |
+
|
198 |
+
# other paths
|
199 |
+
root_data_path = os.path.join(os.path.dirname(__file__), '../', 'data')
|
200 |
+
# downsampling as used in graph neural network
|
201 |
+
root_smal_downsampling = os.path.join(root_data_path, 'graphcmr_data')
|
202 |
+
# remeshing as used for ground contact
|
203 |
+
remeshing_path = os.path.join(root_data_path, 'smal_data_remeshed', 'uniform_surface_sampling', 'my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl')
|
204 |
+
|
205 |
+
loss_weight_path = os.path.join(os.path.dirname(__file__), '../', 'src', 'configs', 'ttopt_loss_weights', args_loss_weight_ttopt_path)
|
206 |
+
print(loss_weight_path)
|
207 |
+
|
208 |
+
|
209 |
+
# Select the hardware device to use for training.
|
210 |
+
if torch.cuda.is_available() and cfg.device=='cuda':
|
211 |
+
device = torch.device('cuda', torch.cuda.current_device())
|
212 |
+
torch.backends.cudnn.benchmark = False # True
|
213 |
+
else:
|
214 |
+
device = torch.device('cpu')
|
215 |
+
|
216 |
+
print('structure_pose_net: ' + cfg.params.STRUCTURE_POSE_NET)
|
217 |
+
print('refinement network type: ' + cfg.params.REF_NET_TYPE)
|
218 |
+
print('smal_model_type: ' + cfg.smal.SMAL_MODEL_TYPE)
|
219 |
+
|
220 |
+
# prepare complete model
|
221 |
+
norm_dict = get_norm_dict(data_info=None, device=device)
|
222 |
+
bite_model = BITEInferenceModel(cfg, path_model_file_complete, norm_dict)
|
223 |
+
smal_model_type = bite_model.smal_model_type
|
224 |
+
logscale_part_list = SMAL_MODEL_CONFIG[smal_model_type]['logscale_part_list'] # ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
|
225 |
+
smal = SMAL(smal_model_type=smal_model_type, template_name='neutral', logscale_part_list=logscale_part_list).to(device)
|
226 |
+
silh_renderer = SilhRenderer(image_size=256).to(device)
|
227 |
+
|
228 |
+
# load loss modules -> not necessary!
|
229 |
+
# loss_module = Loss(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device)
|
230 |
+
# loss_module_ref = LossRef(smal_model_type=cfg.smal.SMAL_MODEL_TYPE, data_info=StanExt.DATA_INFO, nf_version=cfg.params.NF_VERSION).to(device)
|
231 |
+
|
232 |
+
# remeshing utils
|
233 |
+
with open(remeshing_path, 'rb') as fp:
|
234 |
+
remeshing_dict = pkl.load(fp)
|
235 |
+
remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
|
236 |
+
remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
# create path for output files
|
242 |
+
save_imgs_path = os.path.join(cfg.paths.ROOT_OUT_PATH, 'gradio_examples')
|
243 |
+
if not os.path.exists(save_imgs_path):
|
244 |
+
os.makedirs(save_imgs_path)
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
def run_bite_inference(input_image, bbox=None):
|
251 |
+
|
252 |
+
with open(loss_weight_path, 'r') as j:
|
253 |
+
losses = json.loads(j.read())
|
254 |
+
shutil.copyfile(loss_weight_path, root_out_path_details + os.path.basename(loss_weight_path))
|
255 |
+
print(losses)
|
256 |
+
|
257 |
+
# prepare dataset and dataset loader
|
258 |
+
val_dataset, val_loader, len_val_dataset, test_name_list, stanext_data_info, stanext_acc_joints = get_single_crop_dataset_from_image(input_image, bbox=bbox)
|
259 |
+
|
260 |
+
# summarize information for normalization
|
261 |
+
norm_dict = get_norm_dict(stanext_data_info, device)
|
262 |
+
# get keypoint weights
|
263 |
+
keypoint_weights = torch.tensor(stanext_data_info.keypoint_weights, dtype=torch.float)[None, :].to(device)
|
264 |
+
|
265 |
+
|
266 |
+
# prepare progress bar
|
267 |
+
iterable = enumerate(val_loader) # the length of this iterator should be 1
|
268 |
+
progress = None
|
269 |
+
if True: # not quiet:
|
270 |
+
progress = tqdm(iterable, desc='Train', total=len(val_loader), ascii=True, leave=False)
|
271 |
+
iterable = progress
|
272 |
+
ind_img_tot = 0
|
273 |
+
|
274 |
+
for i, (input, target_dict) in iterable:
|
275 |
+
batch_size = input.shape[0]
|
276 |
+
# prepare variables, put them on the right device
|
277 |
+
for key in target_dict.keys():
|
278 |
+
if key == 'breed_index':
|
279 |
+
target_dict[key] = target_dict[key].long().to(device)
|
280 |
+
elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
|
281 |
+
target_dict[key] = target_dict[key].float().to(device)
|
282 |
+
elif key == 'has_seg':
|
283 |
+
target_dict[key] = target_dict[key].to(device)
|
284 |
+
else:
|
285 |
+
pass
|
286 |
+
input = input.float().to(device)
|
287 |
+
|
288 |
+
# get starting values for the optimization
|
289 |
+
preds_dict = bite_model.get_all_results(input)
|
290 |
+
# res_normal_and_ref = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['normal', 'ref'])
|
291 |
+
res = bite_model.get_selected_results(preds_dict=preds_dict, result_networks=['ref'])['ref']
|
292 |
+
bs = res['pose_rotmat'].shape[0]
|
293 |
+
all_pose_6d = rotmat_to_rot6d(res['pose_rotmat'][:, None, 1:, :, :].clone().reshape((-1, 3, 3))).reshape((bs, -1, 6)) # [bs, 34, 6]
|
294 |
+
all_orient_6d = rotmat_to_rot6d(res['pose_rotmat'][:, None, :1, :, :].clone().reshape((-1, 3, 3))).reshape((bs, -1, 6)) # [bs, 1, 6]
|
295 |
+
|
296 |
+
|
297 |
+
ind_img = 0
|
298 |
+
name = (test_name_list[target_dict['index'][ind_img].long()]).replace('/', '__').split('.')[0]
|
299 |
+
|
300 |
+
print('ind_img_tot: ' + str(ind_img_tot) + ' -> ' + name)
|
301 |
+
ind_img_tot += 1
|
302 |
+
batch_size = 1
|
303 |
+
|
304 |
+
# save initial visualizations
|
305 |
+
# save the image with keypoints as predicted by the stacked hourglass
|
306 |
+
pred_unp_prep = torch.cat((res['hg_keyp_256'][ind_img, :, :].detach(), res['hg_keyp_scores'][ind_img, :, :]), 1)
|
307 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
308 |
+
out_path = root_out_path + name + '_hg_key.png'
|
309 |
+
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.01, print_scores=True, ratio_in_out=1.0) # threshold=0.3
|
310 |
+
# save the input image
|
311 |
+
img_inp = input[ind_img, :, :, :].clone()
|
312 |
+
for t, m, s in zip(img_inp, stanext_data_info.rgb_mean, stanext_data_info.rgb_stddev): t.add_(m) # inverse to transforms.color_normalize()
|
313 |
+
img_inp = img_inp.detach().cpu().numpy().transpose(1, 2, 0)
|
314 |
+
img_init = Image.fromarray(np.uint8(255*img_inp)).convert('RGB')
|
315 |
+
img_init.save(root_out_path_details + name + '_img_ainit.png')
|
316 |
+
# save ground truth silhouette (for visualization only, it is not used during the optimization)
|
317 |
+
target_img_silh = Image.fromarray(np.uint8(255*target_dict['silh'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
|
318 |
+
target_img_silh.save(root_out_path_details + name + '_target_silh.png')
|
319 |
+
# save the silhouette as predicted by the stacked hourglass
|
320 |
+
hg_img_silh = Image.fromarray(np.uint8(255*res['hg_silh_prep'][ind_img, :, :].detach().cpu().numpy())).convert('RGB')
|
321 |
+
hg_img_silh.save(root_out_path + name + '_hg_silh.png')
|
322 |
+
|
323 |
+
# initialize the variables over which we want to optimize
|
324 |
+
optimed_pose_6d = all_pose_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True)
|
325 |
+
optimed_orient_6d = all_orient_6d[ind_img, None, :, :].to(device).clone().detach().requires_grad_(True) # [1, 1, 6]
|
326 |
+
optimed_betas = res['betas'][ind_img, None, :].to(device).clone().detach().requires_grad_(True) # [1,30]
|
327 |
+
optimed_trans_xy = res['trans'][ind_img, None, :2].to(device).clone().detach().requires_grad_(True)
|
328 |
+
optimed_trans_z =res['trans'][ind_img, None, 2:3].to(device).clone().detach().requires_grad_(True)
|
329 |
+
optimed_camera_flength = res['flength'][ind_img, None, :].to(device).clone().detach().requires_grad_(True) # [1,1]
|
330 |
+
n_vert_comp = 2*smal.n_center + 3*smal.n_left
|
331 |
+
optimed_vert_off_compact = torch.tensor(np.zeros((batch_size, n_vert_comp)), dtype=torch.float,
|
332 |
+
device=device,
|
333 |
+
requires_grad=True)
|
334 |
+
assert len(logscale_part_list) == 7
|
335 |
+
new_betas_limb_lengths = res['betas_limbs'][ind_img, None, :]
|
336 |
+
optimed_betas_limbs = new_betas_limb_lengths.to(device).clone().detach().requires_grad_(True) # [1,7]
|
337 |
+
|
338 |
+
# define the optimizers
|
339 |
+
optimizer = torch.optim.SGD(
|
340 |
+
# [optimed_pose, optimed_trans_xy, optimed_betas, optimed_betas_limbs, optimed_orient, optimed_vert_off_compact],
|
341 |
+
[optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_pose_6d, optimed_orient_6d, optimed_betas, optimed_betas_limbs],
|
342 |
+
lr=5*1e-4, # 1e-3,
|
343 |
+
momentum=0.9)
|
344 |
+
optimizer_vshift = torch.optim.SGD(
|
345 |
+
[optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_pose_6d, optimed_orient_6d, optimed_betas, optimed_betas_limbs, optimed_vert_off_compact],
|
346 |
+
lr=1e-4, # 1e-4,
|
347 |
+
momentum=0.9)
|
348 |
+
nopose_optimizer = torch.optim.SGD(
|
349 |
+
# [optimed_pose, optimed_trans_xy, optimed_betas, optimed_betas_limbs, optimed_orient, optimed_vert_off_compact],
|
350 |
+
[optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_orient_6d, optimed_betas, optimed_betas_limbs],
|
351 |
+
lr=5*1e-4, # 1e-3,
|
352 |
+
momentum=0.9)
|
353 |
+
nopose_optimizer_vshift = torch.optim.SGD(
|
354 |
+
[optimed_camera_flength, optimed_trans_z, optimed_trans_xy, optimed_orient_6d, optimed_betas, optimed_betas_limbs, optimed_vert_off_compact],
|
355 |
+
lr=1e-4, # 1e-4,
|
356 |
+
momentum=0.9)
|
357 |
+
# define schedulers
|
358 |
+
patience = 5
|
359 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
360 |
+
optimizer,
|
361 |
+
mode='min',
|
362 |
+
factor=0.5,
|
363 |
+
verbose=0,
|
364 |
+
min_lr=1e-5,
|
365 |
+
patience=patience)
|
366 |
+
scheduler_vshift = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
367 |
+
optimizer_vshift,
|
368 |
+
mode='min',
|
369 |
+
factor=0.5,
|
370 |
+
verbose=0,
|
371 |
+
min_lr=1e-5,
|
372 |
+
patience=patience)
|
373 |
+
|
374 |
+
# set all loss values to 0
|
375 |
+
losses = reset_loss_values(losses)
|
376 |
+
|
377 |
+
# prepare all the target labels: keypoints, silhouette, ground contact, ...
|
378 |
+
with torch.no_grad():
|
379 |
+
thr_kp = 0.2
|
380 |
+
kp_weights = res['hg_keyp_scores']
|
381 |
+
kp_weights[res['hg_keyp_scores']<thr_kp] = 0
|
382 |
+
weights_resh = kp_weights[ind_img, None, :, :].reshape((-1)) # target_dict['tpts'][:, :, 2].reshape((-1))
|
383 |
+
keyp_w_resh = keypoint_weights.repeat((batch_size, 1)).reshape((-1))
|
384 |
+
# prepare predicted ground contact labels
|
385 |
+
sm = nn.Softmax(dim=1)
|
386 |
+
target_gc_class = sm(res['vertexwise_ground_contact'][ind_img, :, :])[None, :, 1] # values between 0 and 1
|
387 |
+
target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
|
388 |
+
target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
|
389 |
+
vert_colors = np.repeat(255*target_gc_class.detach().cpu().numpy()[0, :, None], 3, 1)
|
390 |
+
vert_colors[:, 2] = 255
|
391 |
+
faces_prep = smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
|
392 |
+
# prepare target silhouette and keypoints, from stacked hourglass predictions
|
393 |
+
target_hg_silh = res['hg_silh_prep'][ind_img, :, :].detach()
|
394 |
+
target_kp_resh = res['hg_keyp_256'][ind_img, None, :, :].reshape((-1, 2)).detach()
|
395 |
+
# find out if ground contact constraints should be used for the image at hand
|
396 |
+
# print('is flat: ' + str(res['isflat_prep'][ind_img]))
|
397 |
+
if res['isflat_prep'][ind_img] >= 0.5: # threshold should probably be set higher
|
398 |
+
isflat = [True]
|
399 |
+
else:
|
400 |
+
isflat = [False]
|
401 |
+
if target_gc_class_remeshed_prep.sum() > 3:
|
402 |
+
istouching = [True]
|
403 |
+
else:
|
404 |
+
istouching = [False]
|
405 |
+
ignore_pose_optimization = False
|
406 |
+
|
407 |
+
|
408 |
+
##########################################################################################################
|
409 |
+
# start optimizing for this image
|
410 |
+
n_iter = 301 # how many iterations are desired? (+1)
|
411 |
+
loop = tqdm(range(n_iter))
|
412 |
+
per_loop_lst = []
|
413 |
+
list_error_procrustes = []
|
414 |
+
for i in loop:
|
415 |
+
# for the first 150 iterations steps we don't allow vertex shifts
|
416 |
+
if i == 0:
|
417 |
+
current_i = 0
|
418 |
+
if ignore_pose_optimization:
|
419 |
+
current_optimizer = nopose_optimizer
|
420 |
+
else:
|
421 |
+
current_optimizer = optimizer
|
422 |
+
current_scheduler = scheduler
|
423 |
+
current_weight_name = 'weight'
|
424 |
+
# after 150 iteration steps we start with vertex shifts
|
425 |
+
elif i == 150:
|
426 |
+
current_i = 0
|
427 |
+
if ignore_pose_optimization:
|
428 |
+
current_optimizer = nopose_optimizer_vshift
|
429 |
+
else:
|
430 |
+
current_optimizer = optimizer_vshift
|
431 |
+
current_scheduler = scheduler_vshift
|
432 |
+
current_weight_name = 'weight_vshift'
|
433 |
+
# set up arap loss
|
434 |
+
if losses["arap"]['weight_vshift'] > 0.0:
|
435 |
+
with torch.no_grad():
|
436 |
+
torch_mesh_comparison = Meshes(smal_verts.detach(), faces_prep.detach())
|
437 |
+
arap_loss = Arap_Loss(meshes=torch_mesh_comparison, device=device)
|
438 |
+
# is there a laplacian loss similar as in coarse-to-fine?
|
439 |
+
if losses["lapctf"]['weight_vshift'] > 0.0:
|
440 |
+
torch_verts_comparison = smal_verts.detach().clone()
|
441 |
+
smal_model_type_downsampling = '39dogs_norm'
|
442 |
+
smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type_downsampling]['smal_model_path']).replace('.pkl', '_template.npz')
|
443 |
+
smal_downsampling_npz_path = os.path.join(root_smal_downsampling, smal_downsampling_npz_name)
|
444 |
+
data = np.load(smal_downsampling_npz_path, encoding='latin1', allow_pickle=True)
|
445 |
+
adjmat = data['A'][0]
|
446 |
+
laplacian_ctf = LaplacianCTF(adjmat, device=device)
|
447 |
+
else:
|
448 |
+
pass
|
449 |
+
|
450 |
+
|
451 |
+
current_optimizer.zero_grad()
|
452 |
+
|
453 |
+
# get 3d smal model
|
454 |
+
optimed_pose_with_glob = get_optimed_pose_with_glob(optimed_orient_6d, optimed_pose_6d)
|
455 |
+
optimed_trans = torch.cat((optimed_trans_xy, optimed_trans_z), dim=1)
|
456 |
+
smal_verts, keyp_3d, _ = smal(beta=optimed_betas, betas_limbs=optimed_betas_limbs, pose=optimed_pose_with_glob, vert_off_compact=optimed_vert_off_compact, trans=optimed_trans, keyp_conf='olive', get_skin=True)
|
457 |
+
|
458 |
+
# render silhouette and keypoints
|
459 |
+
pred_silh_images, pred_keyp_raw = silh_renderer(vertices=smal_verts, points=keyp_3d, faces=faces_prep, focal_lengths=optimed_camera_flength)
|
460 |
+
pred_keyp = pred_keyp_raw[:, :24, :]
|
461 |
+
|
462 |
+
# save silhouette reprojection visualization
|
463 |
+
if i==0:
|
464 |
+
img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
|
465 |
+
img_silh.save(root_out_path_details + name + '_silh_ainit.png')
|
466 |
+
my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
|
467 |
+
my_mesh_tri.export(root_out_path_details + name + '_res_ainit.obj')
|
468 |
+
|
469 |
+
# silhouette loss
|
470 |
+
diff_silh = torch.abs(pred_silh_images[0, 0, :, :] - target_hg_silh)
|
471 |
+
losses['silhouette']['value'] = diff_silh.mean()
|
472 |
+
|
473 |
+
# keypoint_loss
|
474 |
+
output_kp_resh = (pred_keyp[0, :, :]).reshape((-1, 2))
|
475 |
+
losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt() * \
|
476 |
+
weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
|
477 |
+
max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
|
478 |
+
# losses['keyp']['value'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
|
479 |
+
|
480 |
+
# pose priors on refined pose
|
481 |
+
losses['pose_legs_side']['value'] = leg_sideway_error(optimed_pose_with_glob)
|
482 |
+
losses['pose_legs_tors']['value'] = leg_torsion_error(optimed_pose_with_glob)
|
483 |
+
losses['pose_tail_side']['value'] = tail_sideway_error(optimed_pose_with_glob)
|
484 |
+
losses['pose_tail_tors']['value'] = tail_torsion_error(optimed_pose_with_glob)
|
485 |
+
losses['pose_spine_side']['value'] = spine_sideway_error(optimed_pose_with_glob)
|
486 |
+
losses['pose_spine_tors']['value'] = spine_torsion_error(optimed_pose_with_glob)
|
487 |
+
|
488 |
+
# ground contact loss
|
489 |
+
sel_verts = torch.index_select(smal_verts, dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((batch_size, remeshing_relevant_faces.shape[0], 3, 3))
|
490 |
+
verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
|
491 |
+
|
492 |
+
# gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
|
493 |
+
gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, isflat, istouching)
|
494 |
+
|
495 |
+
losses['gc_plane']['value'] = torch.mean(gc_errors_plane)
|
496 |
+
losses['gc_belowplane']['value'] = torch.mean(gc_errors_under_plane)
|
497 |
+
|
498 |
+
# edge length of the predicted mesh
|
499 |
+
if (losses["edge"][current_weight_name] + losses["normal"][ current_weight_name] + losses["laplacian"][ current_weight_name]) > 0:
|
500 |
+
torch_mesh = Meshes(smal_verts, faces_prep.detach())
|
501 |
+
losses["edge"]['value'] = mesh_edge_loss(torch_mesh)
|
502 |
+
# mesh normal consistency
|
503 |
+
losses["normal"]['value'] = mesh_normal_consistency(torch_mesh)
|
504 |
+
# mesh laplacian smoothing
|
505 |
+
losses["laplacian"]['value'] = mesh_laplacian_smoothing(torch_mesh, method="uniform")
|
506 |
+
|
507 |
+
# arap loss
|
508 |
+
if losses["arap"][current_weight_name] > 0.0:
|
509 |
+
torch_mesh = Meshes(smal_verts, faces_prep.detach())
|
510 |
+
losses["arap"]['value'] = arap_loss(torch_mesh)
|
511 |
+
|
512 |
+
# laplacian loss for comparison (from coarse-to-fine paper)
|
513 |
+
if losses["lapctf"][current_weight_name] > 0.0:
|
514 |
+
verts_refine = smal_verts
|
515 |
+
loss_almost_arap, loss_smooth = laplacian_ctf(verts_refine, torch_verts_comparison)
|
516 |
+
losses["lapctf"]['value'] = loss_almost_arap
|
517 |
+
|
518 |
+
# Weighted sum of the losses
|
519 |
+
total_loss = 0.0
|
520 |
+
for k in ['keyp', 'silhouette', 'pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_tors', 'pose_spine_side', 'gc_plane', 'gc_belowplane', 'edge', 'normal', 'laplacian', 'arap', 'lapctf']:
|
521 |
+
if losses[k][current_weight_name] > 0.0:
|
522 |
+
total_loss += losses[k]['value'] * losses[k][current_weight_name]
|
523 |
+
|
524 |
+
# calculate gradient and make optimization step
|
525 |
+
total_loss.backward(retain_graph=True) #
|
526 |
+
current_optimizer.step()
|
527 |
+
current_scheduler.step(total_loss)
|
528 |
+
loop.set_description(f"Body Fitting = {total_loss.item():.3f}")
|
529 |
+
|
530 |
+
# save the result three times (0, 150, 300)
|
531 |
+
if i % 150 == 0:
|
532 |
+
# save silhouette image
|
533 |
+
img_silh = Image.fromarray(np.uint8(255*pred_silh_images[0, 0, :, :].detach().cpu().numpy())).convert('RGB')
|
534 |
+
img_silh.save(root_out_path_details + name + '_silh_e' + format(i, '03d') + '.png')
|
535 |
+
# save image overlay
|
536 |
+
visualizations = silh_renderer.get_visualization_nograd(smal_verts, faces_prep, optimed_camera_flength, color=0)
|
537 |
+
pred_tex = visualizations[0, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
538 |
+
# out_path = root_out_path_details + name + '_tex_pred_e' + format(i, '03d') + '.png'
|
539 |
+
# plt.imsave(out_path, pred_tex)
|
540 |
+
input_image_np = img_inp.copy()
|
541 |
+
im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
|
542 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
543 |
+
im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
|
544 |
+
out_path = root_out_path + name + '_comp_pred_e' + format(i, '03d') + '.png'
|
545 |
+
plt.imsave(out_path, im_masked)
|
546 |
+
# save mesh
|
547 |
+
my_mesh_tri = trimesh.Trimesh(vertices=smal_verts[0, ...].detach().cpu().numpy(), faces=faces_prep[0, ...].detach().cpu().numpy(), process=False, maintain_order=True)
|
548 |
+
my_mesh_tri.visual.vertex_colors = vert_colors
|
549 |
+
my_mesh_tri.export(root_out_path + name + '_res_e' + format(i, '03d') + '.obj')
|
550 |
+
# save focal length (together with the mesh this is enough to create an overlay in blender)
|
551 |
+
out_file_flength = root_out_path_details + name + '_flength_e' + format(i, '03d') # + '.npz'
|
552 |
+
np.save(out_file_flength, optimed_camera_flength.detach().cpu().numpy())
|
553 |
+
current_i += 1
|
554 |
+
|
555 |
+
# prepare output mesh
|
556 |
+
mesh = my_mesh_tri # all_results[0]['mesh_posed']
|
557 |
+
mesh.apply_transform([[-1, 0, 0, 0],
|
558 |
+
[0, -1, 0, 0],
|
559 |
+
[0, 0, 1, 1],
|
560 |
+
[0, 0, 0, 1]])
|
561 |
+
result_path = os.path.join(save_imgs_path, test_name_list[0] + '_z')
|
562 |
+
mesh.export(file_obj=result_path + '.glb')
|
563 |
+
result_gltf = result_path + '.glb'
|
564 |
+
return result_gltf
|
565 |
+
|
566 |
+
|
567 |
+
|
568 |
+
|
569 |
+
|
570 |
+
# -------------------------------------------------------------------------------------------------------------------- #
|
571 |
+
|
572 |
+
|
573 |
+
def run_complete_inference(img_path_or_img, crop_choice):
|
574 |
+
# depending on crop_choice: run faster r-cnn or take the input image directly
|
575 |
+
if crop_choice == "input image is cropped":
|
576 |
+
if isinstance(img_path_or_img, str):
|
577 |
+
img = cv2.imread(img_path_or_img)
|
578 |
+
output_interm_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
579 |
+
else:
|
580 |
+
output_interm_image = img_path_or_img
|
581 |
+
output_interm_bbox = None
|
582 |
+
else:
|
583 |
+
output_interm_image, output_interm_bbox = run_bbox_inference(img_path_or_img.copy())
|
584 |
+
# run barc inference
|
585 |
+
result_gltf = run_bite_inference(img_path_or_img, output_interm_bbox)
|
586 |
+
# add white border to image for nicer alignment
|
587 |
+
output_interm_image_vis = np.concatenate((255*np.ones_like(output_interm_image), output_interm_image, 255*np.ones_like(output_interm_image)), axis=1)
|
588 |
+
return [result_gltf, result_gltf, output_interm_image_vis]
|
589 |
+
|
590 |
+
|
591 |
+
|
592 |
+
|
593 |
+
########################################################################################################################
|
594 |
+
|
595 |
+
# see: https://huggingface.co/spaces/radames/PIFu-Clothed-Human-Digitization/blob/main/PIFu/spaces.py
|
596 |
+
|
597 |
+
description = '''
|
598 |
+
# BITE
|
599 |
+
|
600 |
+
#### Project Page
|
601 |
+
* https://bite.is.tue.mpg.de/
|
602 |
+
|
603 |
+
#### Description
|
604 |
+
This is a demo for BITE (*B*eyond Priors for *I*mproved *T*hree-{D} Dog Pose *E*stimation).
|
605 |
+
You can either submit a cropped image or choose the option to run a pretrained Faster R-CNN in order to obtain a bounding box.
|
606 |
+
Please have a look at the examples below.
|
607 |
+
<details>
|
608 |
+
|
609 |
+
<summary>More</summary>
|
610 |
+
|
611 |
+
#### Citation
|
612 |
+
|
613 |
+
```
|
614 |
+
@inproceedings{bite2023rueegg,
|
615 |
+
title = {{BITE}: Beyond Priors for Improved Three-{D} Dog Pose Estimation},
|
616 |
+
author = {R\"uegg, Nadine and Tripathi, Shashank and Schindler, Konrad and Black, Michael J. and Zuffi, Silvia},
|
617 |
+
booktitle = {IEEE/CVF Conf.~on Computer Vision and Pattern Recognition (CVPR)},
|
618 |
+
pages = {8867-8876},
|
619 |
+
year = {2023},
|
620 |
+
}
|
621 |
+
```
|
622 |
+
|
623 |
+
#### Image Sources
|
624 |
+
* Stanford extra image dataset
|
625 |
+
* Images from google search engine
|
626 |
+
* https://www.dogtrainingnation.com/wp-content/uploads/2015/02/keep-dog-training-sessions-short.jpg
|
627 |
+
* https://thumbs.dreamstime.com/b/hund-und-seine-neue-hundeh%C3%BCtte-36757551.jpg
|
628 |
+
* https://www.mydearwhippet.com/wp-content/uploads/2021/04/whippet-temperament-2.jpg
|
629 |
+
* https://media.istockphoto.com/photos/ibizan-hound-at-the-shore-in-winter-picture-id1092705644?k=20&m=1092705644&s=612x612&w=0&h=ppwg92s9jI8GWnk22SOR_DWWNP8b2IUmLXSQmVey5Ss=
|
630 |
+
|
631 |
+
|
632 |
+
</details>
|
633 |
+
'''
|
634 |
+
|
635 |
+
|
636 |
+
|
637 |
+
|
638 |
+
|
639 |
+
|
640 |
+
example_images = sorted(glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.jpg')) + glob.glob(os.path.join(os.path.dirname(__file__), '../', 'datasets', 'test_image_crops', '*.png')))
|
641 |
+
random.shuffle(example_images)
|
642 |
+
# example_images.reverse()
|
643 |
+
# examples = [[img, "input image is cropped"] for img in example_images]
|
644 |
+
examples = []
|
645 |
+
for img in example_images:
|
646 |
+
if os.path.basename(img)[:2] == 'z_':
|
647 |
+
examples.append([img, "use Faster R-CNN to get a bounding box"])
|
648 |
+
else:
|
649 |
+
examples.append([img, "input image is cropped"])
|
650 |
+
|
651 |
+
demo = gr.Interface(
|
652 |
+
fn=run_complete_inference,
|
653 |
+
description=description,
|
654 |
+
# inputs=gr.Image(type="filepath", label="Input Image"),
|
655 |
+
inputs=[gr.Image(label="Input Image"),
|
656 |
+
gr.Radio(["input image is cropped", "use Faster R-CNN to get a bounding box"], value="use Faster R-CNN to get a bounding box", label="Crop Choice"),
|
657 |
+
],
|
658 |
+
outputs=[
|
659 |
+
gr.Model3D(
|
660 |
+
clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model"),
|
661 |
+
gr.File(label="Download 3D Model"),
|
662 |
+
gr.Image(label="Bounding Box (Faster R-CNN prediction)"),
|
663 |
+
|
664 |
+
],
|
665 |
+
examples=examples,
|
666 |
+
thumbnail="bite_thumbnail.png",
|
667 |
+
allow_flagging="never",
|
668 |
+
cache_examples=True,
|
669 |
+
examples_per_page=14,
|
670 |
+
)
|
671 |
+
|
672 |
+
demo.launch(share=True)
|
src/__init__.py
ADDED
File without changes
|
src/bps_2d/bps_for_segmentation.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# code idea from https://github.com/sergeyprokudin/bps
|
3 |
+
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import time
|
8 |
+
import scipy
|
9 |
+
import scipy.spatial
|
10 |
+
import pymp
|
11 |
+
|
12 |
+
|
13 |
+
#####################
|
14 |
+
QUERY_POINTS = np.asarray([30, 34, 31, 55, 29, 84, 35, 108, 34, 145, 29, 171, 27,
|
15 |
+
196, 29, 228, 58, 35, 61, 55, 57, 83, 56, 109, 63, 148, 58, 164, 57, 197, 60,
|
16 |
+
227, 81, 26, 87, 58, 85, 87, 89, 117, 86, 142, 89, 172, 84, 197, 88, 227, 113,
|
17 |
+
32, 116, 58, 112, 88, 118, 113, 109, 147, 114, 173, 119, 201, 113, 229, 139,
|
18 |
+
29, 141, 59, 142, 93, 139, 117, 146, 147, 141, 173, 142, 201, 143, 227, 170,
|
19 |
+
26, 173, 59, 166, 90, 174, 117, 176, 141, 169, 175, 167, 198, 172, 227, 198,
|
20 |
+
30, 195, 59, 204, 85, 198, 116, 195, 140, 198, 175, 194, 193, 199, 227, 221,
|
21 |
+
26, 223, 57, 227, 83, 227, 113, 227, 140, 226, 173, 230, 196, 228, 229]).reshape((64, 2))
|
22 |
+
#####################
|
23 |
+
|
24 |
+
class SegBPS():
|
25 |
+
|
26 |
+
def __init__(self, query_points=QUERY_POINTS, size=256):
|
27 |
+
self.size = size
|
28 |
+
self.query_points = query_points
|
29 |
+
row, col = np.indices((self.size, self.size))
|
30 |
+
self.indices_rc = np.stack((row, col), axis=2) # (256, 256, 2)
|
31 |
+
self.pts_aranged = np.arange(64)
|
32 |
+
return
|
33 |
+
|
34 |
+
def _do_kdtree(self, combined_x_y_arrays, points):
|
35 |
+
# see https://stackoverflow.com/questions/10818546/finding-index-of-nearest-
|
36 |
+
# point-in-numpy-arrays-of-x-and-y-coordinates
|
37 |
+
mytree = scipy.spatial.cKDTree(combined_x_y_arrays)
|
38 |
+
dist, indexes = mytree.query(points)
|
39 |
+
return indexes
|
40 |
+
|
41 |
+
def calculate_bps_points(self, seg, thr=0.5, vis=False, out_path=None):
|
42 |
+
# seg: input segmentation image of shape (256, 256) with values between 0 and 1
|
43 |
+
query_val = seg[self.query_points[:, 0], self.query_points[:, 1]]
|
44 |
+
pts_fg = self.pts_aranged[query_val>=thr]
|
45 |
+
pts_bg = self.pts_aranged[query_val<thr]
|
46 |
+
candidate_inds_bg = self.indices_rc[seg<thr]
|
47 |
+
candidate_inds_fg = self.indices_rc[seg>=thr]
|
48 |
+
if candidate_inds_bg.shape[0] == 0:
|
49 |
+
candidate_inds_bg = np.ones((1, 2)) * 128 # np.zeros((1, 2))
|
50 |
+
if candidate_inds_fg.shape[0] == 0:
|
51 |
+
candidate_inds_fg = np.ones((1, 2)) * 128 # np.zeros((1, 2))
|
52 |
+
# calculate nearest points
|
53 |
+
all_nearest_points = np.zeros((64, 2))
|
54 |
+
all_nearest_points[pts_fg, :] = candidate_inds_bg[self._do_kdtree(candidate_inds_bg, self.query_points[pts_fg, :]), :]
|
55 |
+
all_nearest_points[pts_bg, :] = candidate_inds_fg[self._do_kdtree(candidate_inds_fg, self.query_points[pts_bg, :]), :]
|
56 |
+
all_nearest_points_01 = all_nearest_points / 255.
|
57 |
+
if vis:
|
58 |
+
self.visualize_result(seg, all_nearest_points, out_path=out_path)
|
59 |
+
return all_nearest_points_01
|
60 |
+
|
61 |
+
def calculate_bps_points_batch(self, seg_batch, thr=0.5, vis=False, out_path=None):
|
62 |
+
# seg_batch: input segmentation image of shape (bs, 256, 256) with values between 0 and 1
|
63 |
+
bs = seg_batch.shape[0]
|
64 |
+
all_nearest_points_01_batch = np.zeros((bs, self.query_points.shape[0], 2))
|
65 |
+
for ind in range(0, bs): # 0.25
|
66 |
+
seg = seg_batch[ind, :, :]
|
67 |
+
all_nearest_points_01 = self.calculate_bps_points(seg, thr=thr, vis=vis, out_path=out_path)
|
68 |
+
all_nearest_points_01_batch[ind, :, :] = all_nearest_points_01
|
69 |
+
return all_nearest_points_01_batch
|
70 |
+
|
71 |
+
def visualize_result(self, seg, all_nearest_points, out_path=None):
|
72 |
+
import matplotlib as mpl
|
73 |
+
mpl.use('Agg')
|
74 |
+
import matplotlib.pyplot as plt
|
75 |
+
# img: (256, 256, 3)
|
76 |
+
img = (np.stack((seg, seg, seg), axis=2) * 155).astype(np.int)
|
77 |
+
if out_path is None:
|
78 |
+
ind_img = 0
|
79 |
+
out_path = '../test_img' + str(ind_img) + '.png'
|
80 |
+
fig, ax = plt.subplots()
|
81 |
+
plt.imshow(img)
|
82 |
+
plt.gca().set_axis_off()
|
83 |
+
plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, hspace = 0, wspace = 0)
|
84 |
+
plt.margins(0,0)
|
85 |
+
ratio_in_out = 1 # 255
|
86 |
+
for idx, (y, x) in enumerate(self.query_points):
|
87 |
+
x = int(x*ratio_in_out)
|
88 |
+
y = int(y*ratio_in_out)
|
89 |
+
plt.scatter([x], [y], marker="x", s=50)
|
90 |
+
x2 = int(all_nearest_points[idx, 1])
|
91 |
+
y2 = int(all_nearest_points[idx, 0])
|
92 |
+
plt.scatter([x2], [y2], marker="o", s=50)
|
93 |
+
plt.plot([x, x2], [y, y2])
|
94 |
+
plt.savefig(out_path, bbox_inches='tight', pad_inches=0)
|
95 |
+
plt.close()
|
96 |
+
return
|
97 |
+
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
ind_img = 2 # 4
|
104 |
+
path_seg_top = '...../pytorch-stacked-hourglass/results/dogs_hg8_ks_24_v1/test/'
|
105 |
+
path_seg = os.path.join(path_seg_top, 'seg_big_' + str(ind_img) + '.png')
|
106 |
+
img = np.asarray(Image.open(path_seg))
|
107 |
+
# min is 0.004, max is 0.9
|
108 |
+
# low values are background, high values are foreground
|
109 |
+
seg = img[:, :, 1] / 255.
|
110 |
+
# calculate points
|
111 |
+
bps = SegBPS()
|
112 |
+
bps.calculate_bps_points(seg, thr=0.5, vis=False, out_path=None)
|
113 |
+
|
114 |
+
|
src/combined_model/__init__.py
ADDED
File without changes
|
src/combined_model/helper.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.backends.cudnn
|
5 |
+
import torch.nn.parallel
|
6 |
+
from tqdm import tqdm
|
7 |
+
import os
|
8 |
+
import pathlib
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import trimesh
|
14 |
+
|
15 |
+
import sys
|
16 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
17 |
+
from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
|
18 |
+
from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
|
19 |
+
from metrics.metrics import Metrics
|
20 |
+
from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS
|
21 |
+
|
22 |
+
|
23 |
+
# GOAL: have all the functions from the validation and visual epoch together
|
24 |
+
|
25 |
+
|
26 |
+
'''
|
27 |
+
save_imgs_path = ...
|
28 |
+
prefix = ''
|
29 |
+
input # this is the image
|
30 |
+
data_info
|
31 |
+
target_dict
|
32 |
+
render_all
|
33 |
+
model
|
34 |
+
|
35 |
+
|
36 |
+
vertices_smal = output_reproj['vertices_smal']
|
37 |
+
flength = output_unnorm['flength']
|
38 |
+
hg_keyp_norm = output['keypoints_norm']
|
39 |
+
hg_keyp_scores = output['keypoints_scores']
|
40 |
+
betas = output_reproj['betas']
|
41 |
+
betas_limbs = output_reproj['betas_limbs']
|
42 |
+
zz = output_reproj['z']
|
43 |
+
pose_rotmat = output_unnorm['pose_rotmat']
|
44 |
+
trans = output_unnorm['trans']
|
45 |
+
pred_keyp = output_reproj['keyp_2d']
|
46 |
+
pred_silh = output_reproj['silh']
|
47 |
+
'''
|
48 |
+
|
49 |
+
#################################################
|
50 |
+
|
51 |
+
def eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=False):
|
52 |
+
device = input.device
|
53 |
+
curr_batch_size = input.shape[0]
|
54 |
+
# render predicted 3d models
|
55 |
+
visualizations = model.render_vis_nograd(vertices=vertices_smal,
|
56 |
+
focal_lengths=flength,
|
57 |
+
color=0) # color=2)
|
58 |
+
for ind_img in range(len(target_dict['index'])):
|
59 |
+
try:
|
60 |
+
# import pdb; pdb.set_trace()
|
61 |
+
if test_name_list is not None:
|
62 |
+
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
|
63 |
+
img_name = img_name.split('.')[0]
|
64 |
+
else:
|
65 |
+
img_name = str(index) + '_' + str(ind_img)
|
66 |
+
# save image with predicted keypoints
|
67 |
+
out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
|
68 |
+
pred_unp = (hg_keyp_norm[ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
|
69 |
+
pred_unp_maxval = hg_keyp_scores[ind_img, :, :]
|
70 |
+
pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
|
71 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
72 |
+
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
|
73 |
+
# save predicted 3d model (front view)
|
74 |
+
pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
75 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
76 |
+
out_path = save_imgs_path + '/' + prefix + 'tex_pred_' + img_name + '.png'
|
77 |
+
plt.imsave(out_path, pred_tex)
|
78 |
+
input_image = input[ind_img, :, :, :].detach().clone()
|
79 |
+
for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
|
80 |
+
input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
|
81 |
+
im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
|
82 |
+
im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
|
83 |
+
out_path = save_imgs_path + '/' + prefix + 'comp_pred_' + img_name + '.png'
|
84 |
+
plt.imsave(out_path, im_masked)
|
85 |
+
# save predicted 3d model (side view)
|
86 |
+
vertices_cent = vertices_smal - vertices_smal.mean(dim=1)[:, None, :]
|
87 |
+
roll = np.pi / 2 * torch.ones(1).float().to(device)
|
88 |
+
pitch = np.pi / 2 * torch.ones(1).float().to(device)
|
89 |
+
tensor_0 = torch.zeros(1).float().to(device)
|
90 |
+
tensor_1 = torch.ones(1).float().to(device)
|
91 |
+
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
|
92 |
+
RY = torch.stack([
|
93 |
+
torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
|
94 |
+
torch.stack([tensor_0, tensor_1, tensor_0]),
|
95 |
+
torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
|
96 |
+
vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3))
|
97 |
+
vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
|
98 |
+
|
99 |
+
visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
|
100 |
+
focal_lengths=flength,
|
101 |
+
color=0) # 2)
|
102 |
+
pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
103 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
104 |
+
out_path = save_imgs_path + '/' + prefix + 'rot_tex_pred_' + img_name + '.png'
|
105 |
+
plt.imsave(out_path, pred_tex)
|
106 |
+
if render_all:
|
107 |
+
# save input image
|
108 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
109 |
+
out_path = save_imgs_path + '/image_' + img_name + '.png'
|
110 |
+
save_input_image(inp_img, out_path)
|
111 |
+
# save mesh
|
112 |
+
V_posed = vertices_smal[ind_img, :, :].detach().cpu().numpy()
|
113 |
+
Faces = model.smal.f
|
114 |
+
mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
|
115 |
+
mesh_posed.export(save_imgs_path + '/' + prefix + 'mesh_posed_' + img_name + '.obj')
|
116 |
+
except:
|
117 |
+
print('dont save an image')
|
118 |
+
|
119 |
+
############
|
120 |
+
|
121 |
+
def eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh, progress=None, skip_pck_and_iou=False):
|
122 |
+
preds = {}
|
123 |
+
preds['betas'] = betas.cpu().detach().numpy()
|
124 |
+
preds['betas_limbs'] = betas_limbs.cpu().detach().numpy()
|
125 |
+
preds['z'] = zz.cpu().detach().numpy()
|
126 |
+
preds['pose_rotmat'] = pose_rotmat.cpu().detach().numpy()
|
127 |
+
preds['flength'] = flength.cpu().detach().numpy()
|
128 |
+
preds['trans'] = trans.cpu().detach().numpy()
|
129 |
+
preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1))
|
130 |
+
img_names = []
|
131 |
+
for ind_img2 in range(0, betas.shape[0]):
|
132 |
+
if test_name_list is not None:
|
133 |
+
img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_')
|
134 |
+
img_name2 = img_name2.split('.')[0]
|
135 |
+
else:
|
136 |
+
img_name2 = str(index) + '_' + str(ind_img2)
|
137 |
+
img_names.append(img_name2)
|
138 |
+
preds['image_names'] = img_names
|
139 |
+
if not skip_pck_and_iou:
|
140 |
+
# prepare keypoints for PCK calculation - predicted as well as ground truth
|
141 |
+
# pred_keyp = output_reproj['keyp_2d'] # 256
|
142 |
+
gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1)
|
143 |
+
# gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1
|
144 |
+
gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm
|
145 |
+
# prepare silhouette for IoU calculation - predicted as well as ground truth
|
146 |
+
has_seg = target_dict['has_seg']
|
147 |
+
img_border_mask = target_dict['img_border_mask'][:, 0, :, :]
|
148 |
+
gtseg = target_dict['silh']
|
149 |
+
synth_silhouettes = pred_silh[:, 0, :, :] # output_reproj['silh']
|
150 |
+
synth_silhouettes[synth_silhouettes>0.5] = 1
|
151 |
+
synth_silhouettes[synth_silhouettes<0.5] = 0
|
152 |
+
# calculate PCK as well as IoU (similar to WLDO)
|
153 |
+
preds['acc_PCK'] = Metrics.PCK(
|
154 |
+
pred_keyp, gt_keypoints,
|
155 |
+
gtseg, has_seg, idxs=EVAL_KEYPOINTS,
|
156 |
+
thresh_range=[pck_thresh], # [0.15],
|
157 |
+
)
|
158 |
+
preds['acc_IOU'] = Metrics.IOU(
|
159 |
+
synth_silhouettes, gtseg,
|
160 |
+
img_border_mask, mask=has_seg
|
161 |
+
)
|
162 |
+
for group, group_kps in KEYPOINT_GROUPS.items():
|
163 |
+
preds[f'{group}_PCK'] = Metrics.PCK(
|
164 |
+
pred_keyp, gt_keypoints, gtseg, has_seg,
|
165 |
+
thresh_range=[pck_thresh], # [0.15],
|
166 |
+
idxs=group_kps
|
167 |
+
)
|
168 |
+
return preds
|
169 |
+
|
170 |
+
|
171 |
+
# preds['acc_PCK'] = Metrics.PCK(pred_keyp, gt_keypoints, gtseg, has_seg, idxs=EVAL_KEYPOINTS, thresh_range=[pck_thresh])
|
172 |
+
# preds['acc_IOU'] = Metrics.IOU(synth_silhouettes, gtseg, img_border_mask, mask=has_seg)
|
173 |
+
#############################
|
174 |
+
|
175 |
+
def eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size, skip_pck_and_iou=False):
|
176 |
+
if not skip_pck_and_iou:
|
177 |
+
if not (preds['acc_PCK'].data.cpu().numpy().shape == (summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size]).shape):
|
178 |
+
import pdb; pdb.set_trace()
|
179 |
+
summary['pck'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
|
180 |
+
summary['acc_sil_2d'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
|
181 |
+
for part in summary['pck_by_part']:
|
182 |
+
summary['pck_by_part'][part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
|
183 |
+
summary['betas'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas']
|
184 |
+
summary['betas_limbs'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs']
|
185 |
+
summary['z'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z']
|
186 |
+
summary['pose_rotmat'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat']
|
187 |
+
summary['flength'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength']
|
188 |
+
summary['trans'][my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans']
|
189 |
+
summary['breed_indices'][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index']
|
190 |
+
summary['image_names'].extend(preds['image_names'])
|
191 |
+
return
|
192 |
+
|
193 |
+
|
194 |
+
def get_triangle_faces_from_pyvista_poly(poly):
|
195 |
+
"""Fetch all triangle faces."""
|
196 |
+
stream = poly.faces
|
197 |
+
tris = []
|
198 |
+
i = 0
|
199 |
+
while i < len(stream):
|
200 |
+
n = stream[i]
|
201 |
+
if n != 3:
|
202 |
+
i += n + 1
|
203 |
+
continue
|
204 |
+
stop = i + n + 1
|
205 |
+
tris.append(stream[i+1:stop])
|
206 |
+
i = stop
|
207 |
+
return np.array(tris)
|
src/combined_model/helper3.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def get_triangle_faces_from_pyvista_poly(poly):
|
5 |
+
"""Fetch all triangle faces."""
|
6 |
+
stream = poly.faces
|
7 |
+
tris = []
|
8 |
+
i = 0
|
9 |
+
while i < len(stream):
|
10 |
+
n = stream[i]
|
11 |
+
if n != 3:
|
12 |
+
i += n + 1
|
13 |
+
continue
|
14 |
+
stop = i + n + 1
|
15 |
+
tris.append(stream[i+1:stop])
|
16 |
+
i = stop
|
17 |
+
return np.array(tris)
|
src/combined_model/loss_image_to_3d_refinement.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pickle as pkl
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
|
10 |
+
# from priors.pose_prior_35 import Prior
|
11 |
+
# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior
|
12 |
+
from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior
|
13 |
+
from priors.shape_prior import ShapePrior
|
14 |
+
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa, geodesic_loss_R
|
15 |
+
from combined_model.loss_utils.loss_utils import leg_sideway_error, leg_torsion_error, tail_sideway_error, tail_torsion_error, spine_torsion_error, spine_sideway_error
|
16 |
+
from combined_model.loss_utils.loss_utils_gc import LossGConMesh, calculate_plane_errors_batch
|
17 |
+
|
18 |
+
from priors.shape_prior import ShapePrior
|
19 |
+
from configs.SMAL_configs import SMAL_MODEL_CONFIG
|
20 |
+
|
21 |
+
from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss
|
22 |
+
|
23 |
+
|
24 |
+
class LossRef(torch.nn.Module):
|
25 |
+
def __init__(self, smal_model_type, data_info, nf_version=None):
|
26 |
+
super(LossRef, self).__init__()
|
27 |
+
self.criterion_regr = torch.nn.MSELoss() # takes the mean
|
28 |
+
self.criterion_class = torch.nn.CrossEntropyLoss()
|
29 |
+
|
30 |
+
class_weights_isflat = torch.tensor([12, 2])
|
31 |
+
self.criterion_class_isflat = torch.nn.CrossEntropyLoss(weight=class_weights_isflat)
|
32 |
+
self.criterion_l1 = torch.nn.L1Loss()
|
33 |
+
self.geodesic_loss = geodesic_loss_R(reduction='mean')
|
34 |
+
self.gc_loss_on_mesh = LossGConMesh()
|
35 |
+
self.data_info = data_info
|
36 |
+
self.smal_model_type = smal_model_type
|
37 |
+
self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :])
|
38 |
+
# if nf_version is not None:
|
39 |
+
# self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version)
|
40 |
+
|
41 |
+
self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path']
|
42 |
+
self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov
|
43 |
+
|
44 |
+
remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
|
45 |
+
with open(remeshing_path, 'rb') as fp:
|
46 |
+
self.remeshing_dict = pkl.load(fp)
|
47 |
+
self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long)
|
48 |
+
self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32)
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
# load 3d data for the unity dogs (an optional shape prior for 11 breeds)
|
53 |
+
self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs']
|
54 |
+
if self.unity_smal_shape_prior_dogs is not None:
|
55 |
+
self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type)
|
56 |
+
else:
|
57 |
+
self.dog_betas_unity = None
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
def forward(self, output_ref, output_ref_comp, target_dict, weight_dict_ref):
|
66 |
+
# output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image']
|
67 |
+
# target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight']
|
68 |
+
batch_size = output_ref['keyp_2d'].shape[0]
|
69 |
+
loss_dict_temp = {}
|
70 |
+
|
71 |
+
# loss on reprojected keypoints
|
72 |
+
output_kp_resh = (output_ref['keyp_2d']).reshape((-1, 2))
|
73 |
+
target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2))
|
74 |
+
weights_resh = target_dict['tpts'][:, :, 2].reshape((-1))
|
75 |
+
keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1))
|
76 |
+
loss_dict_temp['keyp_ref'] = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
|
77 |
+
max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
|
78 |
+
|
79 |
+
# loss on reprojected silhouette
|
80 |
+
assert output_ref['silh'].shape == (target_dict['silh'][:, None, :, :]).shape
|
81 |
+
silh_loss_type = 'default'
|
82 |
+
if silh_loss_type == 'default':
|
83 |
+
with torch.no_grad():
|
84 |
+
thr_silh = 20
|
85 |
+
diff = torch.norm(output_kp_resh - target_kp_resh, dim=1)
|
86 |
+
diff_x = diff.reshape((batch_size, -1))
|
87 |
+
weights_resh_x = weights_resh.reshape((batch_size, -1))
|
88 |
+
unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6)
|
89 |
+
loss_silh_bs = ((output_ref['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_ref['silh'].shape[2]*output_ref['silh'].shape[3])
|
90 |
+
loss_dict_temp['silh_ref'] = loss_silh_bs[unweighted_kp_mean_dist<thr_silh].sum() / batch_size
|
91 |
+
else:
|
92 |
+
print('silh_loss_type: ' + silh_loss_type)
|
93 |
+
raise ValueError
|
94 |
+
|
95 |
+
# regularization: losses on difference between previous prediction and refinement
|
96 |
+
loss_dict_temp['reg_trans'] = self.criterion_l1(output_ref_comp['ref_trans_notnorm'], output_ref_comp['old_trans_notnorm'].detach()) * 3
|
97 |
+
loss_dict_temp['reg_flength'] = self.criterion_l1(output_ref_comp['ref_flength_notnorm'], output_ref_comp['old_flength_notnorm'].detach()) * 1
|
98 |
+
loss_dict_temp['reg_pose'] = self.geodesic_loss(output_ref_comp['ref_pose_rotmat'], output_ref_comp['old_pose_rotmat'].detach()) * 35 * 6
|
99 |
+
|
100 |
+
# pose priors on refined pose
|
101 |
+
loss_dict_temp['pose_legs_side'] = leg_sideway_error(output_ref['pose_rotmat'])
|
102 |
+
loss_dict_temp['pose_legs_tors'] = leg_torsion_error(output_ref['pose_rotmat'])
|
103 |
+
loss_dict_temp['pose_tail_side'] = tail_sideway_error(output_ref['pose_rotmat'])
|
104 |
+
loss_dict_temp['pose_tail_tors'] = tail_torsion_error(output_ref['pose_rotmat'])
|
105 |
+
loss_dict_temp['pose_spine_side'] = spine_sideway_error(output_ref['pose_rotmat'])
|
106 |
+
loss_dict_temp['pose_spine_tors'] = spine_torsion_error(output_ref['pose_rotmat'])
|
107 |
+
|
108 |
+
# loss to predict ground contact per vertex
|
109 |
+
# import pdb; pdb.set_trace()
|
110 |
+
if 'gc_vertexwise' in weight_dict_ref.keys():
|
111 |
+
# import pdb; pdb.set_trace()
|
112 |
+
device = output_ref['vertexwise_ground_contact'].device
|
113 |
+
pred_gc = output_ref['vertexwise_ground_contact']
|
114 |
+
loss_dict_temp['gc_vertexwise'] = self.gc_loss_on_mesh(pred_gc, target_dict['gc'].to(device=device, dtype=torch.long), target_dict['has_gc'], loss_type_gcmesh='ce')
|
115 |
+
|
116 |
+
keep_smal_mesh = False
|
117 |
+
if 'gc_plane' in weight_dict_ref.keys():
|
118 |
+
if weight_dict_ref['gc_plane'] > 0:
|
119 |
+
if keep_smal_mesh:
|
120 |
+
target_gc_class = target_dict['gc'][:, :, 0]
|
121 |
+
gc_errors_plane = calculate_plane_errors_batch(output_ref['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching'])
|
122 |
+
loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane)
|
123 |
+
else: # use a uniformly sampled mesh
|
124 |
+
target_gc_class = target_dict['gc'][:, :, 0]
|
125 |
+
device = output_ref['vertices_smal'].device
|
126 |
+
remeshing_relevant_faces = self.remeshing_relevant_faces.to(device)
|
127 |
+
remeshing_relevant_barys = self.remeshing_relevant_barys.to(device)
|
128 |
+
|
129 |
+
bs = output_ref['vertices_smal'].shape[0]
|
130 |
+
# verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_ref['vertices_smal'][:, self.remeshing_relevant_faces])
|
131 |
+
# sel_verts_comparison = output_ref['vertices_smal'][:, self.remeshing_relevant_faces]
|
132 |
+
# verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison)
|
133 |
+
sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
|
134 |
+
verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
|
135 |
+
target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32))
|
136 |
+
target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
|
137 |
+
gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
|
138 |
+
loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane)
|
139 |
+
loss_dict_temp['gc_blowplane'] = torch.mean(gc_errors_under_plane)
|
140 |
+
|
141 |
+
# error on classification if the ground plane is flat
|
142 |
+
if 'gc_isflat' in weight_dict_ref.keys():
|
143 |
+
# import pdb; pdb.set_trace()
|
144 |
+
self.criterion_class_isflat.to(device)
|
145 |
+
loss_dict_temp['gc_isflat'] = self.criterion_class(output_ref['isflat'], target_dict['isflat'].to(device))
|
146 |
+
|
147 |
+
# if we refine the shape WITHIN the refinement newtork (shaperef_type is not inexistent)
|
148 |
+
# shape regularization
|
149 |
+
# 'smal': loss on betas (pca coefficients), betas should be close to 0
|
150 |
+
# 'limbs...' loss on selected betas_limbs
|
151 |
+
device = output_ref_comp['ref_trans_notnorm'].device
|
152 |
+
loss_shape_weighted_list = [torch.zeros((1), device=device).mean()]
|
153 |
+
if 'shape_options' in weight_dict_ref.keys():
|
154 |
+
for ind_sp, sp in enumerate(weight_dict_ref['shape_options']):
|
155 |
+
weight_sp = weight_dict_ref['shape'][ind_sp]
|
156 |
+
# self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
|
157 |
+
if sp == 'smal':
|
158 |
+
loss_shape_tmp = self.shape_prior(output_ref['betas'])
|
159 |
+
elif sp == 'limbs':
|
160 |
+
loss_shape_tmp = torch.mean((output_ref['betas_limbs'])**2)
|
161 |
+
elif sp == 'limbs7':
|
162 |
+
limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2]
|
163 |
+
limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device)
|
164 |
+
loss_shape_tmp = torch.mean((output_ref['betas_limbs'] * limb_coeffs[None, :])**2)
|
165 |
+
else:
|
166 |
+
raise NotImplementedError
|
167 |
+
loss_shape_weighted_list.append(weight_sp * loss_shape_tmp)
|
168 |
+
loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum()
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
# 3D loss for dogs for which we have a unity model or toy figure
|
175 |
+
loss_dict_temp['models3d'] = torch.zeros((1), device=device).mean().to(output_ref['betas'].device)
|
176 |
+
if 'models3d' in weight_dict_ref.keys():
|
177 |
+
if weight_dict_ref['models3d'] > 0:
|
178 |
+
assert (self.dog_betas_unity is not None)
|
179 |
+
if weight_dict_ref['models3d'] > 0:
|
180 |
+
for ind_dog in range(target_dict['breed_index'].shape[0]):
|
181 |
+
breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy())
|
182 |
+
if breed_index in self.dog_betas_unity.keys():
|
183 |
+
betas_target = self.dog_betas_unity[breed_index][:output_ref['betas'].shape[1]].to(output_ref['betas'].device)
|
184 |
+
betas_output = output_ref['betas'][ind_dog, :]
|
185 |
+
betas_limbs_output = output_ref['betas_limbs'][ind_dog, :]
|
186 |
+
loss_dict_temp['models3d'] += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_ref['betas'].shape[1] + output_ref['betas_limbs'].shape[1])
|
187 |
+
else:
|
188 |
+
weight_dict_ref['models3d'] = 0.0
|
189 |
+
else:
|
190 |
+
weight_dict_ref['models3d'] = 0.0
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
# weight the losses
|
203 |
+
loss = torch.zeros((1)).mean().to(device=output_ref['keyp_2d'].device, dtype=output_ref['keyp_2d'].dtype)
|
204 |
+
loss_dict = {}
|
205 |
+
for loss_name in weight_dict_ref.keys():
|
206 |
+
if not loss_name in ['shape', 'shape_options']:
|
207 |
+
if weight_dict_ref[loss_name] > 0:
|
208 |
+
loss_weighted = loss_dict_temp[loss_name] * weight_dict_ref[loss_name]
|
209 |
+
loss_dict[loss_name] = loss_weighted.item()
|
210 |
+
loss += loss_weighted
|
211 |
+
loss += loss_shape_weighted
|
212 |
+
loss_dict['loss'] = loss.item()
|
213 |
+
|
214 |
+
return loss, loss_dict
|
215 |
+
|
216 |
+
|
src/combined_model/loss_image_to_3d_withbreedrel.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import pickle as pkl
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', 'src'))
|
10 |
+
# from priors.pose_prior_35 import Prior
|
11 |
+
# from priors.tiger_pose_prior.tiger_pose_prior import GaussianMixturePrior
|
12 |
+
from priors.normalizing_flow_prior.normalizing_flow_prior import NormalizingFlowPrior
|
13 |
+
from priors.shape_prior import ShapePrior
|
14 |
+
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, batch_rot2aa
|
15 |
+
# from configs.SMAL_configs import SMAL_MODEL_DATA_PATH, UNITY_SMAL_SHAPE_PRIOR_DOGS, SMAL_MODEL_TYPE
|
16 |
+
from configs.SMAL_configs import SMAL_MODEL_CONFIG
|
17 |
+
|
18 |
+
from priors.helper_3dcgmodel_loss import load_dog_betas_for_3dcgmodel_loss
|
19 |
+
from combined_model.loss_utils.loss_utils_gc import calculate_plane_errors_batch
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
class Loss(torch.nn.Module):
|
24 |
+
def __init__(self, smal_model_type, data_info, nf_version=None):
|
25 |
+
super(Loss, self).__init__()
|
26 |
+
self.criterion_regr = torch.nn.MSELoss() # takes the mean
|
27 |
+
self.criterion_class = torch.nn.CrossEntropyLoss()
|
28 |
+
self.data_info = data_info
|
29 |
+
self.register_buffer('keypoint_weights', torch.tensor(data_info.keypoint_weights)[None, :])
|
30 |
+
self.l_anchor = None
|
31 |
+
self.l_pos = None
|
32 |
+
self.l_neg = None
|
33 |
+
self.smal_model_type = smal_model_type
|
34 |
+
self.smal_model_data_path = SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path']
|
35 |
+
self.unity_smal_shape_prior_dogs = SMAL_MODEL_CONFIG[self.smal_model_type]['unity_smal_shape_prior_dogs']
|
36 |
+
|
37 |
+
if nf_version is not None:
|
38 |
+
self.normalizing_flow_pose_prior = NormalizingFlowPrior(nf_version=nf_version)
|
39 |
+
self.shape_prior = ShapePrior(self.smal_model_data_path) # here we just need mean and cov
|
40 |
+
self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1)
|
41 |
+
|
42 |
+
# load 3d data for the unity dogs (an optional shape prior for 11 breeds)
|
43 |
+
if self.unity_smal_shape_prior_dogs is not None:
|
44 |
+
self.dog_betas_unity = load_dog_betas_for_3dcgmodel_loss(self.unity_smal_shape_prior_dogs, self.smal_model_type)
|
45 |
+
else:
|
46 |
+
self.dog_betas_unity = None
|
47 |
+
|
48 |
+
remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
|
49 |
+
with open(remeshing_path, 'rb') as fp:
|
50 |
+
self.remeshing_dict = pkl.load(fp)
|
51 |
+
self.remeshing_relevant_faces = torch.tensor(self.remeshing_dict['smal_faces'][self.remeshing_dict['faceid_closest']], dtype=torch.long)
|
52 |
+
self.remeshing_relevant_barys = torch.tensor(self.remeshing_dict['barys_closest'], dtype=torch.float32)
|
53 |
+
|
54 |
+
|
55 |
+
def prepare_anchor_pos_neg(self, batch_size, device):
|
56 |
+
l0 = np.arange(0, batch_size, 2)
|
57 |
+
l_anchor = []
|
58 |
+
l_pos = []
|
59 |
+
l_neg = []
|
60 |
+
for ind in l0:
|
61 |
+
xx = set(np.arange(0, batch_size))
|
62 |
+
xx.discard(ind)
|
63 |
+
xx.discard(ind+1)
|
64 |
+
for ind2 in xx:
|
65 |
+
if ind2 % 2 == 0:
|
66 |
+
l_anchor.append(ind)
|
67 |
+
l_pos.append(ind + 1)
|
68 |
+
else:
|
69 |
+
l_anchor.append(ind + 1)
|
70 |
+
l_pos.append(ind)
|
71 |
+
l_neg.append(ind2)
|
72 |
+
self.l_anchor = torch.Tensor(l_anchor).to(torch.int64).to(device)
|
73 |
+
self.l_pos = torch.Tensor(l_pos).to(torch.int64).to(device)
|
74 |
+
self.l_neg = torch.Tensor(l_neg).to(torch.int64).to(device)
|
75 |
+
return
|
76 |
+
|
77 |
+
|
78 |
+
def forward(self, output_reproj, target_dict, weight_dict=None):
|
79 |
+
|
80 |
+
# output_reproj: ['vertices_smal', 'keyp_3d', 'keyp_2d', 'silh_image']
|
81 |
+
# target_dict: ['index', 'center', 'scale', 'pts', 'tpts', 'target_weight']
|
82 |
+
batch_size = output_reproj['keyp_2d'].shape[0]
|
83 |
+
device = output_reproj['keyp_2d'].device
|
84 |
+
|
85 |
+
# loss on reprojected keypoints
|
86 |
+
output_kp_resh = (output_reproj['keyp_2d']).reshape((-1, 2))
|
87 |
+
target_kp_resh = (target_dict['tpts'][:, :, :2] / 64. * (256. - 1)).reshape((-1, 2))
|
88 |
+
weights_resh = target_dict['tpts'][:, :, 2].reshape((-1))
|
89 |
+
keyp_w_resh = self.keypoint_weights.repeat((batch_size, 1)).reshape((-1))
|
90 |
+
loss_keyp = ((((output_kp_resh - target_kp_resh)[weights_resh>0]**2).sum(axis=1).sqrt()*weights_resh[weights_resh>0])*keyp_w_resh[weights_resh>0]).sum() / \
|
91 |
+
max((weights_resh[weights_resh>0]*keyp_w_resh[weights_resh>0]).sum(), 1e-5)
|
92 |
+
|
93 |
+
# loss on reprojected silhouette
|
94 |
+
assert output_reproj['silh'].shape == (target_dict['silh'][:, None, :, :]).shape
|
95 |
+
silh_loss_type = 'default'
|
96 |
+
if silh_loss_type == 'default':
|
97 |
+
with torch.no_grad():
|
98 |
+
thr_silh = 20
|
99 |
+
diff = torch.norm(output_kp_resh - target_kp_resh, dim=1)
|
100 |
+
diff_x = diff.reshape((batch_size, -1))
|
101 |
+
weights_resh_x = weights_resh.reshape((batch_size, -1))
|
102 |
+
unweighted_kp_mean_dist = (diff_x * weights_resh_x).sum(dim=1) / ((weights_resh_x).sum(dim=1)+1e-6)
|
103 |
+
loss_silh_bs = ((output_reproj['silh'] - target_dict['silh'][:, None, :, :]) ** 2).sum(axis=3).sum(axis=2).sum(axis=1) / (output_reproj['silh'].shape[2]*output_reproj['silh'].shape[3])
|
104 |
+
loss_silh = loss_silh_bs[unweighted_kp_mean_dist<thr_silh].sum() / batch_size
|
105 |
+
else:
|
106 |
+
print('silh_loss_type: ' + silh_loss_type)
|
107 |
+
raise ValueError
|
108 |
+
|
109 |
+
# shape regularization
|
110 |
+
# 'smal': loss on betas (pca coefficients), betas should be close to 0
|
111 |
+
# 'limbs...' loss on selected betas_limbs
|
112 |
+
loss_shape_weighted_list = [torch.zeros((1), device=device).mean().to(output_reproj['keyp_2d'].device)]
|
113 |
+
for ind_sp, sp in enumerate(weight_dict['shape_options']):
|
114 |
+
weight_sp = weight_dict['shape'][ind_sp]
|
115 |
+
# self.logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
|
116 |
+
if sp == 'smal':
|
117 |
+
loss_shape_tmp = self.shape_prior(output_reproj['betas'])
|
118 |
+
elif sp == 'limbs':
|
119 |
+
loss_shape_tmp = torch.mean((output_reproj['betas_limbs'])**2)
|
120 |
+
elif sp == 'limbs7':
|
121 |
+
limb_coeffs_list = [0.01, 1, 0.1, 1, 1, 0.1, 2]
|
122 |
+
limb_coeffs = torch.tensor(limb_coeffs_list).to(torch.float32).to(target_dict['tpts'].device)
|
123 |
+
loss_shape_tmp = torch.mean((output_reproj['betas_limbs'] * limb_coeffs[None, :])**2)
|
124 |
+
else:
|
125 |
+
raise NotImplementedError
|
126 |
+
loss_shape_weighted_list.append(weight_sp * loss_shape_tmp)
|
127 |
+
loss_shape_weighted = torch.stack((loss_shape_weighted_list)).sum()
|
128 |
+
|
129 |
+
# 3D loss for dogs for which we have a unity model or toy figure
|
130 |
+
loss_models3d = torch.zeros((1), device=device).mean().to(output_reproj['betas'].device)
|
131 |
+
if 'models3d' in weight_dict.keys():
|
132 |
+
if weight_dict['models3d'] > 0:
|
133 |
+
assert (self.dog_betas_unity is not None)
|
134 |
+
if weight_dict['models3d'] > 0:
|
135 |
+
for ind_dog in range(target_dict['breed_index'].shape[0]):
|
136 |
+
breed_index = np.asscalar(target_dict['breed_index'][ind_dog].detach().cpu().numpy())
|
137 |
+
if breed_index in self.dog_betas_unity.keys():
|
138 |
+
betas_target = self.dog_betas_unity[breed_index][:output_reproj['betas'].shape[1]].to(output_reproj['betas'].device)
|
139 |
+
betas_output = output_reproj['betas'][ind_dog, :]
|
140 |
+
betas_limbs_output = output_reproj['betas_limbs'][ind_dog, :]
|
141 |
+
loss_models3d += ((betas_limbs_output**2).sum() + ((betas_output-betas_target)**2).sum()) / (output_reproj['betas'].shape[1] + output_reproj['betas_limbs'].shape[1])
|
142 |
+
else:
|
143 |
+
weight_dict['models3d'] = 0.0
|
144 |
+
else:
|
145 |
+
weight_dict['models3d'] = 0.0
|
146 |
+
|
147 |
+
# shape resularization loss on shapedirs
|
148 |
+
# -> in the current version shapedirs are kept fixed, so we don't need those losses
|
149 |
+
if weight_dict['shapedirs'] > 0:
|
150 |
+
raise NotImplementedError
|
151 |
+
else:
|
152 |
+
loss_shapedirs = torch.zeros((1), device=device).mean().to(output_reproj['betas'].device)
|
153 |
+
|
154 |
+
# prior on back joints (not used in cvpr 2022 paper)
|
155 |
+
# -> elementwise MSE loss on all 6 coefficients of 6d rotation representation
|
156 |
+
if 'pose_0' in weight_dict.keys():
|
157 |
+
if weight_dict['pose_0'] > 0:
|
158 |
+
pred_pose_rot6d = output_reproj['pose_rot6d']
|
159 |
+
w_rj_np = np.zeros((pred_pose_rot6d.shape[1]))
|
160 |
+
w_rj_np[[2, 3, 4, 5]] = 1.0 # back
|
161 |
+
w_rj = torch.tensor(w_rj_np).to(torch.float32).to(pred_pose_rot6d.device)
|
162 |
+
zero_rot = torch.tensor([1, 0, 0, 1, 0, 0]).to(pred_pose_rot6d.device).to(torch.float32)[None, None, :].repeat((batch_size, pred_pose_rot6d.shape[1], 1))
|
163 |
+
loss_pose = self.criterion_regr(pred_pose_rot6d*w_rj[None, :, None], zero_rot*w_rj[None, :, None])
|
164 |
+
else:
|
165 |
+
loss_pose = torch.zeros((1), device=device).mean()
|
166 |
+
|
167 |
+
# pose prior
|
168 |
+
# -> we did experiment with different pose priors, for example:
|
169 |
+
# * similart to SMALify (https://github.com/benjiebob/SMALify/blob/master/smal_fitter/smal_fitter.py,
|
170 |
+
# https://github.com/benjiebob/SMALify/blob/master/smal_fitter/priors/pose_prior_35.py)
|
171 |
+
# * vae
|
172 |
+
# * normalizing flow pose prior
|
173 |
+
# -> our cvpr 2022 paper uses the normalizing flow pose prior as implemented below
|
174 |
+
if 'poseprior' in weight_dict.keys():
|
175 |
+
if weight_dict['poseprior'] > 0:
|
176 |
+
pred_pose_rot6d = output_reproj['pose_rot6d']
|
177 |
+
pred_pose = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3))
|
178 |
+
if 'normalizing_flow_tiger' in weight_dict['poseprior_options']:
|
179 |
+
if output_reproj['normflow_z'] is not None:
|
180 |
+
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='square')
|
181 |
+
else:
|
182 |
+
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='square')
|
183 |
+
elif 'normalizing_flow_tiger_logprob' in weight_dict['poseprior_options']:
|
184 |
+
if output_reproj['normflow_z'] is not None:
|
185 |
+
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss_from_z(output_reproj['normflow_z'], type='neg_log_prob')
|
186 |
+
else:
|
187 |
+
loss_poseprior = self.normalizing_flow_pose_prior.calculate_loss(pred_pose_rot6d, type='neg_log_prob')
|
188 |
+
else:
|
189 |
+
raise NotImplementedError
|
190 |
+
else:
|
191 |
+
loss_poseprior = torch.zeros((1), device=device).mean()
|
192 |
+
else:
|
193 |
+
weight_dict['poseprior'] = 0
|
194 |
+
loss_poseprior = torch.zeros((1), device=device).mean()
|
195 |
+
|
196 |
+
# add a prior which penalizes side-movement angles for legs
|
197 |
+
if 'poselegssidemovement' in weight_dict.keys():
|
198 |
+
if weight_dict['poselegssidemovement'] > 0:
|
199 |
+
use_pose_legs_side_loss = True
|
200 |
+
else:
|
201 |
+
use_pose_legs_side_loss = False
|
202 |
+
else:
|
203 |
+
use_pose_legs_side_loss = False
|
204 |
+
if use_pose_legs_side_loss:
|
205 |
+
leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
|
206 |
+
leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
|
207 |
+
vec = torch.zeros((3, 1)).to(device=pred_pose.device, dtype=pred_pose.dtype)
|
208 |
+
vec[2] = -1
|
209 |
+
x0_rotmat = pred_pose
|
210 |
+
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
|
211 |
+
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
|
212 |
+
x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
|
213 |
+
x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
|
214 |
+
eps=0 # 1e-7
|
215 |
+
# use the component of the vector which points to the side
|
216 |
+
loss_poselegssidemovement = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean()
|
217 |
+
else:
|
218 |
+
loss_poselegssidemovement = torch.zeros((1), device=device).mean()
|
219 |
+
weight_dict['poselegssidemovement'] = 0
|
220 |
+
|
221 |
+
# dog breed classification loss
|
222 |
+
dog_breed_gt = target_dict['breed_index']
|
223 |
+
dog_breed_pred = output_reproj['dog_breed']
|
224 |
+
loss_class = self.criterion_class(dog_breed_pred, dog_breed_gt)
|
225 |
+
|
226 |
+
# dog breed relationship loss
|
227 |
+
# -> we did experiment with many other options, but none was significantly better
|
228 |
+
if '4' in weight_dict['breed_options']: # we have pairs of dogs of the same breed
|
229 |
+
if weight_dict['breed'] > 0:
|
230 |
+
assert output_reproj['dog_breed'].shape[0] == 12
|
231 |
+
# assert weight_dict['breed'] > 0
|
232 |
+
z = output_reproj['z']
|
233 |
+
# go through all pairs and compare them to each other sample
|
234 |
+
if self.l_anchor is None:
|
235 |
+
self.prepare_anchor_pos_neg(batch_size, z.device)
|
236 |
+
anchor = torch.index_select(z, 0, self.l_anchor)
|
237 |
+
positive = torch.index_select(z, 0, self.l_pos)
|
238 |
+
negative = torch.index_select(z, 0, self.l_neg)
|
239 |
+
loss_breed = self.criterion_triplet(anchor, positive, negative)
|
240 |
+
else:
|
241 |
+
loss_breed = torch.zeros((1), device=device).mean()
|
242 |
+
else:
|
243 |
+
loss_breed = torch.zeros((1), device=device).mean()
|
244 |
+
|
245 |
+
# regularizarion for focal length
|
246 |
+
loss_flength_near_mean = torch.mean(output_reproj['flength']**2)
|
247 |
+
loss_flength = loss_flength_near_mean
|
248 |
+
|
249 |
+
# bodypart segmentation loss
|
250 |
+
if 'partseg' in weight_dict.keys():
|
251 |
+
if weight_dict['partseg'] > 0:
|
252 |
+
raise NotImplementedError
|
253 |
+
else:
|
254 |
+
loss_partseg = torch.zeros((1), device=device).mean()
|
255 |
+
else:
|
256 |
+
weight_dict['partseg'] = 0
|
257 |
+
loss_partseg = torch.zeros((1), device=device).mean()
|
258 |
+
|
259 |
+
|
260 |
+
# NEW: ground contact loss for main network
|
261 |
+
keep_smal_mesh = False
|
262 |
+
if 'gc_plane' in weight_dict.keys():
|
263 |
+
if weight_dict['gc_plane'] > 0:
|
264 |
+
if keep_smal_mesh:
|
265 |
+
target_gc_class = target_dict['gc'][:, :, 0]
|
266 |
+
gc_errors_plane = calculate_plane_errors_batch(output_reproj['vertices_smal'], target_gc_class, target_dict['has_gc'], target_dict['has_gc_is_touching'])
|
267 |
+
loss_gc_plane = torch.mean(gc_errors_plane)
|
268 |
+
else: # use a uniformly sampled mesh
|
269 |
+
target_gc_class = target_dict['gc'][:, :, 0]
|
270 |
+
device = output_reproj['vertices_smal'].device
|
271 |
+
remeshing_relevant_faces = self.remeshing_relevant_faces.to(device)
|
272 |
+
remeshing_relevant_barys = self.remeshing_relevant_barys.to(device)
|
273 |
+
|
274 |
+
bs = output_reproj['vertices_smal'].shape[0]
|
275 |
+
# verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, output_reproj['vertices_smal'][:, self.remeshing_relevant_faces])
|
276 |
+
# sel_verts_comparison = output_reproj['vertices_smal'][:, self.remeshing_relevant_faces]
|
277 |
+
# verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts_comparison)
|
278 |
+
sel_verts = torch.index_select(output_reproj['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
|
279 |
+
verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
|
280 |
+
target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, self.remeshing_relevant_faces].to(device=device, dtype=torch.float32))
|
281 |
+
target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
|
282 |
+
gc_errors_plane, gc_errors_under_plane = calculate_plane_errors_batch(verts_remeshed, target_gc_class_remeshed_prep, target_dict['has_gc'], target_dict['has_gc_is_touching'])
|
283 |
+
loss_gc_plane = torch.mean(gc_errors_plane)
|
284 |
+
loss_gc_belowplane = torch.mean(gc_errors_under_plane)
|
285 |
+
# loss_dict_temp['gc_plane'] = torch.mean(gc_errors_plane)
|
286 |
+
else:
|
287 |
+
loss_gc_plane = torch.zeros((1), device=device).mean()
|
288 |
+
loss_gc_belowplane = torch.zeros((1), device=device).mean()
|
289 |
+
else:
|
290 |
+
loss_gc_plane = torch.zeros((1), device=device).mean()
|
291 |
+
loss_gc_belowplane = torch.zeros((1), device=device).mean()
|
292 |
+
weight_dict['gc_plane'] = 0
|
293 |
+
weight_dict['gc_belowplane'] = 0
|
294 |
+
|
295 |
+
|
296 |
+
|
297 |
+
# weight and combine losses
|
298 |
+
loss_keyp_weighted = loss_keyp * weight_dict['keyp']
|
299 |
+
loss_silh_weighted = loss_silh * weight_dict['silh']
|
300 |
+
loss_shapedirs_weighted = loss_shapedirs * weight_dict['shapedirs']
|
301 |
+
loss_pose_weighted = loss_pose * weight_dict['pose_0']
|
302 |
+
loss_class_weighted = loss_class * weight_dict['class']
|
303 |
+
loss_breed_weighted = loss_breed * weight_dict['breed']
|
304 |
+
loss_flength_weighted = loss_flength * weight_dict['flength']
|
305 |
+
loss_poseprior_weighted = loss_poseprior * weight_dict['poseprior']
|
306 |
+
loss_partseg_weighted = loss_partseg * weight_dict['partseg']
|
307 |
+
loss_models3d_weighted = loss_models3d * weight_dict['models3d']
|
308 |
+
loss_poselegssidemovement_weighted = loss_poselegssidemovement * weight_dict['poselegssidemovement']
|
309 |
+
|
310 |
+
loss_gc_plane_weighted = loss_gc_plane * weight_dict['gc_plane']
|
311 |
+
loss_gc_belowplane_weighted = loss_gc_belowplane * weight_dict['gc_belowplane']
|
312 |
+
|
313 |
+
|
314 |
+
####################################################################################################
|
315 |
+
loss = loss_keyp_weighted + loss_silh_weighted + loss_shape_weighted + loss_pose_weighted + loss_class_weighted + \
|
316 |
+
loss_shapedirs_weighted + loss_breed_weighted + loss_flength_weighted + loss_poseprior_weighted + \
|
317 |
+
loss_partseg_weighted + loss_models3d_weighted + loss_poselegssidemovement_weighted + \
|
318 |
+
loss_gc_plane_weighted + loss_gc_belowplane_weighted
|
319 |
+
####################################################################################################
|
320 |
+
|
321 |
+
loss_dict = {'loss': loss.item(),
|
322 |
+
'loss_keyp_weighted': loss_keyp_weighted.item(), \
|
323 |
+
'loss_silh_weighted': loss_silh_weighted.item(), \
|
324 |
+
'loss_shape_weighted': loss_shape_weighted.item(), \
|
325 |
+
'loss_shapedirs_weighted': loss_shapedirs_weighted.item(), \
|
326 |
+
'loss_pose0_weighted': loss_pose_weighted.item(), \
|
327 |
+
'loss_class_weighted': loss_class_weighted.item(), \
|
328 |
+
'loss_breed_weighted': loss_breed_weighted.item(), \
|
329 |
+
'loss_flength_weighted': loss_flength_weighted.item(), \
|
330 |
+
'loss_poseprior_weighted': loss_poseprior_weighted.item(), \
|
331 |
+
'loss_partseg_weighted': loss_partseg_weighted.item(), \
|
332 |
+
'loss_models3d_weighted': loss_models3d_weighted.item(), \
|
333 |
+
'loss_poselegssidemovement_weighted': loss_poselegssidemovement_weighted.item(), \
|
334 |
+
'loss_gc_plane_weighted': loss_gc_plane_weighted.item(), \
|
335 |
+
'loss_gc_belowplane_weighted': loss_gc_belowplane_weighted.item()
|
336 |
+
}
|
337 |
+
|
338 |
+
return loss, loss_dict
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
|
src/combined_model/loss_utils/loss_arap.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# code from https://raw.githubusercontent.com/yufu-wang/aves/main/optimization/loss_arap.py
|
4 |
+
|
5 |
+
|
6 |
+
class Arap_Loss():
|
7 |
+
'''
|
8 |
+
Pytorch implementaion: As-rigid-as-possible loss class
|
9 |
+
|
10 |
+
'''
|
11 |
+
|
12 |
+
def __init__(self, meshes, device='cpu', vertex_w=None):
|
13 |
+
|
14 |
+
with torch.no_grad(): # new nadine
|
15 |
+
|
16 |
+
self.device = device
|
17 |
+
self.bn = len(meshes)
|
18 |
+
|
19 |
+
# get lapacian cotangent matrix
|
20 |
+
L = self.get_laplacian_cot(meshes)
|
21 |
+
self.wij = L.values().clone()
|
22 |
+
self.wij[self.wij<0] = 0.
|
23 |
+
|
24 |
+
# get ajacency matrix
|
25 |
+
V = meshes.num_verts_per_mesh().sum()
|
26 |
+
edges_packed = meshes.edges_packed()
|
27 |
+
e0, e1 = edges_packed.unbind(1)
|
28 |
+
idx01 = torch.stack([e0, e1], dim=1)
|
29 |
+
idx10 = torch.stack([e1, e0], dim=1)
|
30 |
+
idx = torch.cat([idx01, idx10], dim=0).t()
|
31 |
+
|
32 |
+
ones = torch.ones(idx.shape[1], dtype=torch.float32).to(device)
|
33 |
+
A = torch.sparse.FloatTensor(idx, ones, (V, V))
|
34 |
+
self.deg = torch.sparse.sum(A, dim=1).to_dense().long()
|
35 |
+
self.idx = self.sort_idx(idx)
|
36 |
+
|
37 |
+
# get edges of default mesh
|
38 |
+
self.eij = self.get_edges(meshes)
|
39 |
+
|
40 |
+
# get per vertex regularization strength
|
41 |
+
self.vertex_w = vertex_w
|
42 |
+
|
43 |
+
|
44 |
+
def __call__(self, new_meshes):
|
45 |
+
new_meshes._compute_packed()
|
46 |
+
|
47 |
+
optimal_R = self.step_1(new_meshes)
|
48 |
+
arap_loss = self.step_2(optimal_R, new_meshes)
|
49 |
+
return arap_loss
|
50 |
+
|
51 |
+
|
52 |
+
def step_1(self, new_meshes):
|
53 |
+
bn = self.bn
|
54 |
+
eij = self.eij.view(bn, -1, 3).cpu()
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
eij_ = self.get_edges(new_meshes)
|
58 |
+
|
59 |
+
eij_ = eij_.view(bn, -1, 3).cpu()
|
60 |
+
wij = self.wij.view(bn, -1).cpu()
|
61 |
+
|
62 |
+
deg_1 = self.deg.view(bn, -1)[0].cpu() # assuming same topology
|
63 |
+
S = torch.zeros([bn, len(deg_1), 3, 3])
|
64 |
+
for i in range(len(deg_1)):
|
65 |
+
start, end = deg_1[:i].sum(), deg_1[:i+1].sum()
|
66 |
+
P = eij[:, start : end]
|
67 |
+
P_ = eij_[:, start : end]
|
68 |
+
D = wij[:, start : end]
|
69 |
+
D = torch.diag_embed(D)
|
70 |
+
S[:, i] = P.transpose(-2,-1) @ D @ P_
|
71 |
+
|
72 |
+
S = S.view(-1, 3, 3)
|
73 |
+
|
74 |
+
u, _, v = torch.svd(S)
|
75 |
+
R = v @ u.transpose(-2, -1)
|
76 |
+
det = torch.det(R)
|
77 |
+
|
78 |
+
u[det<0, :, -1] *= -1
|
79 |
+
R = v @ u.transpose(-2, -1)
|
80 |
+
R = R.to(self.device)
|
81 |
+
|
82 |
+
return R
|
83 |
+
|
84 |
+
|
85 |
+
def step_2(self, R, new_meshes):
|
86 |
+
R = torch.repeat_interleave(R, self.deg, dim=0)
|
87 |
+
Reij = R @ self.eij.unsqueeze(2)
|
88 |
+
Reij = Reij.squeeze()
|
89 |
+
|
90 |
+
eij_ = self.get_edges(new_meshes)
|
91 |
+
arap_loss = self.wij * (eij_ - Reij).norm(dim=1)
|
92 |
+
|
93 |
+
if self.vertex_w is not None:
|
94 |
+
vertex_w = torch.repeat_interleave(self.vertex_w, self.deg, dim=0)
|
95 |
+
arap_loss = arap_loss * vertex_w
|
96 |
+
|
97 |
+
arap_loss = arap_loss.sum() / self.bn
|
98 |
+
|
99 |
+
return arap_loss
|
100 |
+
|
101 |
+
|
102 |
+
def get_edges(self, meshes):
|
103 |
+
verts_packed = meshes.verts_packed()
|
104 |
+
vi = torch.repeat_interleave(verts_packed, self.deg, dim=0)
|
105 |
+
vj = verts_packed[self.idx[1]]
|
106 |
+
eij = vi - vj
|
107 |
+
return eij
|
108 |
+
|
109 |
+
|
110 |
+
def sort_idx(self, idx):
|
111 |
+
_, order = (idx[0] + idx[1]*1e-6).sort()
|
112 |
+
|
113 |
+
return idx[:, order]
|
114 |
+
|
115 |
+
|
116 |
+
def get_laplacian_cot(self, meshes):
|
117 |
+
'''
|
118 |
+
Routine modified from :
|
119 |
+
pytorch3d/loss/mesh_laplacian_smoothing.py
|
120 |
+
'''
|
121 |
+
verts_packed = meshes.verts_packed()
|
122 |
+
faces_packed = meshes.faces_packed()
|
123 |
+
V, F = verts_packed.shape[0], faces_packed.shape[0]
|
124 |
+
|
125 |
+
face_verts = verts_packed[faces_packed]
|
126 |
+
v0, v1, v2 = face_verts[:,0], face_verts[:,1], face_verts[:,2]
|
127 |
+
|
128 |
+
A = (v1-v2).norm(dim=1)
|
129 |
+
B = (v0-v2).norm(dim=1)
|
130 |
+
C = (v0-v1).norm(dim=1)
|
131 |
+
|
132 |
+
s = 0.5 * (A+B+C)
|
133 |
+
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
|
134 |
+
|
135 |
+
A2, B2, C2 = A * A, B * B, C * C
|
136 |
+
cota = (B2 + C2 - A2) / area
|
137 |
+
cotb = (A2 + C2 - B2) / area
|
138 |
+
cotc = (A2 + B2 - C2) / area
|
139 |
+
cot = torch.stack([cota, cotb, cotc], dim=1)
|
140 |
+
cot /= 4.0
|
141 |
+
|
142 |
+
ii = faces_packed[:, [1,2,0]]
|
143 |
+
jj = faces_packed[:, [2,0,1]]
|
144 |
+
idx = torch.stack([ii, jj], dim=0).view(2, F*3)
|
145 |
+
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
|
146 |
+
L += L.t()
|
147 |
+
L = L.coalesce()
|
148 |
+
L /= 2.0 # normalized according to arap paper
|
149 |
+
|
150 |
+
return L
|
151 |
+
|
152 |
+
|
153 |
+
|
src/combined_model/loss_utils/loss_laplacian_mesh_comparison.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_utils.py
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
# Laplacian loss, calculate the Laplacian coordiante of both coarse and refined vertices and then compare the difference
|
8 |
+
class LaplacianCTF(torch.nn.Module):
|
9 |
+
def __init__(self, adjmat, device):
|
10 |
+
'''
|
11 |
+
Args:
|
12 |
+
adjmat: adjacency matrix of the input graph data
|
13 |
+
device: specify device for training
|
14 |
+
'''
|
15 |
+
super(LaplacianCTF, self).__init__()
|
16 |
+
adjmat.data = np.ones_like(adjmat.data)
|
17 |
+
adjmat = torch.from_numpy(adjmat.todense()).float()
|
18 |
+
dg = torch.sum(adjmat, dim=-1)
|
19 |
+
dg_m = torch.diag(dg)
|
20 |
+
ls = dg_m - adjmat
|
21 |
+
self.ls = ls.unsqueeze(0).to(device) # Should be normalized by the diagonal elements according to
|
22 |
+
# the origial definition, this one also works fine.
|
23 |
+
|
24 |
+
def forward(self, verts_pred, verts_gt, smooth=False):
|
25 |
+
verts_pred = torch.matmul(self.ls, verts_pred)
|
26 |
+
verts_gt = torch.matmul(self.ls, verts_gt)
|
27 |
+
loss = torch.norm(verts_pred - verts_gt, dim=-1).mean()
|
28 |
+
if smooth:
|
29 |
+
loss_smooth = torch.norm(torch.matmul(self.ls, verts_pred), dim=-1).mean()
|
30 |
+
return loss, loss_smooth
|
31 |
+
return loss, None
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
#
|
37 |
+
# read the adjacency matrix, which will used in the Laplacian regularizer
|
38 |
+
# data = np.load('./data/mesh_down_sampling_4.npz', encoding='latin1', allow_pickle=True)
|
39 |
+
# adjmat = data['A'][0]
|
40 |
+
# laplacianloss = Laplacian(adjmat, device)
|
41 |
+
#
|
42 |
+
# verts_clone = verts.detach().clone()
|
43 |
+
# loss_arap, loss_smooth = laplacianloss(verts_refine, verts_clone)
|
44 |
+
# loss_arap = args.w_arap * loss_arap
|
45 |
+
#
|
src/combined_model/loss_utils/loss_sdf.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/util/loss_sdf.py
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
from scipy.ndimage import distance_transform_edt as distance
|
7 |
+
from skimage import segmentation as skimage_seg
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
|
10 |
+
|
11 |
+
def dice_loss(score, target):
|
12 |
+
# implemented from paper https://arxiv.org/pdf/1606.04797.pdf
|
13 |
+
target = target.float()
|
14 |
+
smooth = 1e-5
|
15 |
+
intersect = torch.sum(score * target)
|
16 |
+
y_sum = torch.sum(target * target)
|
17 |
+
z_sum = torch.sum(score * score)
|
18 |
+
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
|
19 |
+
loss = 1 - loss
|
20 |
+
return loss
|
21 |
+
|
22 |
+
|
23 |
+
class tversky_loss(torch.nn.Module):
|
24 |
+
# implemented from https://arxiv.org/pdf/1706.05721.pdf
|
25 |
+
def __init__(self, alpha, beta):
|
26 |
+
'''
|
27 |
+
Args:
|
28 |
+
alpha: coefficient for false positive prediction
|
29 |
+
beta: coefficient for false negtive prediction
|
30 |
+
'''
|
31 |
+
super(tversky_loss, self).__init__()
|
32 |
+
self.alpha = alpha
|
33 |
+
self.beta = beta
|
34 |
+
|
35 |
+
def __call__(self, score, target):
|
36 |
+
target = target.float()
|
37 |
+
smooth = 1e-5
|
38 |
+
tp = torch.sum(score * target)
|
39 |
+
fn = torch.sum(target * (1 - score))
|
40 |
+
fp = torch.sum((1-target) * score)
|
41 |
+
loss = (tp + smooth) / (tp + self.alpha * fp + self.beta * fn + smooth)
|
42 |
+
loss = 1 - loss
|
43 |
+
return loss
|
44 |
+
|
45 |
+
|
46 |
+
def compute_sdf1_1(img_gt, out_shape):
|
47 |
+
"""
|
48 |
+
compute the normalized signed distance map of binary mask
|
49 |
+
input: segmentation, shape = (batch_size, x, y, z)
|
50 |
+
output: the Signed Distance Map (SDM)
|
51 |
+
sdf(x) = 0; x in segmentation boundary
|
52 |
+
-inf|x-y|; x in segmentation
|
53 |
+
+inf|x-y|; x out of segmentation
|
54 |
+
normalize sdf to [-1, 1]
|
55 |
+
"""
|
56 |
+
|
57 |
+
img_gt = img_gt.astype(np.uint8)
|
58 |
+
|
59 |
+
normalized_sdf = np.zeros(out_shape)
|
60 |
+
|
61 |
+
for b in range(out_shape[0]): # batch size
|
62 |
+
# ignore background
|
63 |
+
for c in range(1, out_shape[1]):
|
64 |
+
posmask = img_gt[b]
|
65 |
+
negmask = 1-posmask
|
66 |
+
posdis = distance(posmask)
|
67 |
+
negdis = distance(negmask)
|
68 |
+
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
|
69 |
+
sdf = (negdis-np.min(negdis))/(np.max(negdis)-np.min(negdis)) - (posdis-np.min(posdis))/(np.max(posdis)-np.min(posdis))
|
70 |
+
sdf[boundary==1] = 0
|
71 |
+
normalized_sdf[b][c] = sdf
|
72 |
+
assert np.min(sdf) == -1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
|
73 |
+
assert np.max(sdf) == 1.0, print(np.min(posdis), np.min(negdis), np.max(posdis), np.max(negdis))
|
74 |
+
|
75 |
+
return normalized_sdf
|
76 |
+
|
77 |
+
|
78 |
+
def compute_sdf(img_gt, out_shape):
|
79 |
+
"""
|
80 |
+
compute the signed distance map of binary mask
|
81 |
+
input: segmentation, shape = (batch_size, x, y, z)
|
82 |
+
output: the Signed Distance Map (SDM)
|
83 |
+
sdf(x) = 0; x in segmentation boundary
|
84 |
+
-inf|x-y|; x in segmentation
|
85 |
+
+inf|x-y|; x out of segmentation
|
86 |
+
"""
|
87 |
+
|
88 |
+
img_gt = img_gt.astype(np.uint8)
|
89 |
+
|
90 |
+
gt_sdf = np.zeros(out_shape)
|
91 |
+
debug = False
|
92 |
+
for b in range(out_shape[0]): # batch size
|
93 |
+
for c in range(0, out_shape[1]):
|
94 |
+
posmask = img_gt[b]
|
95 |
+
negmask = 1-posmask
|
96 |
+
posdis = distance(posmask)
|
97 |
+
negdis = distance(negmask)
|
98 |
+
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
|
99 |
+
sdf = negdis - posdis
|
100 |
+
sdf[boundary==1] = 0
|
101 |
+
gt_sdf[b][c] = sdf
|
102 |
+
if debug:
|
103 |
+
plt.figure()
|
104 |
+
plt.subplot(1, 2, 1), plt.imshow(img_gt[b, 0, :, :]), plt.colorbar()
|
105 |
+
plt.subplot(1, 2, 2), plt.imshow(gt_sdf[b, 0, :, :]), plt.colorbar()
|
106 |
+
plt.show()
|
107 |
+
|
108 |
+
return gt_sdf
|
109 |
+
|
110 |
+
|
111 |
+
def boundary_loss(output, gt):
|
112 |
+
"""
|
113 |
+
compute boundary loss for binary segmentation
|
114 |
+
input: outputs_soft: softmax results, shape=(b,2,x,y,z)
|
115 |
+
gt_sdf: sdf of ground truth (can be original or normalized sdf); shape=(b,2,x,y,z)
|
116 |
+
output: boundary_loss; sclar
|
117 |
+
adopted from http://proceedings.mlr.press/v102/kervadec19a/kervadec19a.pdf
|
118 |
+
"""
|
119 |
+
multipled = torch.einsum('bcxy, bcxy->bcxy', output, gt)
|
120 |
+
bd_loss = multipled.mean()
|
121 |
+
|
122 |
+
return bd_loss
|
src/combined_model/loss_utils/loss_utils.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
'''
|
7 |
+
def keyp_rep_error_l1(smpl_keyp_2d, keyp_hourglass, keyp_hourglass_scores, thr_kp=0.3):
|
8 |
+
# step 1: make sure that the hg prediction and barc are close
|
9 |
+
with torch.no_grad():
|
10 |
+
kp_weights = keyp_hourglass_scores
|
11 |
+
kp_weights[keyp_hourglass_scores<thr_kp] = 0
|
12 |
+
loss_keyp_rep = torch.mean((torch.abs((smpl_keyp_2d - keyp_hourglass)/512)).sum(dim=2)*kp_weights[:, :, 0])
|
13 |
+
return loss_keyp_rep
|
14 |
+
|
15 |
+
def keyp_rep_error(smpl_keyp_2d, keyp_hourglass, keyp_hourglass_scores, thr_kp=0.3):
|
16 |
+
# step 1: make sure that the hg prediction and barc are close
|
17 |
+
with torch.no_grad():
|
18 |
+
kp_weights = keyp_hourglass_scores
|
19 |
+
kp_weights[keyp_hourglass_scores<thr_kp] = 0
|
20 |
+
# losses['kp_reproj']['value'] = torch.mean((((smpl_keyp_2d - keyp_reproj_init)/512)**2).sum(dim=2)*kp_weights[:, :, 0])
|
21 |
+
loss_keyp_rep = torch.mean((((smpl_keyp_2d - keyp_hourglass)/512)**2).sum(dim=2)*kp_weights[:, :, 0])
|
22 |
+
return loss_keyp_rep
|
23 |
+
'''
|
24 |
+
|
25 |
+
def leg_sideway_error(optimed_pose_with_glob):
|
26 |
+
assert optimed_pose_with_glob.shape[1] == 35
|
27 |
+
leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
|
28 |
+
leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
|
29 |
+
# leg_indices_right = np.asarray([8, 9, 10, 18, 19, 20]) # front, back
|
30 |
+
# leg_indices_left = np.asarray([12, 13, 14, 22, 23, 24]) # front, back
|
31 |
+
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
|
32 |
+
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
|
33 |
+
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
|
34 |
+
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
|
35 |
+
vec[2] = -1
|
36 |
+
x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
|
37 |
+
x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
|
38 |
+
loss_pose_legs_side = (x0_legs_left[:, 1]**2).mean() + (x0_legs_right[:, 1]**2).mean()
|
39 |
+
return loss_pose_legs_side
|
40 |
+
|
41 |
+
|
42 |
+
def leg_torsion_error(optimed_pose_with_glob):
|
43 |
+
leg_indices_right = np.asarray([7, 8, 9, 10, 17, 18, 19, 20]) # front, back
|
44 |
+
leg_indices_left = np.asarray([11, 12, 13, 14, 21, 22, 23, 24]) # front, back
|
45 |
+
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
|
46 |
+
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
|
47 |
+
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
|
48 |
+
vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
|
49 |
+
vec_x[0] = 1 # in x direction
|
50 |
+
x_x_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec_x
|
51 |
+
x_x_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec_x
|
52 |
+
loss_pose_legs_torsion = (x_x_legs_left[:, 1]**2).mean() + (x_x_legs_right[:, 1]**2).mean()
|
53 |
+
return loss_pose_legs_torsion
|
54 |
+
|
55 |
+
|
56 |
+
def frontleg_walkingdir_error(optimed_pose_with_glob):
|
57 |
+
# this prior should only be used for standing poses!
|
58 |
+
leg_indices_right = np.asarray([7, 8, 9, 10]) # front, back
|
59 |
+
leg_indices_left = np.asarray([11, 12, 13, 14]) # front, back
|
60 |
+
relevant_back_indices = np.asarray([1, 2, 3, 4, 5, 6]) # np.asarray([6]) # back joint in the front
|
61 |
+
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
|
62 |
+
x0_rotmat_legs_left = x0_rotmat[:, leg_indices_left, :, :]
|
63 |
+
x0_rotmat_legs_right = x0_rotmat[:, leg_indices_right, :, :]
|
64 |
+
x0_rotmat_back = x0_rotmat[:, relevant_back_indices, :, :]
|
65 |
+
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
|
66 |
+
vec[2] = -1 # vector down
|
67 |
+
x0_legs_left = x0_rotmat_legs_left.reshape((-1, 3, 3))@vec
|
68 |
+
x0_legs_right = x0_rotmat_legs_right.reshape((-1, 3, 3))@vec
|
69 |
+
x0_back = x0_rotmat_back.reshape((-1, 3, 3))@vec
|
70 |
+
loss_pose_legs_side = (x0_legs_left[:, 0]**2).mean() + (x0_legs_right[:, 0]**2).mean() + (x0_back[:, 0]**2).mean() # penalize movement to front
|
71 |
+
return loss_pose_legs_side
|
72 |
+
|
73 |
+
|
74 |
+
def tail_sideway_error(optimed_pose_with_glob):
|
75 |
+
tail_indices = np.asarray([25, 26, 27, 28, 29, 30, 31])
|
76 |
+
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
|
77 |
+
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
|
78 |
+
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
|
79 |
+
'''vec[2] = -1
|
80 |
+
x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
|
81 |
+
loss_pose_tail_side = (x0_tail[:, 1]**2).mean()'''
|
82 |
+
vec[0] = -1
|
83 |
+
x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
|
84 |
+
loss_pose_tail_side = (x0_tail[:, 1]**2).mean()
|
85 |
+
return loss_pose_tail_side
|
86 |
+
|
87 |
+
|
88 |
+
def tail_torsion_error(optimed_pose_with_glob):
|
89 |
+
tail_indices = np.asarray([25, 26, 27, 28, 29, 30, 31])
|
90 |
+
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
|
91 |
+
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
|
92 |
+
vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
|
93 |
+
'''vec_x[0] = 1 # in x direction
|
94 |
+
x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
|
95 |
+
loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean()'''
|
96 |
+
vec_x[2] = 1 # in y direction
|
97 |
+
x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
|
98 |
+
loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean()
|
99 |
+
return loss_pose_tail_torsion
|
100 |
+
|
101 |
+
|
102 |
+
def spine_sideway_error(optimed_pose_with_glob):
|
103 |
+
tail_indices = np.asarray([1, 2, 3, 4, 5, 6]) # was wrong
|
104 |
+
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
|
105 |
+
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
|
106 |
+
vec = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
|
107 |
+
vec[0] = -1
|
108 |
+
x0_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec
|
109 |
+
loss_pose_tail_side = (x0_tail[:, 1]**2).mean()
|
110 |
+
return loss_pose_tail_side
|
111 |
+
|
112 |
+
|
113 |
+
def spine_torsion_error(optimed_pose_with_glob):
|
114 |
+
tail_indices = np.asarray([1, 2, 3, 4, 5, 6])
|
115 |
+
x0_rotmat = optimed_pose_with_glob # (1, 35, 3, 3)
|
116 |
+
x0_rotmat_tail = x0_rotmat[:, tail_indices, :, :]
|
117 |
+
vec_x = torch.zeros((3, 1)).to(device=optimed_pose_with_glob.device, dtype=optimed_pose_with_glob.dtype)
|
118 |
+
vec_x[2] = 1 # vec_x[0] = 1 # in z direction
|
119 |
+
x_x_tail = x0_rotmat_tail.reshape((-1, 3, 3))@vec_x
|
120 |
+
loss_pose_tail_torsion = (x_x_tail[:, 1]**2).mean() # (x_x_tail[:, 1]**2).mean()
|
121 |
+
return loss_pose_tail_torsion
|
122 |
+
|
123 |
+
|
124 |
+
def fit_plane(points_npx3):
|
125 |
+
# remarks:
|
126 |
+
# visualization of the plane: debug_code/curve_fitting_v2.py
|
127 |
+
# theory: https://www.ltu.se/cms_fs/1.51590!/svd-fitting.pdf
|
128 |
+
# remark: torch.svd is depreciated
|
129 |
+
# new plane equation:
|
130 |
+
# a(x−x0)+b(y−y0)+c(z−z0)=0
|
131 |
+
# ax+by+cz=d with d=ax0+by0+cz0
|
132 |
+
# z = (d-ax-by)/c
|
133 |
+
# here:
|
134 |
+
# a, b, c describe the plane normal
|
135 |
+
# d can be calculated (from a, b, c, x0, y0, z0)
|
136 |
+
# (x0, y0, z0) are the coordinates of a point on the
|
137 |
+
# plane, for example points_centroid
|
138 |
+
# (x, y, z) are the coordinates of a query point on the plane
|
139 |
+
#
|
140 |
+
# points_npx3: (n_points, 3)
|
141 |
+
# REMARK: this loss is not yet for batches!
|
142 |
+
# import pdb; pdb.set_trace()
|
143 |
+
# print('this loss is not yet for batches!')
|
144 |
+
assert (points_npx3.ndim == 2)
|
145 |
+
assert (points_npx3.shape[1] == 3)
|
146 |
+
points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
|
147 |
+
points_centroid = torch.mean(points, dim=1)
|
148 |
+
input_svd = points - points_centroid[:, None]
|
149 |
+
U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
|
150 |
+
plane_normal = U_svd[:, 2]
|
151 |
+
plane_squaredsumofdists = sigma_svd[2]
|
152 |
+
error = plane_squaredsumofdists
|
153 |
+
return points_centroid, plane_normal, error
|
154 |
+
|
155 |
+
|
156 |
+
def paws_to_groundplane_error(vertices, return_details=False):
|
157 |
+
# list of feet vertices (some of them)
|
158 |
+
# remark: we did annotate left indices and find the right insices using sym_ids_dict
|
159 |
+
# REMARK: this loss is not yet for batches!
|
160 |
+
# import pdb; pdb.set_trace()
|
161 |
+
# print('this loss is not yet for batches!')
|
162 |
+
list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
|
163 |
+
list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
|
164 |
+
list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
|
165 |
+
list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
|
166 |
+
assert vertices.shape[0] == 3889
|
167 |
+
assert vertices.shape[1] == 3
|
168 |
+
all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
|
169 |
+
verts_paws = vertices[all_paw_vert_idxs, :]
|
170 |
+
plane_centroid, plane_normal, error = fit_plane(verts_paws)
|
171 |
+
if return_details:
|
172 |
+
return plane_centroid, plane_normal, error
|
173 |
+
else:
|
174 |
+
return error
|
175 |
+
|
176 |
+
def groundcontact_error(vertices, gclabels, return_details=False):
|
177 |
+
# import pdb; pdb.set_trace()
|
178 |
+
# REMARK: this loss is not yet for batches!
|
179 |
+
import pdb; pdb.set_trace()
|
180 |
+
print('this loss is not yet for batches!')
|
181 |
+
assert vertices.shape[0] == 3889
|
182 |
+
assert vertices.shape[1] == 3
|
183 |
+
verts_gc = vertices[gclabels, :]
|
184 |
+
plane_centroid, plane_normal, error = fit_plane(verts_gc)
|
185 |
+
if return_details:
|
186 |
+
return plane_centroid, plane_normal, error
|
187 |
+
else:
|
188 |
+
return error
|
189 |
+
|
190 |
+
|
191 |
+
|
src/combined_model/loss_utils/loss_utils_gc.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
class LossGConMesh(torch.nn.Module):
|
9 |
+
def __init__(self , n_verts=3889):
|
10 |
+
super(LossGConMesh, self).__init__()
|
11 |
+
self.n_verts = n_verts
|
12 |
+
self.criterion_class = torch.nn.CrossEntropyLoss(reduction='mean')
|
13 |
+
|
14 |
+
def forward(self, pred_gc, target_gc, has_gc, loss_type_gcmesh='ce'):
|
15 |
+
# pred_gc has shape (bs, n_verts, 2)
|
16 |
+
# target_gc has shape (bs, n_verts, 3)
|
17 |
+
# with [first: no-contact=0 contact=1
|
18 |
+
# second: index of closest vertex with opposite label
|
19 |
+
# third: dist to that closest vertex]
|
20 |
+
target_gc_class = target_gc[:, :, 0]
|
21 |
+
target_gc_nearoppvert_ind = target_gc[:, :, 1]
|
22 |
+
target_gc_nearoppvert_dist = target_gc[:, :, 2]
|
23 |
+
# bs = pred_gc.shape[0]
|
24 |
+
bs = has_gc.sum()
|
25 |
+
|
26 |
+
if loss_type_gcmesh == 'ce': # cross entropy
|
27 |
+
# import pdb; pdb.set_trace()
|
28 |
+
|
29 |
+
# classification_loss = self.criterion_class(pred_gc.reshape((bs*self.n_verts, 2)), target_gc_class.reshape((bs*self.n_verts)))
|
30 |
+
classification_loss = self.criterion_class(pred_gc[has_gc==True, ...].reshape((bs*self.n_verts, 2)), target_gc_class[has_gc==True, ...].reshape((bs*self.n_verts)))
|
31 |
+
loss = classification_loss
|
32 |
+
else:
|
33 |
+
raise ValueError
|
34 |
+
|
35 |
+
return loss
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def calculate_plane_errors_batch(vertices, target_gc_class, target_has_gc, has_gc_is_touching, return_error_under_plane=True):
|
43 |
+
# remarks:
|
44 |
+
# visualization of the plane: debug_code/curve_fitting_v2.py
|
45 |
+
# theory: https://www.ltu.se/cms_fs/1.51590!/svd-fitting.pdf
|
46 |
+
# remark: torch.svd is depreciated
|
47 |
+
# new plane equation:
|
48 |
+
# a(x−x0)+b(y−y0)+c(z−z0)=0
|
49 |
+
# ax+by+cz=d with d=ax0+by0+cz0
|
50 |
+
# z = (d-ax-by)/c
|
51 |
+
# here:
|
52 |
+
# a, b, c describe the plane normal
|
53 |
+
# d can be calculated (from a, b, c, x0, y0, z0)
|
54 |
+
# (x0, y0, z0) are the coordinates of a point on the
|
55 |
+
# plane, for example points_centroid
|
56 |
+
# (x, y, z) are the coordinates of a query point on the plane
|
57 |
+
#
|
58 |
+
# input:
|
59 |
+
# vertices: (bs, 3889, 3)
|
60 |
+
# target_gc_class: (bs, 3889)
|
61 |
+
#
|
62 |
+
bs = vertices.shape[0]
|
63 |
+
error_list = []
|
64 |
+
error_under_plane_list = []
|
65 |
+
|
66 |
+
for ind_b in range(bs):
|
67 |
+
if target_has_gc[ind_b] == 1 and has_gc_is_touching[ind_b] == 1:
|
68 |
+
try:
|
69 |
+
points_npx3 = vertices[ind_b, target_gc_class[ind_b, :]==1, :]
|
70 |
+
points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
|
71 |
+
points_centroid = torch.mean(points, dim=1)
|
72 |
+
input_svd = points - points_centroid[:, None]
|
73 |
+
# U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
|
74 |
+
# plane_normal = U_svd[:, 2]
|
75 |
+
# _, sigma_svd, _ = torch.svd(input_svd, compute_uv=False)
|
76 |
+
# _, sigma_svd, _ = torch.svd(input_svd, compute_uv=True)
|
77 |
+
U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
|
78 |
+
plane_squaredsumofdists = sigma_svd[2]
|
79 |
+
error_list.append(plane_squaredsumofdists)
|
80 |
+
|
81 |
+
if return_error_under_plane:
|
82 |
+
# plane information
|
83 |
+
# plane_centroid = points_centroid
|
84 |
+
plane_normal = U_svd[:, 2]
|
85 |
+
|
86 |
+
# non-plane points
|
87 |
+
nonplane_points_npx3 = vertices[ind_b, target_gc_class[ind_b, :]==0, :] # (n_points_3)
|
88 |
+
nonplane_points = torch.transpose(nonplane_points_npx3, 0, 1) # (3, n_points)
|
89 |
+
nonplane_points_centered = nonplane_points - points_centroid[:, None]
|
90 |
+
|
91 |
+
nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered) # plane normal already has length 1
|
92 |
+
|
93 |
+
if nonplane_points_projected.sum() > 0:
|
94 |
+
# bug corrected 07.11.22
|
95 |
+
# error_under_plane = nonplane_points_projected[nonplane_points_projected<0].sum() / 100
|
96 |
+
error_under_plane = - nonplane_points_projected[nonplane_points_projected<0].sum() / 100
|
97 |
+
else:
|
98 |
+
error_under_plane = nonplane_points_projected[nonplane_points_projected>0].sum() / 100
|
99 |
+
error_under_plane_list.append(error_under_plane)
|
100 |
+
except:
|
101 |
+
print('was not able to calculate plane error for this image')
|
102 |
+
error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
|
103 |
+
error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
|
104 |
+
else:
|
105 |
+
error_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
|
106 |
+
error_under_plane_list.append(torch.zeros((1), dtype=vertices.dtype, device=vertices.device)[0])
|
107 |
+
errors = torch.stack(error_list, dim=0)
|
108 |
+
errors_under_plane = torch.stack(error_under_plane_list, dim=0)
|
109 |
+
|
110 |
+
if return_error_under_plane:
|
111 |
+
return errors, errors_under_plane
|
112 |
+
else:
|
113 |
+
return errors
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
# def calculate_vertex_wise_labeling_error():
|
118 |
+
# vertexwise_ground_contact
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
'''
|
133 |
+
|
134 |
+
def paws_to_groundplane_error_batch(vertices, return_details=False):
|
135 |
+
# list of feet vertices (some of them)
|
136 |
+
# remark: we did annotate left indices and find the right insices using sym_ids_dict
|
137 |
+
# REMARK: this loss is not yet for batches!
|
138 |
+
import pdb; pdb.set_trace()
|
139 |
+
print('this loss is not yet for batches!')
|
140 |
+
list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
|
141 |
+
list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
|
142 |
+
list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
|
143 |
+
list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
|
144 |
+
assert vertices.shape[0] == 3889
|
145 |
+
assert vertices.shape[1] == 3
|
146 |
+
all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
|
147 |
+
verts_paws = vertices[all_paw_vert_idxs, :]
|
148 |
+
plane_centroid, plane_normal, error = fit_plane_batch(verts_paws)
|
149 |
+
if return_details:
|
150 |
+
return plane_centroid, plane_normal, error
|
151 |
+
else:
|
152 |
+
return error
|
153 |
+
|
154 |
+
def paws_to_groundplane_error_batch_new(vertices, return_details=False):
|
155 |
+
# list of feet vertices (some of them)
|
156 |
+
# remark: we did annotate left indices and find the right insices using sym_ids_dict
|
157 |
+
# REMARK: this loss is not yet for batches!
|
158 |
+
import pdb; pdb.set_trace()
|
159 |
+
print('this loss is not yet for batches!')
|
160 |
+
list_back_left = [1524, 1517, 1512, 1671, 1678, 1664, 1956, 1680, 1685, 1602, 1953, 1569]
|
161 |
+
list_front_left = [1331, 1327, 1332, 1764, 1767, 1747, 1779, 1789, 1944, 1339, 1323, 1420]
|
162 |
+
list_back_right = [3476, 3469, 3464, 3623, 3630, 3616, 3838, 3632, 3637, 3554, 3835, 3521]
|
163 |
+
list_front_right = [3283, 3279, 3284, 3715, 3718, 3698, 3730, 3740, 3826, 3291, 3275, 3372]
|
164 |
+
assert vertices.shape[0] == 3889
|
165 |
+
assert vertices.shape[1] == 3
|
166 |
+
all_paw_vert_idxs = list_back_left + list_front_left + list_back_right + list_front_right
|
167 |
+
verts_paws = vertices[all_paw_vert_idxs, :]
|
168 |
+
plane_centroid, plane_normal, error = fit_plane_batch(verts_paws)
|
169 |
+
print('this loss is not yet for batches!')
|
170 |
+
points = torch.transpose(points_npx3, 0, 1) # (3, n_points)
|
171 |
+
points_centroid = torch.mean(points, dim=1)
|
172 |
+
input_svd = points - points_centroid[:, None]
|
173 |
+
U_svd, sigma_svd, V_svd = torch.svd(input_svd, compute_uv=True)
|
174 |
+
plane_normal = U_svd[:, 2]
|
175 |
+
plane_squaredsumofdists = sigma_svd[2]
|
176 |
+
error = plane_squaredsumofdists
|
177 |
+
print('error: ' + str(error.item()))
|
178 |
+
return error
|
179 |
+
'''
|
src/combined_model/model_shape_v7_withref_withgraphcnn.py
ADDED
@@ -0,0 +1,927 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import pickle as pkl
|
3 |
+
import numpy as np
|
4 |
+
import torchvision.models as models
|
5 |
+
from torchvision import transforms
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.nn.parameter import Parameter
|
9 |
+
from kornia.geometry.subpix import dsnt # kornia 0.4.0
|
10 |
+
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
14 |
+
from stacked_hourglass.utils.evaluation import get_preds_soft
|
15 |
+
from stacked_hourglass import hg1, hg2, hg8
|
16 |
+
from lifting_to_3d.linear_model import LinearModelComplete, LinearModel
|
17 |
+
from lifting_to_3d.inn_model_for_shape import INNForShape
|
18 |
+
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d
|
19 |
+
from smal_pytorch.smal_model.smal_torch_new import SMAL
|
20 |
+
from smal_pytorch.renderer.differentiable_renderer import SilhRenderer
|
21 |
+
from bps_2d.bps_for_segmentation import SegBPS
|
22 |
+
# from configs.SMAL_configs import SMAL_MODEL_DATA_PATH as SHAPE_PRIOR
|
23 |
+
from configs.SMAL_configs import SMAL_MODEL_CONFIG
|
24 |
+
from configs.SMAL_configs import MEAN_DOG_BONE_LENGTHS_NO_RED, VERTEX_IDS_TAIL
|
25 |
+
|
26 |
+
# NEW: for graph cnn part
|
27 |
+
from smal_pytorch.smal_model.smal_torch_new import SMAL
|
28 |
+
from configs.SMAL_configs import SMAL_MODEL_CONFIG
|
29 |
+
from graph_networks.graphcmr.utils_mesh import Mesh
|
30 |
+
from graph_networks.graphcmr.graph_cnn_groundcontact_multistage import GraphCNNMS
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
class SmallLinear(nn.Module):
|
36 |
+
def __init__(self, input_size=64, output_size=30, linear_size=128):
|
37 |
+
super(SmallLinear, self).__init__()
|
38 |
+
self.relu = nn.ReLU(inplace=True)
|
39 |
+
self.w1 = nn.Linear(input_size, linear_size)
|
40 |
+
self.w2 = nn.Linear(linear_size, linear_size)
|
41 |
+
self.w3 = nn.Linear(linear_size, output_size)
|
42 |
+
def forward(self, x):
|
43 |
+
# pre-processing
|
44 |
+
y = self.w1(x)
|
45 |
+
y = self.relu(y)
|
46 |
+
y = self.w2(y)
|
47 |
+
y = self.relu(y)
|
48 |
+
y = self.w3(y)
|
49 |
+
return y
|
50 |
+
|
51 |
+
|
52 |
+
class MyConv1d(nn.Module):
|
53 |
+
def __init__(self, input_size=37, output_size=30, start=True):
|
54 |
+
super(MyConv1d, self).__init__()
|
55 |
+
self.input_size = input_size
|
56 |
+
self.output_size = output_size
|
57 |
+
self.start = start
|
58 |
+
self.weight = Parameter(torch.ones((self.output_size)))
|
59 |
+
self.bias = Parameter(torch.zeros((self.output_size)))
|
60 |
+
def forward(self, x):
|
61 |
+
# pre-processing
|
62 |
+
if self.start:
|
63 |
+
y = x[:, :self.output_size]
|
64 |
+
else:
|
65 |
+
y = x[:, -self.output_size:]
|
66 |
+
y = y * self.weight[None, :] + self.bias[None, :]
|
67 |
+
return y
|
68 |
+
|
69 |
+
|
70 |
+
class ModelShapeAndBreed(nn.Module):
|
71 |
+
def __init__(self, smal_model_type, n_betas=10, n_betas_limbs=13, n_breeds=121, n_z=512, structure_z_to_betas='default'):
|
72 |
+
super(ModelShapeAndBreed, self).__init__()
|
73 |
+
self.n_betas = n_betas
|
74 |
+
self.n_betas_limbs = n_betas_limbs # n_betas_logscale
|
75 |
+
self.n_breeds = n_breeds
|
76 |
+
self.structure_z_to_betas = structure_z_to_betas
|
77 |
+
if self.structure_z_to_betas == '1dconv':
|
78 |
+
if not (n_z == self.n_betas+self.n_betas_limbs):
|
79 |
+
raise ValueError
|
80 |
+
self.smal_model_type = smal_model_type
|
81 |
+
# shape branch
|
82 |
+
self.resnet = models.resnet34(pretrained=False)
|
83 |
+
# replace the first layer
|
84 |
+
n_in = 3 + 1
|
85 |
+
self.resnet.conv1 = nn.Conv2d(n_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
86 |
+
# replace the last layer
|
87 |
+
self.resnet.fc = nn.Linear(512, n_z)
|
88 |
+
# softmax
|
89 |
+
self.soft_max = torch.nn.Softmax(dim=1)
|
90 |
+
# fc network (and other versions) to connect z with betas
|
91 |
+
p_dropout = 0.2
|
92 |
+
if self.structure_z_to_betas == 'default':
|
93 |
+
self.linear_betas = LinearModel(linear_size=1024,
|
94 |
+
num_stage=1,
|
95 |
+
p_dropout=p_dropout,
|
96 |
+
input_size=n_z,
|
97 |
+
output_size=self.n_betas)
|
98 |
+
self.linear_betas_limbs = LinearModel(linear_size=1024,
|
99 |
+
num_stage=1,
|
100 |
+
p_dropout=p_dropout,
|
101 |
+
input_size=n_z,
|
102 |
+
output_size=self.n_betas_limbs)
|
103 |
+
elif self.structure_z_to_betas == 'lin':
|
104 |
+
self.linear_betas = nn.Linear(n_z, self.n_betas)
|
105 |
+
self.linear_betas_limbs = nn.Linear(n_z, self.n_betas_limbs)
|
106 |
+
elif self.structure_z_to_betas == 'fc_0':
|
107 |
+
self.linear_betas = SmallLinear(linear_size=128, # 1024,
|
108 |
+
input_size=n_z,
|
109 |
+
output_size=self.n_betas)
|
110 |
+
self.linear_betas_limbs = SmallLinear(linear_size=128, # 1024,
|
111 |
+
input_size=n_z,
|
112 |
+
output_size=self.n_betas_limbs)
|
113 |
+
elif structure_z_to_betas == 'fc_1':
|
114 |
+
self.linear_betas = LinearModel(linear_size=64, # 1024,
|
115 |
+
num_stage=1,
|
116 |
+
p_dropout=0,
|
117 |
+
input_size=n_z,
|
118 |
+
output_size=self.n_betas)
|
119 |
+
self.linear_betas_limbs = LinearModel(linear_size=64, # 1024,
|
120 |
+
num_stage=1,
|
121 |
+
p_dropout=0,
|
122 |
+
input_size=n_z,
|
123 |
+
output_size=self.n_betas_limbs)
|
124 |
+
elif self.structure_z_to_betas == '1dconv':
|
125 |
+
self.linear_betas = MyConv1d(n_z, self.n_betas, start=True)
|
126 |
+
self.linear_betas_limbs = MyConv1d(n_z, self.n_betas_limbs, start=False)
|
127 |
+
elif self.structure_z_to_betas == 'inn':
|
128 |
+
self.linear_betas_and_betas_limbs = INNForShape(self.n_betas, self.n_betas_limbs, betas_scale=1.0, betas_limbs_scale=1.0)
|
129 |
+
else:
|
130 |
+
raise ValueError
|
131 |
+
# network to connect latent shape vector z with dog breed classification
|
132 |
+
self.linear_breeds = LinearModel(linear_size=1024, # 1024,
|
133 |
+
num_stage=1,
|
134 |
+
p_dropout=p_dropout,
|
135 |
+
input_size=n_z,
|
136 |
+
output_size=self.n_breeds)
|
137 |
+
# shape multiplicator
|
138 |
+
self.shape_multiplicator_np = np.ones(self.n_betas)
|
139 |
+
with open(SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path'], 'rb') as file:
|
140 |
+
u = pkl._Unpickler(file)
|
141 |
+
u.encoding = 'latin1'
|
142 |
+
res = u.load()
|
143 |
+
# shape predictions are centered around the mean dog of our dog model
|
144 |
+
if 'dog_cluster_mean' in res.keys():
|
145 |
+
self.betas_mean_np = res['dog_cluster_mean']
|
146 |
+
else:
|
147 |
+
assert res['cluster_means'].shape[0]==1
|
148 |
+
self.betas_mean_np = res['cluster_means'][0, :]
|
149 |
+
|
150 |
+
|
151 |
+
def forward(self, img, seg_raw=None, seg_prep=None):
|
152 |
+
# img is the network input image
|
153 |
+
# seg_raw is before softmax and subtracting 0.5
|
154 |
+
# seg_prep would be the prepared_segmentation
|
155 |
+
if seg_prep is None:
|
156 |
+
seg_prep = self.soft_max(seg_raw)[:, 1:2, :, :] - 0.5
|
157 |
+
input_img_and_seg = torch.cat((img, seg_prep), axis=1)
|
158 |
+
res_output = self.resnet(input_img_and_seg)
|
159 |
+
dog_breed_output = self.linear_breeds(res_output)
|
160 |
+
if self.structure_z_to_betas == 'inn':
|
161 |
+
shape_output_orig, shape_limbs_output_orig = self.linear_betas_and_betas_limbs(res_output)
|
162 |
+
else:
|
163 |
+
shape_output_orig = self.linear_betas(res_output) * 0.1
|
164 |
+
betas_mean = torch.tensor(self.betas_mean_np).float().to(img.device)
|
165 |
+
shape_output = shape_output_orig + betas_mean[None, 0:self.n_betas]
|
166 |
+
shape_limbs_output_orig = self.linear_betas_limbs(res_output)
|
167 |
+
shape_limbs_output = shape_limbs_output_orig * 0.1
|
168 |
+
output_dict = {'z': res_output,
|
169 |
+
'breeds': dog_breed_output,
|
170 |
+
'betas': shape_output_orig,
|
171 |
+
'betas_limbs': shape_limbs_output_orig}
|
172 |
+
return output_dict
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
class LearnableShapedirs(nn.Module):
|
177 |
+
def __init__(self, sym_ids_dict, shapedirs_init, n_betas, n_betas_fixed=10):
|
178 |
+
super(LearnableShapedirs, self).__init__()
|
179 |
+
# shapedirs_init = self.smal.shapedirs.detach()
|
180 |
+
self.n_betas = n_betas
|
181 |
+
self.n_betas_fixed = n_betas_fixed
|
182 |
+
self.sym_ids_dict = sym_ids_dict
|
183 |
+
sym_left_ids = self.sym_ids_dict['left']
|
184 |
+
sym_right_ids = self.sym_ids_dict['right']
|
185 |
+
sym_center_ids = self.sym_ids_dict['center']
|
186 |
+
self.n_center = sym_center_ids.shape[0]
|
187 |
+
self.n_left = sym_left_ids.shape[0]
|
188 |
+
self.n_sd = self.n_betas - self.n_betas_fixed # number of learnable shapedirs
|
189 |
+
# get indices to go from half_shapedirs to shapedirs
|
190 |
+
inds_back = np.zeros((3889))
|
191 |
+
for ind in range(0, sym_center_ids.shape[0]):
|
192 |
+
ind_in_forward = sym_center_ids[ind]
|
193 |
+
inds_back[ind_in_forward] = ind
|
194 |
+
for ind in range(0, sym_left_ids.shape[0]):
|
195 |
+
ind_in_forward = sym_left_ids[ind]
|
196 |
+
inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind
|
197 |
+
for ind in range(0, sym_right_ids.shape[0]):
|
198 |
+
ind_in_forward = sym_right_ids[ind]
|
199 |
+
inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind
|
200 |
+
self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long())
|
201 |
+
# self.smal.shapedirs: (51, 11667)
|
202 |
+
# shapedirs: (3889, 3, n_sd)
|
203 |
+
# shapedirs_half: (2012, 3, n_sd)
|
204 |
+
sd = shapedirs_init[:self.n_betas, :].permute((1, 0)).reshape((-1, 3, self.n_betas))
|
205 |
+
self.register_buffer('sd', sd)
|
206 |
+
sd_center = sd[sym_center_ids, :, self.n_betas_fixed:]
|
207 |
+
sd_left = sd[sym_left_ids, :, self.n_betas_fixed:]
|
208 |
+
self.register_parameter('learnable_half_shapedirs_c0', torch.nn.Parameter(sd_center[:, 0, :].detach()))
|
209 |
+
self.register_parameter('learnable_half_shapedirs_c2', torch.nn.Parameter(sd_center[:, 2, :].detach()))
|
210 |
+
self.register_parameter('learnable_half_shapedirs_l0', torch.nn.Parameter(sd_left[:, 0, :].detach()))
|
211 |
+
self.register_parameter('learnable_half_shapedirs_l1', torch.nn.Parameter(sd_left[:, 1, :].detach()))
|
212 |
+
self.register_parameter('learnable_half_shapedirs_l2', torch.nn.Parameter(sd_left[:, 2, :].detach()))
|
213 |
+
def forward(self):
|
214 |
+
device = self.learnable_half_shapedirs_c0.device
|
215 |
+
half_shapedirs_center = torch.stack((self.learnable_half_shapedirs_c0, \
|
216 |
+
torch.zeros((self.n_center, self.n_sd)).to(device), \
|
217 |
+
self.learnable_half_shapedirs_c2), axis=1)
|
218 |
+
half_shapedirs_left = torch.stack((self.learnable_half_shapedirs_l0, \
|
219 |
+
self.learnable_half_shapedirs_l1, \
|
220 |
+
self.learnable_half_shapedirs_l2), axis=1)
|
221 |
+
half_shapedirs_right = torch.stack((self.learnable_half_shapedirs_l0, \
|
222 |
+
- self.learnable_half_shapedirs_l1, \
|
223 |
+
self.learnable_half_shapedirs_l2), axis=1)
|
224 |
+
half_shapedirs_tot = torch.cat((half_shapedirs_center, half_shapedirs_left, half_shapedirs_right))
|
225 |
+
shapedirs = torch.index_select(half_shapedirs_tot, dim=0, index=self.inds_back_torch)
|
226 |
+
shapedirs_complete = torch.cat((self.sd[:, :, :self.n_betas_fixed], shapedirs), axis=2) # (3889, 3, n_sd)
|
227 |
+
shapedirs_complete_prepared = torch.cat((self.sd[:, :, :10], shapedirs), axis=2).reshape((-1, 30)).permute((1, 0)) # (n_sd, 11667)
|
228 |
+
return shapedirs_complete, shapedirs_complete_prepared
|
229 |
+
|
230 |
+
|
231 |
+
class ModelRefinement(nn.Module):
|
232 |
+
def __init__(self, n_betas=10, n_betas_limbs=7, n_breeds=121, n_keyp=20, n_joints=35, ref_net_type='add', graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'):
|
233 |
+
super(ModelRefinement, self).__init__()
|
234 |
+
self.n_betas = n_betas
|
235 |
+
self.n_betas_limbs = n_betas_limbs
|
236 |
+
self.n_breeds = n_breeds
|
237 |
+
self.n_keyp = n_keyp
|
238 |
+
self.n_joints = n_joints
|
239 |
+
self.n_out_seg = 256
|
240 |
+
self.n_out_keyp = 256
|
241 |
+
self.n_out_enc = 256
|
242 |
+
self.linear_size = 1024
|
243 |
+
self.linear_size_small = 128
|
244 |
+
self.ref_net_type = ref_net_type
|
245 |
+
self.graphcnn_type = graphcnn_type
|
246 |
+
self.isflat_type = isflat_type
|
247 |
+
self.shaperef_type = shaperef_type
|
248 |
+
p_dropout = 0.2
|
249 |
+
# --- segmentation encoder
|
250 |
+
if self.ref_net_type in ['multrot_res34', 'multrot01all_res34']:
|
251 |
+
self.ref_res = models.resnet34(pretrained=False)
|
252 |
+
else:
|
253 |
+
self.ref_res = models.resnet18(pretrained=False)
|
254 |
+
# replace the first layer
|
255 |
+
self.ref_res.conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
256 |
+
# replace the last layer
|
257 |
+
self.ref_res.fc = nn.Linear(512, self.n_out_seg)
|
258 |
+
# softmax
|
259 |
+
self.soft_max = torch.nn.Softmax(dim=1)
|
260 |
+
# --- keypoint encoder
|
261 |
+
self.linear_keyp = LinearModel(linear_size=self.linear_size,
|
262 |
+
num_stage=1,
|
263 |
+
p_dropout=p_dropout,
|
264 |
+
input_size=n_keyp*2*2,
|
265 |
+
output_size=self.n_out_keyp)
|
266 |
+
# --- decoder
|
267 |
+
self.linear_combined = LinearModel(linear_size=self.linear_size,
|
268 |
+
num_stage=1,
|
269 |
+
p_dropout=p_dropout,
|
270 |
+
input_size=self.n_out_seg+self.n_out_keyp,
|
271 |
+
output_size=self.n_out_enc)
|
272 |
+
# output info
|
273 |
+
pose = {'name': 'pose', 'n': self.n_joints*6, 'out_shape':[self.n_joints, 6]}
|
274 |
+
trans = {'name': 'trans_notnorm', 'n': 3}
|
275 |
+
cam = {'name': 'flength_notnorm', 'n': 1}
|
276 |
+
betas = {'name': 'betas', 'n': self.n_betas}
|
277 |
+
betas_limbs = {'name': 'betas_limbs', 'n': self.n_betas_limbs}
|
278 |
+
if self.shaperef_type=='inexistent':
|
279 |
+
self.output_info = [pose, trans, cam] # , betas]
|
280 |
+
else:
|
281 |
+
self.output_info = [pose, trans, cam, betas, betas_limbs]
|
282 |
+
# output branches
|
283 |
+
self.output_info_linear_models = []
|
284 |
+
for ind_el, element in enumerate(self.output_info):
|
285 |
+
n_in = self.n_out_enc + element['n']
|
286 |
+
self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size,
|
287 |
+
num_stage=1,
|
288 |
+
p_dropout=p_dropout,
|
289 |
+
input_size=n_in,
|
290 |
+
output_size=element['n']))
|
291 |
+
element['linear_model_index'] = ind_el
|
292 |
+
self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models)
|
293 |
+
# new: predict if the ground is flat
|
294 |
+
if not self.isflat_type=='inexistent':
|
295 |
+
self.linear_isflat = LinearModel(linear_size=self.linear_size_small,
|
296 |
+
num_stage=1,
|
297 |
+
p_dropout=p_dropout,
|
298 |
+
input_size=self.n_out_enc,
|
299 |
+
output_size=2) # answer is just yes or no
|
300 |
+
|
301 |
+
|
302 |
+
# new for ground contact prediction: graph cnn
|
303 |
+
if not self.graphcnn_type=='inexistent':
|
304 |
+
num_downsampling = 1
|
305 |
+
smal_model_type = '39dogs_norm'
|
306 |
+
smal = SMAL(smal_model_type=smal_model_type, template_name='neutral')
|
307 |
+
ROOT_smal_downsampling = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/'
|
308 |
+
smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path']).replace('.pkl', '_template.npz')
|
309 |
+
smal_downsampling_npz_path = ROOT_smal_downsampling + smal_downsampling_npz_name # 'data/mesh_downsampling.npz'
|
310 |
+
self.my_custom_smal_dog_mesh = Mesh(filename=smal_downsampling_npz_path, num_downsampling=num_downsampling, nsize=1, body_model=smal) # , device=device)
|
311 |
+
# create GraphCNN
|
312 |
+
num_layers = 2 # <= len(my_custom_mesh._A)-1
|
313 |
+
n_resnet_out = self.n_out_enc # 256
|
314 |
+
num_channels = 256 # 512
|
315 |
+
self.graph_cnn = GraphCNNMS(mesh=self.my_custom_smal_dog_mesh,
|
316 |
+
num_downsample = num_downsampling,
|
317 |
+
num_layers = num_layers,
|
318 |
+
n_resnet_out = n_resnet_out,
|
319 |
+
num_channels = num_channels) # .to(device)
|
320 |
+
|
321 |
+
|
322 |
+
|
323 |
+
def forward(self, keyp_sh, keyp_pred, in_pose_3x3, in_trans_notnorm, in_cam_notnorm, in_betas, in_betas_limbs, seg_pred_prep=None, seg_sh_raw=None, seg_sh_prep=None):
|
324 |
+
# img is the network input image
|
325 |
+
# seg_raw is before softmax and subtracting 0.5
|
326 |
+
# seg_prep would be the prepared_segmentation
|
327 |
+
batch_size = in_pose_3x3.shape[0]
|
328 |
+
device = in_pose_3x3.device
|
329 |
+
dtype = in_pose_3x3.dtype
|
330 |
+
# --- segmentation encoder
|
331 |
+
if seg_sh_prep is None:
|
332 |
+
seg_sh_prep = self.soft_max(seg_sh_raw)[:, 1:2, :, :] - 0.5 # class 1 is the dog
|
333 |
+
input_seg_conc = torch.cat((seg_sh_prep, seg_pred_prep), axis=1)
|
334 |
+
network_output_seg = self.ref_res(input_seg_conc)
|
335 |
+
# --- keypoint encoder
|
336 |
+
keyp_conc = torch.cat((keyp_sh.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2])), keyp_pred.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2]))), axis=1)
|
337 |
+
network_output_keyp = self.linear_keyp(keyp_conc)
|
338 |
+
# --- decoder
|
339 |
+
x = torch.cat((network_output_seg, network_output_keyp), axis=1)
|
340 |
+
y_comb = self.linear_combined(x)
|
341 |
+
in_pose_6d = rotmat_to_rot6d(in_pose_3x3.reshape((-1, 3, 3))).reshape((in_pose_3x3.shape[0], -1, 6))
|
342 |
+
in_dict = {'pose': in_pose_6d,
|
343 |
+
'trans_notnorm': in_trans_notnorm,
|
344 |
+
'flength_notnorm': in_cam_notnorm,
|
345 |
+
'betas': in_betas,
|
346 |
+
'betas_limbs': in_betas_limbs}
|
347 |
+
results = {}
|
348 |
+
for element in self.output_info:
|
349 |
+
# import pdb; pdb.set_trace()
|
350 |
+
|
351 |
+
linear_model = self.output_info_linear_models[element['linear_model_index']]
|
352 |
+
y = torch.cat((y_comb, in_dict[element['name']].reshape((-1, element['n']))), axis=1)
|
353 |
+
if 'out_shape' in element.keys():
|
354 |
+
if element['name'] == 'pose':
|
355 |
+
if self.ref_net_type in ['multrot', 'multrot01', 'multrot01all', 'multrotxx', 'multrot_res34', 'multrot01all_res34']: # if self.ref_net_type == 'multrot' or self.ref_net_type == 'multrot_res34':
|
356 |
+
# multiply the rotations with each other -> just predict a correction
|
357 |
+
# the correction should be initialized as identity
|
358 |
+
# res_pose_out = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']]
|
359 |
+
identity_rot6d = torch.tensor(([1., 0., 0., 1., 0., 0.])).repeat((in_pose_3x3.shape[0]*in_pose_3x3.shape[1], 1)).to(device=device, dtype=dtype)
|
360 |
+
if self.ref_net_type in ['multrot01', 'multrot01all', 'multrot01all_res34']:
|
361 |
+
res_pose_out = identity_rot6d + 0.1*(linear_model(y)).reshape((-1, element['out_shape'][1]))
|
362 |
+
elif self.ref_net_type == 'multrotxx':
|
363 |
+
res_pose_out = identity_rot6d + 0.0*(linear_model(y)).reshape((-1, element['out_shape'][1]))
|
364 |
+
else:
|
365 |
+
res_pose_out = identity_rot6d + (linear_model(y)).reshape((-1, element['out_shape'][1]))
|
366 |
+
res_pose_rotmat = rot6d_to_rotmat(res_pose_out.reshape((-1, 6))) # (bs*35, 3, 3) .reshape((batch_size, -1, 3, 3))
|
367 |
+
res_tot_rotmat = torch.bmm(res_pose_rotmat.reshape((-1, 3, 3)), in_pose_3x3.reshape((-1, 3, 3))).reshape((batch_size, -1, 3, 3)) # (bs, 5, 3, 3)
|
368 |
+
results['pose_rotmat'] = res_tot_rotmat
|
369 |
+
elif self.ref_net_type == 'add':
|
370 |
+
res_6d = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict['pose']
|
371 |
+
results['pose_rotmat'] = rot6d_to_rotmat(res_6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3))
|
372 |
+
else:
|
373 |
+
raise ValueError
|
374 |
+
else:
|
375 |
+
if self.ref_net_type in ['multrot01all', 'multrot01all_res34']:
|
376 |
+
results[element['name']] = (0.1*linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']]
|
377 |
+
else:
|
378 |
+
results[element['name']] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']]
|
379 |
+
else:
|
380 |
+
if self.ref_net_type in ['multrot01all', 'multrot01all_res34']:
|
381 |
+
results[element['name']] = 0.1*linear_model(y) + in_dict[element['name']]
|
382 |
+
else:
|
383 |
+
results[element['name']] = linear_model(y) + in_dict[element['name']]
|
384 |
+
|
385 |
+
# add prediction if ground is flat
|
386 |
+
if not self.isflat_type=='inexistent':
|
387 |
+
isflat = self.linear_isflat(y_comb)
|
388 |
+
results['isflat'] = isflat
|
389 |
+
|
390 |
+
# add graph cnn
|
391 |
+
if not self.graphcnn_type=='inexistent':
|
392 |
+
ground_contact_downsampled, ground_cntact_all_stages_output = self.graph_cnn(y_comb)
|
393 |
+
ground_contact = self.my_custom_smal_dog_mesh.upsample(ground_contact_downsampled.transpose(1,2))
|
394 |
+
results['vertexwise_ground_contact'] = ground_contact
|
395 |
+
|
396 |
+
return results
|
397 |
+
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
class ModelImageToBreed(nn.Module):
|
402 |
+
def __init__(self, smal_model_type, arch='hg8', n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=7, n_breeds=121, image_size=256, n_z=512, thr_keyp_sc=None, add_partseg=True):
|
403 |
+
super(ModelImageToBreed, self).__init__()
|
404 |
+
self.n_classes = n_classes
|
405 |
+
self.n_partseg = n_partseg
|
406 |
+
self.n_betas = n_betas
|
407 |
+
self.n_betas_limbs = n_betas_limbs
|
408 |
+
self.n_keyp = n_keyp
|
409 |
+
self.n_bones = n_bones
|
410 |
+
self.n_breeds = n_breeds
|
411 |
+
self.image_size = image_size
|
412 |
+
self.upsample_seg = True
|
413 |
+
self.threshold_scores = thr_keyp_sc
|
414 |
+
self.n_z = n_z
|
415 |
+
self.add_partseg = add_partseg
|
416 |
+
self.smal_model_type = smal_model_type
|
417 |
+
# ------------------------------ STACKED HOUR GLASS ------------------------------
|
418 |
+
if arch == 'hg8':
|
419 |
+
self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg)
|
420 |
+
else:
|
421 |
+
raise Exception('unrecognised model architecture: ' + arch)
|
422 |
+
# ------------------------------ SHAPE AND BREED MODEL ------------------------------
|
423 |
+
self.breed_model = ModelShapeAndBreed(smal_model_type=self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z)
|
424 |
+
def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
|
425 |
+
batch_size = input_img.shape[0]
|
426 |
+
device = input_img.device
|
427 |
+
# ------------------------------ STACKED HOUR GLASS ------------------------------
|
428 |
+
hourglass_out_dict = self.stacked_hourglass(input_img)
|
429 |
+
last_seg = hourglass_out_dict['seg_final']
|
430 |
+
last_heatmap = hourglass_out_dict['out_list_kp'][-1]
|
431 |
+
# - prepare keypoints (from heatmap)
|
432 |
+
# normalize predictions -> from logits to probability distribution
|
433 |
+
# last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1))
|
434 |
+
# keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
|
435 |
+
# keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2)
|
436 |
+
keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True)
|
437 |
+
if self.threshold_scores is not None:
|
438 |
+
scores[scores>self.threshold_scores] = 1.0
|
439 |
+
scores[scores<=self.threshold_scores] = 0.0
|
440 |
+
# ------------------------------ SHAPE AND BREED MODEL ------------------------------
|
441 |
+
# breed_model takes as input the image as well as the predicted segmentation map
|
442 |
+
# -> we need to split up ModelImageTo3d, such that we can use the silhouette
|
443 |
+
resnet_output = self.breed_model(img=input_img, seg_raw=last_seg)
|
444 |
+
pred_breed = resnet_output['breeds'] # (bs, n_breeds)
|
445 |
+
pred_betas = resnet_output['betas']
|
446 |
+
pred_betas_limbs = resnet_output['betas_limbs']
|
447 |
+
small_output = {'keypoints_norm': keypoints_norm,
|
448 |
+
'keypoints_scores': scores}
|
449 |
+
small_output_reproj = {'betas': pred_betas,
|
450 |
+
'betas_limbs': pred_betas_limbs,
|
451 |
+
'dog_breed': pred_breed}
|
452 |
+
return small_output, None, small_output_reproj
|
453 |
+
|
454 |
+
class ModelImageTo3d_withshape_withproj(nn.Module):
|
455 |
+
def __init__(self, smal_model_type, smal_keyp_conf=None, arch='hg8', num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=6, n_breeds=121, image_size=256, n_z=512, n_segbps=64*2, thr_keyp_sc=None, add_z_to_3d_input=True, add_segbps_to_3d_input=False, add_partseg=True, silh_no_tail=True, fix_flength=False, render_partseg=False, structure_z_to_betas='default', structure_pose_net='default', nf_version=None, ref_net_type='add', ref_detach_shape=True, graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'):
|
456 |
+
super(ModelImageTo3d_withshape_withproj, self).__init__()
|
457 |
+
self.n_classes = n_classes
|
458 |
+
self.n_partseg = n_partseg
|
459 |
+
self.n_betas = n_betas
|
460 |
+
self.n_betas_limbs = n_betas_limbs
|
461 |
+
self.n_keyp = n_keyp
|
462 |
+
self.n_joints = n_joints
|
463 |
+
self.n_bones = n_bones
|
464 |
+
self.n_breeds = n_breeds
|
465 |
+
self.image_size = image_size
|
466 |
+
self.threshold_scores = thr_keyp_sc
|
467 |
+
self.upsample_seg = True
|
468 |
+
self.silh_no_tail = silh_no_tail
|
469 |
+
self.add_z_to_3d_input = add_z_to_3d_input
|
470 |
+
self.add_segbps_to_3d_input = add_segbps_to_3d_input
|
471 |
+
self.add_partseg = add_partseg
|
472 |
+
self.ref_net_type = ref_net_type
|
473 |
+
self.ref_detach_shape = ref_detach_shape
|
474 |
+
self.graphcnn_type = graphcnn_type
|
475 |
+
self.isflat_type = isflat_type
|
476 |
+
self.shaperef_type = shaperef_type
|
477 |
+
assert (not self.add_segbps_to_3d_input) or (not self.add_z_to_3d_input)
|
478 |
+
self.n_z = n_z
|
479 |
+
if add_segbps_to_3d_input:
|
480 |
+
self.n_segbps = n_segbps # 64
|
481 |
+
self.segbps_model = SegBPS()
|
482 |
+
else:
|
483 |
+
self.n_segbps = 0
|
484 |
+
self.fix_flength = fix_flength
|
485 |
+
self.render_partseg = render_partseg
|
486 |
+
self.structure_z_to_betas = structure_z_to_betas
|
487 |
+
self.structure_pose_net = structure_pose_net
|
488 |
+
assert self.structure_pose_net in ['default', 'vae', 'normflow']
|
489 |
+
self.nf_version = nf_version
|
490 |
+
self.smal_model_type = smal_model_type
|
491 |
+
assert (smal_keyp_conf is not None)
|
492 |
+
self.smal_keyp_conf = smal_keyp_conf
|
493 |
+
self.register_buffer('betas_zeros', torch.zeros((1, self.n_betas)))
|
494 |
+
self.register_buffer('mean_dog_bone_lengths', torch.tensor(MEAN_DOG_BONE_LENGTHS_NO_RED, dtype=torch.float32))
|
495 |
+
p_dropout = 0.2 # 0.5
|
496 |
+
# ------------------------------ SMAL MODEL ------------------------------
|
497 |
+
self.smal = SMAL(smal_model_type=self.smal_model_type, template_name='neutral')
|
498 |
+
print('SMAL model type: ' + self.smal.smal_model_type)
|
499 |
+
# New for rendering without tail
|
500 |
+
f_np = self.smal.faces.detach().cpu().numpy()
|
501 |
+
self.f_no_tail_np = f_np[np.isin(f_np[:,:], VERTEX_IDS_TAIL).sum(axis=1)==0, :]
|
502 |
+
# in theory we could optimize for improved shapedirs, but we do not do that
|
503 |
+
# -> would need to implement regularizations
|
504 |
+
# -> there are better ways than changing the shapedirs
|
505 |
+
self.model_learnable_shapedirs = LearnableShapedirs(self.smal.sym_ids_dict, self.smal.shapedirs.detach(), self.n_betas, 10)
|
506 |
+
# ------------------------------ STACKED HOUR GLASS ------------------------------
|
507 |
+
if arch == 'hg8':
|
508 |
+
self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg)
|
509 |
+
else:
|
510 |
+
raise Exception('unrecognised model architecture: ' + arch)
|
511 |
+
# ------------------------------ SHAPE AND BREED MODEL ------------------------------
|
512 |
+
self.breed_model = ModelShapeAndBreed(self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z, structure_z_to_betas=self.structure_z_to_betas)
|
513 |
+
# ------------------------------ LINEAR 3D MODEL ------------------------------
|
514 |
+
# 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength}
|
515 |
+
self.soft_max = torch.nn.Softmax(dim=1)
|
516 |
+
input_size = self.n_keyp*3 + self.n_bones
|
517 |
+
self.model_3d = LinearModelComplete(linear_size=1024,
|
518 |
+
num_stage_comb=num_stage_comb,
|
519 |
+
num_stage_heads=num_stage_heads,
|
520 |
+
num_stage_heads_pose=num_stage_heads_pose,
|
521 |
+
trans_sep=trans_sep,
|
522 |
+
p_dropout=p_dropout, # 0.5,
|
523 |
+
input_size=input_size,
|
524 |
+
intermediate_size=1024,
|
525 |
+
output_info=None,
|
526 |
+
n_joints=self.n_joints,
|
527 |
+
n_z=self.n_z,
|
528 |
+
add_z_to_3d_input=self.add_z_to_3d_input,
|
529 |
+
n_segbps=self.n_segbps,
|
530 |
+
add_segbps_to_3d_input=self.add_segbps_to_3d_input,
|
531 |
+
structure_pose_net=self.structure_pose_net,
|
532 |
+
nf_version = self.nf_version)
|
533 |
+
# ------------------------------ RENDERING ------------------------------
|
534 |
+
self.silh_renderer = SilhRenderer(image_size)
|
535 |
+
# ------------------------------ REFINEMENT -----------------------------
|
536 |
+
self.refinement_model = ModelRefinement(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_keyp=self.n_keyp, n_joints=self.n_joints, ref_net_type=self.ref_net_type, graphcnn_type=self.graphcnn_type, isflat_type=self.isflat_type, shaperef_type=self.shaperef_type)
|
537 |
+
|
538 |
+
|
539 |
+
def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
|
540 |
+
batch_size = input_img.shape[0]
|
541 |
+
device = input_img.device
|
542 |
+
# ------------------------------ STACKED HOUR GLASS ------------------------------
|
543 |
+
hourglass_out_dict = self.stacked_hourglass(input_img)
|
544 |
+
last_seg = hourglass_out_dict['seg_final']
|
545 |
+
last_heatmap = hourglass_out_dict['out_list_kp'][-1]
|
546 |
+
# - prepare keypoints (from heatmap)
|
547 |
+
# normalize predictions -> from logits to probability distribution
|
548 |
+
# last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1))
|
549 |
+
# keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2)
|
550 |
+
# keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2)
|
551 |
+
keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True)
|
552 |
+
if self.threshold_scores is not None:
|
553 |
+
scores[scores>self.threshold_scores] = 1.0
|
554 |
+
scores[scores<=self.threshold_scores] = 0.0
|
555 |
+
# ------------------------------ LEARNABLE SHAPE MODEL ------------------------------
|
556 |
+
# in our cvpr 2022 paper we do not change the shapedirs
|
557 |
+
# learnable_sd_complete has shape (3889, 3, n_sd)
|
558 |
+
# learnable_sd_complete_prepared has shape (n_sd, 11667)
|
559 |
+
learnable_sd_complete, learnable_sd_complete_prepared = self.model_learnable_shapedirs()
|
560 |
+
shapedirs_sel = learnable_sd_complete_prepared # None
|
561 |
+
# ------------------------------ SHAPE AND BREED MODEL ------------------------------
|
562 |
+
# breed_model takes as input the image as well as the predicted segmentation map
|
563 |
+
# -> we need to split up ModelImageTo3d, such that we can use the silhouette
|
564 |
+
resnet_output = self.breed_model(img=input_img, seg_raw=last_seg)
|
565 |
+
pred_breed = resnet_output['breeds'] # (bs, n_breeds)
|
566 |
+
pred_z = resnet_output['z']
|
567 |
+
# - prepare shape
|
568 |
+
pred_betas = resnet_output['betas']
|
569 |
+
pred_betas_limbs = resnet_output['betas_limbs']
|
570 |
+
# - calculate bone lengths
|
571 |
+
with torch.no_grad():
|
572 |
+
use_mean_bone_lengths = False
|
573 |
+
if use_mean_bone_lengths:
|
574 |
+
bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))])
|
575 |
+
else:
|
576 |
+
assert (bone_lengths_prepared is None)
|
577 |
+
bone_lengths_prepared = self.smal.caclulate_bone_lengths(pred_betas, pred_betas_limbs, shapedirs_sel=shapedirs_sel, short=True)
|
578 |
+
# ------------------------------ LINEAR 3D MODEL ------------------------------
|
579 |
+
# 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength}
|
580 |
+
# prepare input for 2d-to-3d network
|
581 |
+
keypoints_prepared = torch.cat((keypoints_norm, scores), axis=2)
|
582 |
+
if bone_lengths_prepared is None:
|
583 |
+
bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))])
|
584 |
+
# should we add silhouette to 3d input? should we add z?
|
585 |
+
if self.add_segbps_to_3d_input:
|
586 |
+
seg_raw = last_seg
|
587 |
+
seg_prep_bps = self.soft_max(seg_raw)[:, 1, :, :] # class 1 is the dog
|
588 |
+
with torch.no_grad():
|
589 |
+
seg_prep_np = seg_prep_bps.detach().cpu().numpy()
|
590 |
+
bps_output_np = self.segbps_model.calculate_bps_points_batch(seg_prep_np) # (bs, 64, 2)
|
591 |
+
bps_output = torch.tensor(bps_output_np, dtype=torch.float32).to(device).reshape((batch_size, -1))
|
592 |
+
bps_output_prep = bps_output * 2. - 1
|
593 |
+
input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
|
594 |
+
input_vec = torch.cat((input_vec_keyp_bones, bps_output_prep), dim=1)
|
595 |
+
elif self.add_z_to_3d_input:
|
596 |
+
# we do not use this in our cvpr 2022 version
|
597 |
+
input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
|
598 |
+
input_vec_additional = pred_z
|
599 |
+
input_vec = torch.cat((input_vec_keyp_bones, input_vec_additional), dim=1)
|
600 |
+
else:
|
601 |
+
input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1)
|
602 |
+
# predict 3d parameters (those are normalized, we need to correct mean and std in a next step)
|
603 |
+
output = self.model_3d(input_vec)
|
604 |
+
# add predicted keypoints to the output dict
|
605 |
+
output['keypoints_norm'] = keypoints_norm
|
606 |
+
output['keypoints_scores'] = scores
|
607 |
+
# add predicted segmentation to output dictc
|
608 |
+
output['seg_hg'] = hourglass_out_dict['seg_final']
|
609 |
+
# - denormalize 3d parameters -> so far predictions were normalized, now we denormalize them again
|
610 |
+
pred_trans = output['trans'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3)
|
611 |
+
if self.structure_pose_net == 'default':
|
612 |
+
pred_pose_rot6d = output['pose'] + norm_dict['pose_rot6d_mean'][None, :]
|
613 |
+
elif self.structure_pose_net == 'normflow':
|
614 |
+
pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :])
|
615 |
+
pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :]
|
616 |
+
pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros
|
617 |
+
else:
|
618 |
+
pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :])
|
619 |
+
pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :]
|
620 |
+
pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros
|
621 |
+
pred_pose_reshx33 = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6)))
|
622 |
+
pred_pose = pred_pose_reshx33.reshape((batch_size, -1, 3, 3))
|
623 |
+
pred_pose_rot6d = rotmat_to_rot6d(pred_pose_reshx33).reshape((batch_size, -1, 6))
|
624 |
+
|
625 |
+
if self.fix_flength:
|
626 |
+
output['flength'] = torch.zeros_like(output['flength'])
|
627 |
+
pred_flength = torch.ones_like(output['flength'])*2100 # norm_dict['flength_mean'][None, :]
|
628 |
+
else:
|
629 |
+
pred_flength_orig = output['flength'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1)
|
630 |
+
pred_flength = pred_flength_orig.clone() # torch.abs(pred_flength_orig)
|
631 |
+
pred_flength[pred_flength_orig<=0] = norm_dict['flength_mean'][None, :]
|
632 |
+
|
633 |
+
# ------------------------------ RENDERING ------------------------------
|
634 |
+
# get 3d model (SMAL)
|
635 |
+
V, keyp_green_3d, _ = self.smal(beta=pred_betas, betas_limbs=pred_betas_limbs, pose=pred_pose, trans=pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=shapedirs_sel)
|
636 |
+
keyp_3d = keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3)
|
637 |
+
# render silhouette
|
638 |
+
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
|
639 |
+
if not self.silh_no_tail:
|
640 |
+
pred_silh_images, pred_keyp = self.silh_renderer(vertices=V,
|
641 |
+
points=keyp_3d, faces=faces_prep, focal_lengths=pred_flength)
|
642 |
+
else:
|
643 |
+
faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1))
|
644 |
+
pred_silh_images, pred_keyp = self.silh_renderer(vertices=V,
|
645 |
+
points=keyp_3d, faces=faces_no_tail_prep, focal_lengths=pred_flength)
|
646 |
+
# get torch 'Meshes'
|
647 |
+
torch_meshes = self.silh_renderer.get_torch_meshes(vertices=V, faces=faces_prep)
|
648 |
+
|
649 |
+
# render body parts (not part of cvpr 2022 version)
|
650 |
+
if self.render_partseg:
|
651 |
+
raise NotImplementedError
|
652 |
+
else:
|
653 |
+
partseg_images = None
|
654 |
+
partseg_images_hg = None
|
655 |
+
|
656 |
+
|
657 |
+
# ------------------------------ REFINEMENT MODEL ------------------------------
|
658 |
+
|
659 |
+
# refinement model
|
660 |
+
pred_keyp_norm = (pred_keyp.detach() / (self.image_size - 1) - 0.5)*2
|
661 |
+
'''output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \
|
662 |
+
seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \
|
663 |
+
in_pose=output['pose'].detach(), in_trans=output['trans'].detach(), in_cam=output['flength'].detach(), in_betas=pred_betas.detach())'''
|
664 |
+
output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \
|
665 |
+
seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \
|
666 |
+
in_pose_3x3=pred_pose.detach(), in_trans_notnorm=output['trans'].detach(), in_cam_notnorm=output['flength'].detach(), in_betas=pred_betas.detach(), in_betas_limbs=pred_betas_limbs.detach())
|
667 |
+
# a better alternative would be to submit pred_pose_reshx33
|
668 |
+
|
669 |
+
|
670 |
+
|
671 |
+
# nothing changes for betas or shapedirs or z ##################### should probably not be detached in the end
|
672 |
+
if self.shaperef_type == 'inexistent':
|
673 |
+
if self.ref_detach_shape:
|
674 |
+
output_ref['betas'] = pred_betas.detach()
|
675 |
+
output_ref['betas_limbs'] = pred_betas_limbs.detach()
|
676 |
+
output_ref['z'] = pred_z.detach()
|
677 |
+
output_ref['shapedirs'] = shapedirs_sel.detach()
|
678 |
+
else:
|
679 |
+
output_ref['betas'] = pred_betas
|
680 |
+
output_ref['betas_limbs'] = pred_betas_limbs
|
681 |
+
output_ref['z'] = pred_z
|
682 |
+
output_ref['shapedirs'] = shapedirs_sel
|
683 |
+
else:
|
684 |
+
assert ('betas' in output_ref.keys())
|
685 |
+
assert ('betas_limbs' in output_ref.keys())
|
686 |
+
output_ref['shapedirs'] = shapedirs_sel
|
687 |
+
|
688 |
+
|
689 |
+
# we denormalize flength and trans, but pose is handled differently
|
690 |
+
if self.fix_flength:
|
691 |
+
output_ref['flength_notnorm'] = torch.zeros_like(output['flength'])
|
692 |
+
ref_pred_flength = torch.ones_like(output['flength_notnorm'])*2100 # norm_dict['flength_mean'][None, :]
|
693 |
+
raise ValueError # not sure if we want to have a fixed flength in refinement
|
694 |
+
else:
|
695 |
+
ref_pred_flength_orig = output_ref['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1)
|
696 |
+
ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig)
|
697 |
+
ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :]
|
698 |
+
ref_pred_trans = output_ref['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3)
|
699 |
+
|
700 |
+
|
701 |
+
|
702 |
+
|
703 |
+
# ref_pred_pose_rot6d = output_ref['pose']
|
704 |
+
# ref_pred_pose_reshx33 = rot6d_to_rotmat(output_ref['pose'].reshape((-1, 6))).reshape((batch_size, -1, 3, 3))
|
705 |
+
ref_pred_pose_reshx33 = output_ref['pose_rotmat'].reshape((batch_size, -1, 3, 3))
|
706 |
+
ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6))
|
707 |
+
|
708 |
+
ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref['betas'], betas_limbs=output_ref['betas_limbs'],
|
709 |
+
pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf,
|
710 |
+
shapedirs_sel=output_ref['shapedirs'])
|
711 |
+
ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3)
|
712 |
+
|
713 |
+
if not self.silh_no_tail:
|
714 |
+
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
|
715 |
+
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
|
716 |
+
points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength)
|
717 |
+
else:
|
718 |
+
faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1))
|
719 |
+
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
|
720 |
+
points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength)
|
721 |
+
|
722 |
+
output_ref_unnorm = {'vertices_smal': ref_V,
|
723 |
+
'keyp_3d': ref_keyp_3d,
|
724 |
+
'keyp_2d': ref_pred_keyp,
|
725 |
+
'silh': ref_pred_silh_images,
|
726 |
+
'trans': ref_pred_trans,
|
727 |
+
'flength': ref_pred_flength,
|
728 |
+
'betas': output_ref['betas'],
|
729 |
+
'betas_limbs': output_ref['betas_limbs'],
|
730 |
+
# 'z': output_ref['z'],
|
731 |
+
'pose_rot6d': ref_pred_pose_rot6d,
|
732 |
+
'pose_rotmat': ref_pred_pose_reshx33}
|
733 |
+
# 'shapedirs': shapedirs_sel}
|
734 |
+
|
735 |
+
if not self.graphcnn_type == 'inexistent':
|
736 |
+
output_ref_unnorm['vertexwise_ground_contact'] = output_ref['vertexwise_ground_contact']
|
737 |
+
if not self.isflat_type=='inexistent':
|
738 |
+
output_ref_unnorm['isflat'] = output_ref['isflat']
|
739 |
+
if self.shaperef_type == 'inexistent':
|
740 |
+
output_ref_unnorm['z'] = output_ref['z']
|
741 |
+
|
742 |
+
# REMARK: we will want to have the predicted differences, for pose this would
|
743 |
+
# be a rotation matrix, ...
|
744 |
+
# -> TODO: adjust output_orig_ref_comparison
|
745 |
+
output_orig_ref_comparison = {#'pose': output['pose'].detach(),
|
746 |
+
#'trans': output['trans'].detach(),
|
747 |
+
#'flength': output['flength'].detach(),
|
748 |
+
# 'pose': output['pose'],
|
749 |
+
'old_pose_rotmat': pred_pose_reshx33,
|
750 |
+
'old_trans_notnorm': output['trans'],
|
751 |
+
'old_flength_notnorm': output['flength'],
|
752 |
+
# 'ref_pose': output_ref['pose'],
|
753 |
+
'ref_pose_rotmat': ref_pred_pose_reshx33,
|
754 |
+
'ref_trans_notnorm': output_ref['trans_notnorm'],
|
755 |
+
'ref_flength_notnorm': output_ref['flength_notnorm']}
|
756 |
+
|
757 |
+
|
758 |
+
|
759 |
+
# ------------------------------ PREPARE OUTPUT ------------------------------
|
760 |
+
# create output dictionarys
|
761 |
+
# output: contains all output from model_image_to_3d
|
762 |
+
# output_unnorm: same as output, but normalizations are undone
|
763 |
+
# output_reproj: smal output and reprojected keypoints as well as silhouette
|
764 |
+
keypoints_heatmap_256 = (output['keypoints_norm'] / 2. + 0.5) * (self.image_size - 1)
|
765 |
+
output_unnorm = {'pose_rotmat': pred_pose,
|
766 |
+
'flength': pred_flength,
|
767 |
+
'trans': pred_trans,
|
768 |
+
'keypoints':keypoints_heatmap_256}
|
769 |
+
output_reproj = {'vertices_smal': V,
|
770 |
+
'torch_meshes': torch_meshes,
|
771 |
+
'keyp_3d': keyp_3d,
|
772 |
+
'keyp_2d': pred_keyp,
|
773 |
+
'silh': pred_silh_images,
|
774 |
+
'betas': pred_betas,
|
775 |
+
'betas_limbs': pred_betas_limbs,
|
776 |
+
'pose_rot6d': pred_pose_rot6d, # used for pose prior...
|
777 |
+
'dog_breed': pred_breed,
|
778 |
+
'shapedirs': shapedirs_sel,
|
779 |
+
'z': pred_z,
|
780 |
+
'flength_unnorm': pred_flength,
|
781 |
+
'flength': output['flength'],
|
782 |
+
'partseg_images_rend': partseg_images,
|
783 |
+
'partseg_images_hg_nograd': partseg_images_hg,
|
784 |
+
'normflow_z': output['normflow_z']}
|
785 |
+
|
786 |
+
return output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison
|
787 |
+
|
788 |
+
|
789 |
+
def forward_with_multiple_refinements(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None):
|
790 |
+
|
791 |
+
# import pdb; pdb.set_trace()
|
792 |
+
|
793 |
+
# run normal network part
|
794 |
+
output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison = self.forward(input_img, norm_dict=norm_dict, bone_lengths_prepared=bone_lengths_prepared, betas=betas)
|
795 |
+
|
796 |
+
# prepare input for second refinement stage
|
797 |
+
batch_size = output['keypoints_norm'].shape[0]
|
798 |
+
keypoints_norm = output['keypoints_norm']
|
799 |
+
pred_keyp_norm = (output_ref_unnorm['keyp_2d'].detach() / (self.image_size - 1) - 0.5)*2
|
800 |
+
|
801 |
+
last_seg = output['seg_hg']
|
802 |
+
pred_silh_images = output_ref_unnorm['silh'].detach()
|
803 |
+
|
804 |
+
trans_notnorm = output_orig_ref_comparison['ref_trans_notnorm']
|
805 |
+
flength_notnorm = output_orig_ref_comparison['ref_flength_notnorm']
|
806 |
+
# trans_notnorm = output_orig_ref_comparison['ref_pose_rotmat']
|
807 |
+
pred_pose = output_ref_unnorm['pose_rotmat'].reshape((batch_size, -1, 3, 3))
|
808 |
+
|
809 |
+
# run second refinement step
|
810 |
+
output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \
|
811 |
+
seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \
|
812 |
+
in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), \
|
813 |
+
in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach())
|
814 |
+
# output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach())
|
815 |
+
|
816 |
+
|
817 |
+
# new shape
|
818 |
+
if self.shaperef_type == 'inexistent':
|
819 |
+
if self.ref_detach_shape:
|
820 |
+
output_ref_new['betas'] = output_ref_unnorm['betas'].detach()
|
821 |
+
output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs'].detach()
|
822 |
+
output_ref_new['z'] = output_ref_unnorm['z'].detach()
|
823 |
+
output_ref_new['shapedirs'] = output_reproj['shapedirs'].detach()
|
824 |
+
else:
|
825 |
+
output_ref_new['betas'] = output_ref_unnorm['betas']
|
826 |
+
output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs']
|
827 |
+
output_ref_new['z'] = output_ref_unnorm['z']
|
828 |
+
output_ref_new['shapedirs'] = output_reproj['shapedirs']
|
829 |
+
else:
|
830 |
+
assert ('betas' in output_ref_new.keys())
|
831 |
+
assert ('betas_limbs' in output_ref_new.keys())
|
832 |
+
output_ref_new['shapedirs'] = output_reproj['shapedirs']
|
833 |
+
|
834 |
+
# we denormalize flength and trans, but pose is handled differently
|
835 |
+
if self.fix_flength:
|
836 |
+
raise ValueError # not sure if we want to have a fixed flength in refinement
|
837 |
+
else:
|
838 |
+
ref_pred_flength_orig = output_ref_new['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1)
|
839 |
+
ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig)
|
840 |
+
ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :]
|
841 |
+
ref_pred_trans = output_ref_new['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3)
|
842 |
+
|
843 |
+
|
844 |
+
ref_pred_pose_reshx33 = output_ref_new['pose_rotmat'].reshape((batch_size, -1, 3, 3))
|
845 |
+
ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6))
|
846 |
+
|
847 |
+
ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'],
|
848 |
+
pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf,
|
849 |
+
shapedirs_sel=output_ref_new['shapedirs'])
|
850 |
+
|
851 |
+
# ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'], pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=output_ref_new['shapedirs'])
|
852 |
+
ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3)
|
853 |
+
|
854 |
+
if not self.silh_no_tail:
|
855 |
+
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
|
856 |
+
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
|
857 |
+
points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength)
|
858 |
+
else:
|
859 |
+
faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1))
|
860 |
+
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V,
|
861 |
+
points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength)
|
862 |
+
|
863 |
+
output_ref_unnorm_new = {'vertices_smal': ref_V,
|
864 |
+
'keyp_3d': ref_keyp_3d,
|
865 |
+
'keyp_2d': ref_pred_keyp,
|
866 |
+
'silh': ref_pred_silh_images,
|
867 |
+
'trans': ref_pred_trans,
|
868 |
+
'flength': ref_pred_flength,
|
869 |
+
'betas': output_ref_new['betas'],
|
870 |
+
'betas_limbs': output_ref_new['betas_limbs'],
|
871 |
+
'pose_rot6d': ref_pred_pose_rot6d,
|
872 |
+
'pose_rotmat': ref_pred_pose_reshx33}
|
873 |
+
|
874 |
+
if not self.graphcnn_type == 'inexistent':
|
875 |
+
output_ref_unnorm_new['vertexwise_ground_contact'] = output_ref_new['vertexwise_ground_contact']
|
876 |
+
if not self.isflat_type=='inexistent':
|
877 |
+
output_ref_unnorm_new['isflat'] = output_ref_new['isflat']
|
878 |
+
if self.shaperef_type == 'inexistent':
|
879 |
+
output_ref_unnorm_new['z'] = output_ref_new['z']
|
880 |
+
|
881 |
+
output_orig_ref_comparison_new = {'ref_pose_rotmat': ref_pred_pose_reshx33,
|
882 |
+
'ref_trans_notnorm': output_ref_new['trans_notnorm'],
|
883 |
+
'ref_flength_notnorm': output_ref_new['flength_notnorm']}
|
884 |
+
|
885 |
+
results = {
|
886 |
+
'output': output,
|
887 |
+
'output_unnorm': output_unnorm,
|
888 |
+
'output_reproj':output_reproj,
|
889 |
+
'output_ref_unnorm': output_ref_unnorm,
|
890 |
+
'output_orig_ref_comparison':output_orig_ref_comparison,
|
891 |
+
'output_ref_unnorm_new': output_ref_unnorm_new,
|
892 |
+
'output_orig_ref_comparison_new': output_orig_ref_comparison_new}
|
893 |
+
return results
|
894 |
+
|
895 |
+
|
896 |
+
|
897 |
+
|
898 |
+
|
899 |
+
|
900 |
+
|
901 |
+
|
902 |
+
|
903 |
+
|
904 |
+
|
905 |
+
|
906 |
+
|
907 |
+
|
908 |
+
|
909 |
+
|
910 |
+
|
911 |
+
|
912 |
+
|
913 |
+
|
914 |
+
|
915 |
+
def render_vis_nograd(self, vertices, focal_lengths, color=0):
|
916 |
+
# this function is for visualization only
|
917 |
+
# vertices: (bs, n_verts, 3)
|
918 |
+
# focal_lengths: (bs, 1)
|
919 |
+
# color: integer, either 0 or 1
|
920 |
+
# returns a torch tensor of shape (bs, image_size, image_size, 3)
|
921 |
+
with torch.no_grad():
|
922 |
+
batch_size = vertices.shape[0]
|
923 |
+
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1))
|
924 |
+
visualizations = self.silh_renderer.get_visualization_nograd(vertices,
|
925 |
+
faces_prep, focal_lengths, color=color)
|
926 |
+
return visualizations
|
927 |
+
|
src/combined_model/train_main_image_to_3d_wbr_withref.py
ADDED
@@ -0,0 +1,955 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.backends.cudnn
|
5 |
+
import torch.nn.parallel
|
6 |
+
from tqdm import tqdm
|
7 |
+
import os
|
8 |
+
import pathlib
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import trimesh
|
14 |
+
import pickle as pkl
|
15 |
+
import csv
|
16 |
+
from scipy.spatial.transform import Rotation as R_sc
|
17 |
+
|
18 |
+
|
19 |
+
import sys
|
20 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
21 |
+
from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
|
22 |
+
from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
|
23 |
+
from metrics.metrics import Metrics
|
24 |
+
from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS, SMAL_KEYPOINT_NAMES_FOR_3D_EVAL, SMAL_KEYPOINT_INDICES_FOR_3D_EVAL, SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL
|
25 |
+
from combined_model.helper import eval_save_visualizations_and_meshes, eval_prepare_pck_and_iou, eval_add_preds_to_summary
|
26 |
+
|
27 |
+
from smal_pytorch.smal_model.smal_torch_new import SMAL # for gc visualization
|
28 |
+
from src.combined_model.loss_utils.loss_utils import fit_plane
|
29 |
+
# from src.evaluation.sketchfab_evaluation.alignment_utils.calculate_v2v_error_release import compute_similarity_transform
|
30 |
+
# from src.evaluation.sketchfab_evaluation.alignment_utils.calculate_alignment_error import calculate_alignemnt_errors
|
31 |
+
|
32 |
+
# ---------------------------------------------------------------------------------------------------------------------------
|
33 |
+
def do_training_epoch(train_loader, model, loss_module, loss_module_ref, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None, weight_dict_ref=None):
|
34 |
+
losses = AverageMeter()
|
35 |
+
losses_keyp = AverageMeter()
|
36 |
+
losses_silh = AverageMeter()
|
37 |
+
losses_shape = AverageMeter()
|
38 |
+
losses_pose = AverageMeter()
|
39 |
+
losses_class = AverageMeter()
|
40 |
+
losses_breed = AverageMeter()
|
41 |
+
losses_partseg = AverageMeter()
|
42 |
+
losses_ref_keyp = AverageMeter()
|
43 |
+
losses_ref_silh = AverageMeter()
|
44 |
+
losses_ref_pose = AverageMeter()
|
45 |
+
losses_ref_reg = AverageMeter()
|
46 |
+
accuracies = AverageMeter()
|
47 |
+
# Put the model in training mode.
|
48 |
+
model.train()
|
49 |
+
# prepare progress bar
|
50 |
+
iterable = enumerate(train_loader)
|
51 |
+
progress = None
|
52 |
+
if not quiet:
|
53 |
+
progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False)
|
54 |
+
iterable = progress
|
55 |
+
# information for normalization
|
56 |
+
norm_dict = {
|
57 |
+
'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
|
58 |
+
'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
|
59 |
+
'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
|
60 |
+
'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
|
61 |
+
'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
|
62 |
+
# prepare variables, put them on the right device
|
63 |
+
for i, (input, target_dict) in iterable:
|
64 |
+
batch_size = input.shape[0]
|
65 |
+
for key in target_dict.keys():
|
66 |
+
if key == 'breed_index':
|
67 |
+
target_dict[key] = target_dict[key].long().to(device)
|
68 |
+
elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
|
69 |
+
target_dict[key] = target_dict[key].float().to(device)
|
70 |
+
elif key in ['has_seg', 'gc']:
|
71 |
+
target_dict[key] = target_dict[key].to(device)
|
72 |
+
else:
|
73 |
+
pass
|
74 |
+
input = input.float().to(device)
|
75 |
+
|
76 |
+
# ----------------------- do training step -----------------------
|
77 |
+
assert model.training, 'model must be in training mode.'
|
78 |
+
with torch.enable_grad():
|
79 |
+
# ----- forward pass -----
|
80 |
+
output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict)
|
81 |
+
# ----- loss -----
|
82 |
+
# --- from main network
|
83 |
+
loss, loss_dict = loss_module(output_reproj=output_reproj,
|
84 |
+
target_dict=target_dict,
|
85 |
+
weight_dict=weight_dict)
|
86 |
+
# ---from refinement network
|
87 |
+
loss_ref, loss_dict_ref = loss_module_ref(output_ref=output_ref,
|
88 |
+
output_ref_comp=output_ref_comp,
|
89 |
+
target_dict=target_dict,
|
90 |
+
weight_dict_ref=weight_dict_ref)
|
91 |
+
loss_total = loss + loss_ref
|
92 |
+
# ----- backward pass and parameter update -----
|
93 |
+
optimiser.zero_grad()
|
94 |
+
loss_total.backward()
|
95 |
+
optimiser.step()
|
96 |
+
# ----------------------------------------------------------------
|
97 |
+
|
98 |
+
# prepare losses for progress bar
|
99 |
+
bs_fake = 1 # batch_size
|
100 |
+
losses.update(loss_dict['loss'] + loss_dict_ref['loss'], bs_fake)
|
101 |
+
losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
|
102 |
+
losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
|
103 |
+
losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
|
104 |
+
losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
|
105 |
+
losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
|
106 |
+
losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
|
107 |
+
losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
|
108 |
+
losses_ref_keyp.update(loss_dict_ref['keyp_ref'], bs_fake)
|
109 |
+
losses_ref_silh.update(loss_dict_ref['silh_ref'], bs_fake)
|
110 |
+
loss_ref_pose = 0
|
111 |
+
for l_name in ['pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_side', 'pose_spine_tors']:
|
112 |
+
if l_name in loss_dict_ref.keys():
|
113 |
+
loss_ref_pose += loss_dict_ref[l_name]
|
114 |
+
losses_ref_pose.update(loss_ref_pose, bs_fake)
|
115 |
+
loss_ref_reg = 0
|
116 |
+
for l_name in ['reg_trans', 'reg_flength', 'reg_pose']:
|
117 |
+
if l_name in loss_dict_ref.keys():
|
118 |
+
loss_ref_reg += loss_dict_ref[l_name]
|
119 |
+
losses_ref_reg.update(loss_ref_reg, bs_fake)
|
120 |
+
acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
|
121 |
+
accuracies.update(acc, bs_fake)
|
122 |
+
# Show losses as part of the progress bar.
|
123 |
+
if progress is not None:
|
124 |
+
my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
|
125 |
+
loss=losses.avg,
|
126 |
+
loss_keyp=losses_keyp.avg,
|
127 |
+
loss_silh=losses_silh.avg,
|
128 |
+
loss_shape=losses_shape.avg,
|
129 |
+
loss_pose=losses_pose.avg,
|
130 |
+
loss_class=losses_class.avg,
|
131 |
+
loss_breed=losses_breed.avg,
|
132 |
+
loss_partseg=losses_partseg.avg,
|
133 |
+
loss_ref_keyp=losses_ref_keyp.avg,
|
134 |
+
loss_ref_silh=losses_ref_silh.avg,
|
135 |
+
loss_ref_pose=losses_ref_pose.avg,
|
136 |
+
loss_ref_reg=losses_ref_reg.avg)
|
137 |
+
my_string_short = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
|
138 |
+
loss=losses.avg,
|
139 |
+
loss_keyp=losses_keyp.avg,
|
140 |
+
loss_silh=losses_silh.avg,
|
141 |
+
loss_ref_keyp=losses_ref_keyp.avg,
|
142 |
+
loss_ref_silh=losses_ref_silh.avg,
|
143 |
+
loss_ref_pose=losses_ref_pose.avg,
|
144 |
+
loss_ref_reg=losses_ref_reg.avg)
|
145 |
+
progress.set_postfix_str(my_string_short)
|
146 |
+
|
147 |
+
return my_string, accuracies.avg
|
148 |
+
|
149 |
+
|
150 |
+
# ---------------------------------------------------------------------------------------------------------------------------
|
151 |
+
def do_validation_epoch(val_loader, model, loss_module, loss_module_ref, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, weight_dict_ref=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None):
|
152 |
+
losses = AverageMeter()
|
153 |
+
losses_keyp = AverageMeter()
|
154 |
+
losses_silh = AverageMeter()
|
155 |
+
losses_shape = AverageMeter()
|
156 |
+
losses_pose = AverageMeter()
|
157 |
+
losses_class = AverageMeter()
|
158 |
+
losses_breed = AverageMeter()
|
159 |
+
losses_partseg = AverageMeter()
|
160 |
+
losses_ref_keyp = AverageMeter()
|
161 |
+
losses_ref_silh = AverageMeter()
|
162 |
+
losses_ref_pose = AverageMeter()
|
163 |
+
losses_ref_reg = AverageMeter()
|
164 |
+
accuracies = AverageMeter()
|
165 |
+
if save_imgs_path is not None:
|
166 |
+
pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
|
167 |
+
# Put the model in evaluation mode.
|
168 |
+
model.eval()
|
169 |
+
# prepare progress bar
|
170 |
+
iterable = enumerate(val_loader)
|
171 |
+
progress = None
|
172 |
+
if not quiet:
|
173 |
+
progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False)
|
174 |
+
iterable = progress
|
175 |
+
# summarize information for normalization
|
176 |
+
norm_dict = {
|
177 |
+
'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
|
178 |
+
'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
|
179 |
+
'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
|
180 |
+
'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
|
181 |
+
'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
|
182 |
+
batch_size = val_loader.batch_size
|
183 |
+
|
184 |
+
return_mesh_with_gt_groundplane = True
|
185 |
+
if return_mesh_with_gt_groundplane:
|
186 |
+
remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
|
187 |
+
with open(remeshing_path, 'rb') as fp:
|
188 |
+
remeshing_dict = pkl.load(fp)
|
189 |
+
remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
|
190 |
+
remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
|
191 |
+
|
192 |
+
|
193 |
+
# from smal_pytorch.smal_model.smal_torch_new import SMAL
|
194 |
+
print('start: load smal default model (barc), but only for vertices')
|
195 |
+
smal = SMAL()
|
196 |
+
print('end: load smal default model (barc), but only for vertices')
|
197 |
+
smal_template_verts = smal.v_template.detach().cpu().numpy()
|
198 |
+
smal_faces = smal.faces.detach().cpu().numpy()
|
199 |
+
|
200 |
+
|
201 |
+
my_step = 0
|
202 |
+
for index, (input, target_dict) in iterable:
|
203 |
+
|
204 |
+
# prepare variables, put them on the right device
|
205 |
+
curr_batch_size = input.shape[0]
|
206 |
+
for key in target_dict.keys():
|
207 |
+
if key == 'breed_index':
|
208 |
+
target_dict[key] = target_dict[key].long().to(device)
|
209 |
+
elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
|
210 |
+
target_dict[key] = target_dict[key].float().to(device)
|
211 |
+
elif key in ['has_seg', 'gc']:
|
212 |
+
target_dict[key] = target_dict[key].to(device)
|
213 |
+
else:
|
214 |
+
pass
|
215 |
+
input = input.float().to(device)
|
216 |
+
|
217 |
+
# ----------------------- do validation step -----------------------
|
218 |
+
with torch.no_grad():
|
219 |
+
# ----- forward pass -----
|
220 |
+
# output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores'])
|
221 |
+
# output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints'])
|
222 |
+
# output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength'])
|
223 |
+
# target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh'])
|
224 |
+
output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict)
|
225 |
+
# ----- loss -----
|
226 |
+
if metrics == 'no_loss':
|
227 |
+
# --- from main network
|
228 |
+
loss, loss_dict = loss_module(output_reproj=output_reproj,
|
229 |
+
target_dict=target_dict,
|
230 |
+
weight_dict=weight_dict)
|
231 |
+
# ---from refinement network
|
232 |
+
loss_ref, loss_dict_ref = loss_module_ref(output_ref=output_ref,
|
233 |
+
output_ref_comp=output_ref_comp,
|
234 |
+
target_dict=target_dict,
|
235 |
+
weight_dict_ref=weight_dict_ref)
|
236 |
+
loss_total = loss + loss_ref
|
237 |
+
|
238 |
+
# ----------------------------------------------------------------
|
239 |
+
|
240 |
+
|
241 |
+
for result_network in ['normal', 'ref']:
|
242 |
+
# variabled that are not refined
|
243 |
+
hg_keyp_norm = output['keypoints_norm']
|
244 |
+
hg_keyp_scores = output['keypoints_scores']
|
245 |
+
betas = output_reproj['betas']
|
246 |
+
betas_limbs = output_reproj['betas_limbs']
|
247 |
+
zz = output_reproj['z']
|
248 |
+
if result_network == 'normal':
|
249 |
+
# STEP 1: normal network
|
250 |
+
vertices_smal = output_reproj['vertices_smal']
|
251 |
+
flength = output_unnorm['flength']
|
252 |
+
pose_rotmat = output_unnorm['pose_rotmat']
|
253 |
+
trans = output_unnorm['trans']
|
254 |
+
pred_keyp = output_reproj['keyp_2d']
|
255 |
+
pred_silh = output_reproj['silh']
|
256 |
+
prefix = 'normal_'
|
257 |
+
else:
|
258 |
+
# STEP 1: refinement network
|
259 |
+
vertices_smal = output_ref['vertices_smal']
|
260 |
+
flength = output_ref['flength']
|
261 |
+
pose_rotmat = output_ref['pose_rotmat']
|
262 |
+
trans = output_ref['trans']
|
263 |
+
pred_keyp = output_ref['keyp_2d']
|
264 |
+
pred_silh = output_ref['silh']
|
265 |
+
prefix = 'ref_'
|
266 |
+
if return_mesh_with_gt_groundplane and 'gc' in target_dict.keys():
|
267 |
+
bs = vertices_smal.shape[0]
|
268 |
+
target_gc_class = target_dict['gc'][:, :, 0]
|
269 |
+
sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
|
270 |
+
verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
|
271 |
+
target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
|
272 |
+
target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
# import pdb; pdb.set_trace()
|
279 |
+
|
280 |
+
# new for vertex wise ground contact
|
281 |
+
if (not model.graphcnn_type == 'inexistent') and (save_imgs_path is not None):
|
282 |
+
# import pdb; pdb.set_trace()
|
283 |
+
|
284 |
+
sm = torch.nn.Softmax(dim=2)
|
285 |
+
ground_contact_probs = sm(output_ref['vertexwise_ground_contact'])
|
286 |
+
|
287 |
+
for ind_img in range(ground_contact_probs.shape[0]):
|
288 |
+
# ind_img = 0
|
289 |
+
if test_name_list is not None:
|
290 |
+
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
|
291 |
+
img_name = img_name.split('.')[0]
|
292 |
+
else:
|
293 |
+
img_name = str(index) + '_' + str(ind_img)
|
294 |
+
out_path_gcmesh = save_imgs_path + '/' + prefix + 'gcmesh_' + img_name + '.obj'
|
295 |
+
|
296 |
+
gc_prob = ground_contact_probs[ind_img, :, 1] # contact probability
|
297 |
+
vert_colors = np.repeat(255*gc_prob.detach().cpu().numpy()[:, None], 3, 1)
|
298 |
+
my_mesh = trimesh.Trimesh(vertices=smal_template_verts, faces=smal_faces, process=False, maintain_order=True)
|
299 |
+
my_mesh.visual.vertex_colors = vert_colors
|
300 |
+
save_gc_mesh = True # False
|
301 |
+
if save_gc_mesh:
|
302 |
+
my_mesh.export(out_path_gcmesh)
|
303 |
+
|
304 |
+
'''
|
305 |
+
input_image = input[ind_img, :, :, :].detach().clone()
|
306 |
+
for t, m, s in zip(input_image, data_info.rgb_mean,data_info.rgb_stddev): t.add_(m)
|
307 |
+
input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
|
308 |
+
out_path = save_debug_path + 'b' + str(ind_img) +'_input.png'
|
309 |
+
plt.imsave(out_path, input_image_np)
|
310 |
+
'''
|
311 |
+
|
312 |
+
# -------------------------------------
|
313 |
+
|
314 |
+
# import pdb; pdb.set_trace()
|
315 |
+
|
316 |
+
|
317 |
+
'''
|
318 |
+
target_gc_class = target_dict['gc'][ind_img, :, 0]
|
319 |
+
|
320 |
+
current_vertices_smal = vertices_smal[ind_img, :, :]
|
321 |
+
|
322 |
+
points_centroid, plane_normal, error = fit_plane(current_vertices_smal[target_gc_class==1, :])
|
323 |
+
'''
|
324 |
+
|
325 |
+
# calculate ground plane
|
326 |
+
# (see /is/cluster/work/nrueegg/icon_pifu_related/ICON/debug_code/curve_fitting_v2.py)
|
327 |
+
if return_mesh_with_gt_groundplane and 'gc' in target_dict.keys():
|
328 |
+
|
329 |
+
current_verts_remeshed = verts_remeshed[ind_img, :, :]
|
330 |
+
current_target_gc_class_remeshed_prep = target_gc_class_remeshed_prep[ind_img, ...]
|
331 |
+
|
332 |
+
if current_target_gc_class_remeshed_prep.sum() > 3:
|
333 |
+
points_on_plane = current_verts_remeshed[current_target_gc_class_remeshed_prep==1, :]
|
334 |
+
data_centroid, plane_normal, error = fit_plane(points_on_plane)
|
335 |
+
nonplane_points_centered = current_verts_remeshed[current_target_gc_class_remeshed_prep==0, :] - data_centroid[None, :]
|
336 |
+
nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered.transpose(0,1))
|
337 |
+
|
338 |
+
if nonplane_points_projected.sum() > 0: # plane normal points towards the animal
|
339 |
+
plane_normal = plane_normal.detach().cpu().numpy()
|
340 |
+
else:
|
341 |
+
plane_normal = - plane_normal.detach().cpu().numpy()
|
342 |
+
data_centroid = data_centroid.detach().cpu().numpy()
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
# import pdb; pdb.set_trace()
|
347 |
+
|
348 |
+
|
349 |
+
desired_plane_normal_vector = np.asarray([[0, -1, 0]])
|
350 |
+
# new approach: use cross product
|
351 |
+
rotation_axis = np.cross(plane_normal, desired_plane_normal_vector) # np.cross(plane_normal, desired_plane_normal_vector)
|
352 |
+
lengt_rotation_axis = np.linalg.norm(rotation_axis) # = sin(alpha) (because vectors have unit length)
|
353 |
+
angle = np.sin(lengt_rotation_axis)
|
354 |
+
rot = R_sc.from_rotvec(angle * rotation_axis * 1/lengt_rotation_axis)
|
355 |
+
rot_mat = rot[0].as_matrix()
|
356 |
+
rot_upsidedown = R_sc.from_rotvec(np.pi * np.asarray([[1, 0, 0]]))
|
357 |
+
# rot_upsidedown[0].apply(rot[0].apply(plane_normal))
|
358 |
+
current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy()
|
359 |
+
new_smal_vertices = rot_upsidedown[0].apply(rot[0].apply(current_vertices_smal - data_centroid[None, :]))
|
360 |
+
my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True)
|
361 |
+
vert_colors[:, 2] = 255
|
362 |
+
my_mesh.visual.vertex_colors = vert_colors
|
363 |
+
out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_new.obj'
|
364 |
+
my_mesh.export(out_path_gc_rotated)
|
365 |
+
|
366 |
+
|
367 |
+
|
368 |
+
|
369 |
+
|
370 |
+
|
371 |
+
'''# rot = R_sc.align_vectors(plane_normal.reshape((1, -1)), desired_plane_normal_vector)
|
372 |
+
desired_plane_normal_vector = np.asarray([[0, 1, 0]])
|
373 |
+
|
374 |
+
rot = R_sc.align_vectors(desired_plane_normal_vector, plane_normal.reshape((1, -1))) # inv
|
375 |
+
rot_mat = rot[0].as_matrix()
|
376 |
+
|
377 |
+
|
378 |
+
current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy()
|
379 |
+
new_smal_vertices = rot[0].apply((current_vertices_smal - data_centroid[None, :]))
|
380 |
+
|
381 |
+
my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True)
|
382 |
+
my_mesh.visual.vertex_colors = vert_colors
|
383 |
+
out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_y.obj'
|
384 |
+
my_mesh.export(out_path_gc_rotated)
|
385 |
+
'''
|
386 |
+
|
387 |
+
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
# ----
|
396 |
+
|
397 |
+
|
398 |
+
# -------------------------------------
|
399 |
+
|
400 |
+
|
401 |
+
|
402 |
+
|
403 |
+
if index == 0:
|
404 |
+
if len_dataset is None:
|
405 |
+
len_data = val_loader.batch_size * len(val_loader) # 1703
|
406 |
+
else:
|
407 |
+
len_data = len_dataset
|
408 |
+
if metrics == 'all' or metrics == 'no_loss':
|
409 |
+
if result_network == 'normal':
|
410 |
+
summaries = {'normal': dict(), 'ref': dict()}
|
411 |
+
summary = summaries['normal']
|
412 |
+
else:
|
413 |
+
summary = summaries['ref']
|
414 |
+
summary['pck'] = np.zeros((len_data))
|
415 |
+
summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS}
|
416 |
+
summary['acc_sil_2d'] = np.zeros(len_data)
|
417 |
+
summary['betas'] = np.zeros((len_data,betas.shape[1]))
|
418 |
+
summary['betas_limbs'] = np.zeros((len_data, betas_limbs.shape[1]))
|
419 |
+
summary['z'] = np.zeros((len_data, zz.shape[1]))
|
420 |
+
summary['pose_rotmat'] = np.zeros((len_data, pose_rotmat.shape[1], 3, 3))
|
421 |
+
summary['flength'] = np.zeros((len_data, flength.shape[1]))
|
422 |
+
summary['trans'] = np.zeros((len_data, trans.shape[1]))
|
423 |
+
summary['breed_indices'] = np.zeros((len_data))
|
424 |
+
summary['image_names'] = [] # len_data * [None]
|
425 |
+
else:
|
426 |
+
if result_network == 'normal':
|
427 |
+
summary = summaries['normal']
|
428 |
+
else:
|
429 |
+
summary = summaries['ref']
|
430 |
+
|
431 |
+
if save_imgs_path is not None:
|
432 |
+
eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=render_all)
|
433 |
+
|
434 |
+
if metrics == 'all' or metrics == 'no_loss':
|
435 |
+
preds = eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh, progress=progress)
|
436 |
+
# add results for all images in this batch to lists
|
437 |
+
curr_batch_size = pred_keyp.shape[0]
|
438 |
+
eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size)
|
439 |
+
else:
|
440 |
+
# measure accuracy and record loss
|
441 |
+
bs_fake = 1 # batch_size
|
442 |
+
# import pdb; pdb.set_trace()
|
443 |
+
|
444 |
+
|
445 |
+
# save_imgs_path + '/' + prefix + 'rot_tex_pred_' + img_name + '.png'
|
446 |
+
# import pdb; pdb.set_trace()
|
447 |
+
'''
|
448 |
+
for ind_img in range(len(target_dict['index'])):
|
449 |
+
try:
|
450 |
+
if test_name_list is not None:
|
451 |
+
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
|
452 |
+
img_name = img_name.split('.')[0]
|
453 |
+
else:
|
454 |
+
img_name = str(index) + '_' + str(ind_img)
|
455 |
+
all_image_names = ['keypoints_pred_' + img_name + '.png', 'normal_comp_pred_' + img_name + '.png', 'normal_rot_tex_pred_' + img_name + '.png', 'ref_comp_pred_' + img_name + '.png', 'ref_rot_tex_pred_' + img_name + '.png']
|
456 |
+
all_saved_images = []
|
457 |
+
for sub_img_name in all_image_names:
|
458 |
+
saved_img = cv2.imread(save_imgs_path + '/' + sub_img_name)
|
459 |
+
if not (saved_img.shape[0] == 256 and saved_img.shape[1] == 256):
|
460 |
+
saved_img = cv2.resize(saved_img, (256, 256))
|
461 |
+
all_saved_images.append(saved_img)
|
462 |
+
final_image = np.concatenate(all_saved_images, axis=1)
|
463 |
+
save_imgs_path_sum = save_imgs_path.replace('test_', 'summary_test_')
|
464 |
+
if not os.path.exists(save_imgs_path_sum): os.makedirs(save_imgs_path_sum)
|
465 |
+
final_image_path = save_imgs_path_sum + '/summary_' + img_name + '.png'
|
466 |
+
cv2.imwrite(final_image_path, final_image)
|
467 |
+
except:
|
468 |
+
print('dont save a summary image')
|
469 |
+
'''
|
470 |
+
|
471 |
+
|
472 |
+
bs_fake = 1
|
473 |
+
if metrics == 'all' or metrics == 'no_loss':
|
474 |
+
# update progress bar
|
475 |
+
if progress is not None:
|
476 |
+
'''my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format(
|
477 |
+
pck[:(my_step * batch_size + curr_batch_size)].mean(),
|
478 |
+
acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean())'''
|
479 |
+
my_string = "normal_PCK: {0:.2f}, normal_IOU: {1:.2f}, ref_PCK: {2:.2f}, ref_IOU: {3:.2f}".format(
|
480 |
+
summaries['normal']['pck'][:(my_step * batch_size + curr_batch_size)].mean(),
|
481 |
+
summaries['normal']['acc_sil_2d'][:(my_step * batch_size + curr_batch_size)].mean(),
|
482 |
+
summaries['ref']['pck'][:(my_step * batch_size + curr_batch_size)].mean(),
|
483 |
+
summaries['ref']['acc_sil_2d'][:(my_step * batch_size + curr_batch_size)].mean())
|
484 |
+
progress.set_postfix_str(my_string)
|
485 |
+
else:
|
486 |
+
losses.update(loss_dict['loss'] + loss_dict_ref['loss'], bs_fake)
|
487 |
+
losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
|
488 |
+
losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
|
489 |
+
losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
|
490 |
+
losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
|
491 |
+
losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
|
492 |
+
losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
|
493 |
+
losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
|
494 |
+
losses_ref_keyp.update(loss_dict_ref['keyp_ref'], bs_fake)
|
495 |
+
losses_ref_silh.update(loss_dict_ref['silh_ref'], bs_fake)
|
496 |
+
loss_ref_pose = 0
|
497 |
+
for l_name in ['pose_legs_side', 'pose_legs_tors', 'pose_tail_side', 'pose_tail_tors', 'pose_spine_side', 'pose_spine_tors']:
|
498 |
+
loss_ref_pose += loss_dict_ref[l_name]
|
499 |
+
losses_ref_pose.update(loss_ref_pose, bs_fake)
|
500 |
+
loss_ref_reg = 0
|
501 |
+
for l_name in ['reg_trans', 'reg_flength', 'reg_pose']:
|
502 |
+
loss_ref_reg += loss_dict_ref[l_name]
|
503 |
+
losses_ref_reg.update(loss_ref_reg, bs_fake)
|
504 |
+
acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
|
505 |
+
accuracies.update(acc, bs_fake)
|
506 |
+
# Show losses as part of the progress bar.
|
507 |
+
if progress is not None:
|
508 |
+
my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
|
509 |
+
loss=losses.avg,
|
510 |
+
loss_keyp=losses_keyp.avg,
|
511 |
+
loss_silh=losses_silh.avg,
|
512 |
+
loss_shape=losses_shape.avg,
|
513 |
+
loss_pose=losses_pose.avg,
|
514 |
+
loss_class=losses_class.avg,
|
515 |
+
loss_breed=losses_breed.avg,
|
516 |
+
loss_partseg=losses_partseg.avg,
|
517 |
+
loss_ref_keyp=losses_ref_keyp.avg,
|
518 |
+
loss_ref_silh=losses_ref_silh.avg,
|
519 |
+
loss_ref_pose=losses_ref_pose.avg,
|
520 |
+
loss_ref_reg=losses_ref_reg.avg)
|
521 |
+
my_string_short = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_ref_keyp: {loss_ref_keyp:0.4f}, loss_ref_silh: {loss_ref_silh:0.4f}, loss_ref_pose: {loss_ref_pose:0.4f}, loss_ref_reg: {loss_ref_reg:0.4f}'.format(
|
522 |
+
loss=losses.avg,
|
523 |
+
loss_keyp=losses_keyp.avg,
|
524 |
+
loss_silh=losses_silh.avg,
|
525 |
+
loss_ref_keyp=losses_ref_keyp.avg,
|
526 |
+
loss_ref_silh=losses_ref_silh.avg,
|
527 |
+
loss_ref_pose=losses_ref_pose.avg,
|
528 |
+
loss_ref_reg=losses_ref_reg.avg)
|
529 |
+
progress.set_postfix_str(my_string_short)
|
530 |
+
my_step += 1
|
531 |
+
if metrics == 'all':
|
532 |
+
return my_string, summaries # summary
|
533 |
+
elif metrics == 'no_loss':
|
534 |
+
return my_string, np.average(np.asarray(summaries['ref']['acc_sil_2d'])) # np.average(np.asarray(summary['acc_sil_2d']))
|
535 |
+
else:
|
536 |
+
return my_string, accuracies.avg
|
537 |
+
|
538 |
+
|
539 |
+
# ---------------------------------------------------------------------------------------------------------------------------
|
540 |
+
def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, weight_dict_ref=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False, len_dataset=None):
|
541 |
+
if save_imgs_path is not None:
|
542 |
+
pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
|
543 |
+
all_results = []
|
544 |
+
|
545 |
+
# Put the model in evaluation mode.
|
546 |
+
model.eval()
|
547 |
+
|
548 |
+
iterable = enumerate(val_loader)
|
549 |
+
|
550 |
+
# information for normalization
|
551 |
+
norm_dict = {
|
552 |
+
'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
|
553 |
+
'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
|
554 |
+
'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
|
555 |
+
'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
|
556 |
+
'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
|
557 |
+
|
558 |
+
|
559 |
+
return_mesh_with_gt_groundplane = True
|
560 |
+
if return_mesh_with_gt_groundplane:
|
561 |
+
remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
|
562 |
+
with open(remeshing_path, 'rb') as fp:
|
563 |
+
remeshing_dict = pkl.load(fp)
|
564 |
+
remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
|
565 |
+
remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
|
566 |
+
|
567 |
+
# from smal_pytorch.smal_model.smal_torch_new import SMAL
|
568 |
+
print('start: load smal default model (barc), but only for vertices')
|
569 |
+
smal = SMAL()
|
570 |
+
print('end: load smal default model (barc), but only for vertices')
|
571 |
+
smal_template_verts = smal.v_template.detach().cpu().numpy()
|
572 |
+
smal_faces = smal.faces.detach().cpu().numpy()
|
573 |
+
|
574 |
+
file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.txt', 'a') # append mode
|
575 |
+
file_alignment_errors.write(" ----------- start evaluation ------------- \n ")
|
576 |
+
|
577 |
+
csv_file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.csv', 'w') # write mode
|
578 |
+
fieldnames = ['name', 'error']
|
579 |
+
writer = csv.DictWriter(csv_file_alignment_errors, fieldnames=fieldnames)
|
580 |
+
writer.writeheader()
|
581 |
+
|
582 |
+
my_step = 0
|
583 |
+
for index, (input, target_dict) in iterable:
|
584 |
+
batch_size = input.shape[0]
|
585 |
+
input = input.float().to(device)
|
586 |
+
partial_results = {}
|
587 |
+
|
588 |
+
# ----------------------- do visualization step -----------------------
|
589 |
+
with torch.no_grad():
|
590 |
+
output, output_unnorm, output_reproj, output_ref, output_ref_comp = model(input, norm_dict=norm_dict)
|
591 |
+
|
592 |
+
|
593 |
+
# import pdb; pdb.set_trace()
|
594 |
+
|
595 |
+
|
596 |
+
sm = torch.nn.Softmax(dim=2)
|
597 |
+
ground_contact_probs = sm(output_ref['vertexwise_ground_contact'])
|
598 |
+
|
599 |
+
for result_network in ['normal', 'ref']:
|
600 |
+
# variabled that are not refined
|
601 |
+
hg_keyp_norm = output['keypoints_norm']
|
602 |
+
hg_keyp_scores = output['keypoints_scores']
|
603 |
+
betas = output_reproj['betas']
|
604 |
+
betas_limbs = output_reproj['betas_limbs']
|
605 |
+
zz = output_reproj['z']
|
606 |
+
if result_network == 'normal':
|
607 |
+
# STEP 1: normal network
|
608 |
+
vertices_smal = output_reproj['vertices_smal']
|
609 |
+
flength = output_unnorm['flength']
|
610 |
+
pose_rotmat = output_unnorm['pose_rotmat']
|
611 |
+
trans = output_unnorm['trans']
|
612 |
+
pred_keyp = output_reproj['keyp_2d']
|
613 |
+
pred_silh = output_reproj['silh']
|
614 |
+
prefix = 'normal_'
|
615 |
+
else:
|
616 |
+
# STEP 1: refinement network
|
617 |
+
vertices_smal = output_ref['vertices_smal']
|
618 |
+
flength = output_ref['flength']
|
619 |
+
pose_rotmat = output_ref['pose_rotmat']
|
620 |
+
trans = output_ref['trans']
|
621 |
+
pred_keyp = output_ref['keyp_2d']
|
622 |
+
pred_silh = output_ref['silh']
|
623 |
+
prefix = 'ref_'
|
624 |
+
|
625 |
+
bs = vertices_smal.shape[0]
|
626 |
+
# target_gc_class = target_dict['gc'][:, :, 0]
|
627 |
+
target_gc_class = torch.round(ground_contact_probs).long()[:, :, 1]
|
628 |
+
sel_verts = torch.index_select(output_ref['vertices_smal'], dim=1, index=remeshing_relevant_faces.reshape((-1))).reshape((bs, remeshing_relevant_faces.shape[0], 3, 3))
|
629 |
+
verts_remeshed = torch.einsum('ij,aijk->aik', remeshing_relevant_barys, sel_verts)
|
630 |
+
target_gc_class_remeshed = torch.einsum('ij,aij->ai', remeshing_relevant_barys, target_gc_class[:, remeshing_relevant_faces].to(device=device, dtype=torch.float32))
|
631 |
+
target_gc_class_remeshed_prep = torch.round(target_gc_class_remeshed).to(torch.long)
|
632 |
+
|
633 |
+
|
634 |
+
|
635 |
+
|
636 |
+
# index = i
|
637 |
+
# ind_img = 0
|
638 |
+
for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size):
|
639 |
+
|
640 |
+
# ind_img = 0
|
641 |
+
if test_name_list is not None:
|
642 |
+
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
|
643 |
+
img_name = img_name.split('.')[0]
|
644 |
+
else:
|
645 |
+
img_name = str(index) + '_' + str(ind_img)
|
646 |
+
out_path_gcmesh = save_imgs_path + '/' + prefix + 'gcmesh_' + img_name + '.obj'
|
647 |
+
|
648 |
+
gc_prob = ground_contact_probs[ind_img, :, 1] # contact probability
|
649 |
+
vert_colors = np.repeat(255*gc_prob.detach().cpu().numpy()[:, None], 3, 1)
|
650 |
+
my_mesh = trimesh.Trimesh(vertices=smal_template_verts, faces=smal_faces, process=False, maintain_order=True)
|
651 |
+
my_mesh.visual.vertex_colors = vert_colors
|
652 |
+
save_gc_mesh = False
|
653 |
+
if save_gc_mesh:
|
654 |
+
my_mesh.export(out_path_gcmesh)
|
655 |
+
|
656 |
+
current_verts_remeshed = verts_remeshed[ind_img, :, :]
|
657 |
+
current_target_gc_class_remeshed_prep = target_gc_class_remeshed_prep[ind_img, ...]
|
658 |
+
|
659 |
+
if current_target_gc_class_remeshed_prep.sum() > 3:
|
660 |
+
points_on_plane = current_verts_remeshed[current_target_gc_class_remeshed_prep==1, :]
|
661 |
+
data_centroid, plane_normal, error = fit_plane(points_on_plane)
|
662 |
+
nonplane_points_centered = current_verts_remeshed[current_target_gc_class_remeshed_prep==0, :] - data_centroid[None, :]
|
663 |
+
nonplane_points_projected = torch.matmul(plane_normal[None, :], nonplane_points_centered.transpose(0,1))
|
664 |
+
|
665 |
+
if nonplane_points_projected.sum() > 0: # plane normal points towards the animal
|
666 |
+
plane_normal = plane_normal.detach().cpu().numpy()
|
667 |
+
else:
|
668 |
+
plane_normal = - plane_normal.detach().cpu().numpy()
|
669 |
+
data_centroid = data_centroid.detach().cpu().numpy()
|
670 |
+
|
671 |
+
|
672 |
+
|
673 |
+
# import pdb; pdb.set_trace()
|
674 |
+
|
675 |
+
|
676 |
+
desired_plane_normal_vector = np.asarray([[0, -1, 0]])
|
677 |
+
# new approach: use cross product
|
678 |
+
rotation_axis = np.cross(plane_normal, desired_plane_normal_vector) # np.cross(plane_normal, desired_plane_normal_vector)
|
679 |
+
lengt_rotation_axis = np.linalg.norm(rotation_axis) # = sin(alpha) (because vectors have unit length)
|
680 |
+
angle = np.sin(lengt_rotation_axis)
|
681 |
+
rot = R_sc.from_rotvec(angle * rotation_axis * 1/lengt_rotation_axis)
|
682 |
+
rot_mat = rot[0].as_matrix()
|
683 |
+
rot_upsidedown = R_sc.from_rotvec(np.pi * np.asarray([[1, 0, 0]]))
|
684 |
+
# rot_upsidedown[0].apply(rot[0].apply(plane_normal))
|
685 |
+
current_vertices_smal = vertices_smal[ind_img, :, :].detach().cpu().numpy()
|
686 |
+
new_smal_vertices = rot_upsidedown[0].apply(rot[0].apply(current_vertices_smal - data_centroid[None, :]))
|
687 |
+
my_mesh = trimesh.Trimesh(vertices=new_smal_vertices, faces=smal_faces, process=False, maintain_order=True)
|
688 |
+
vert_colors[:, 2] = 255
|
689 |
+
my_mesh.visual.vertex_colors = vert_colors
|
690 |
+
out_path_gc_rotated = save_imgs_path + '/' + prefix + 'gc_rotated_' + img_name + '_new.obj'
|
691 |
+
my_mesh.export(out_path_gc_rotated)
|
692 |
+
|
693 |
+
|
694 |
+
|
695 |
+
'''
|
696 |
+
import pdb; pdb.set_trace()
|
697 |
+
|
698 |
+
from src.evaluation.registration import preprocess_point_cloud, o3d_ransac, draw_registration_result
|
699 |
+
import open3d as o3d
|
700 |
+
import copy
|
701 |
+
|
702 |
+
|
703 |
+
mesh_gt_path = target_dict['mesh_path'][ind_img]
|
704 |
+
mesh_gt = o3d.io.read_triangle_mesh(mesh_gt_path)
|
705 |
+
|
706 |
+
mesh_gt_verts = np.asarray(mesh_gt.vertices)
|
707 |
+
mesh_gt_faces = np.asarray(mesh_gt.triangles)
|
708 |
+
diag_gt = np.sqrt(sum((mesh_gt_verts.max(axis=0) - mesh_gt_verts.min(axis=0))**2))
|
709 |
+
|
710 |
+
mesh_pred_verts = np.asarray(new_smal_vertices)
|
711 |
+
mesh_pred_faces = np.asarray(smal_faces)
|
712 |
+
diag_pred = np.sqrt(sum((mesh_pred_verts.max(axis=0) - mesh_pred_verts.min(axis=0))**2))
|
713 |
+
mesh_pred = o3d.geometry.TriangleMesh()
|
714 |
+
mesh_pred.vertices = o3d.utility.Vector3dVector(mesh_pred_verts)
|
715 |
+
mesh_pred.triangles = o3d.utility.Vector3iVector(mesh_pred_faces)
|
716 |
+
|
717 |
+
# center the predicted mesh around 0
|
718 |
+
trans = - mesh_pred_verts.mean(axis=0)
|
719 |
+
mesh_pred_verts_new = mesh_pred_verts + trans
|
720 |
+
# change the size of the predicted mesh
|
721 |
+
mesh_pred_verts_new = mesh_pred_verts_new * diag_gt / diag_pred
|
722 |
+
|
723 |
+
# transform the predicted mesh (rough alignment)
|
724 |
+
mesh_pred_new = copy.deepcopy(mesh_pred)
|
725 |
+
mesh_pred_new.vertices = o3d.utility.Vector3dVector(np.asarray(mesh_pred_verts_new)) # normals should not have changed
|
726 |
+
voxel_size = 0.01 # 0.5
|
727 |
+
distance_threshold = 0.015 # 0.005 # 0.02 # 1.0
|
728 |
+
result, src_down, src_fpfh, dst_down, dst_fpfh = o3d_ransac(mesh_pred_new, mesh_gt, voxel_size=voxel_size, distance_threshold=distance_threshold, return_all=True)
|
729 |
+
transform = result.transformation
|
730 |
+
mesh_pred_transf = copy.deepcopy(mesh_pred_new).transform(transform)
|
731 |
+
|
732 |
+
out_path_pred_transf = save_imgs_path + '/' + prefix + 'alignment_initial_' + img_name + '.obj'
|
733 |
+
o3d.io.write_triangle_mesh(out_path_pred_transf, mesh_pred_transf)
|
734 |
+
|
735 |
+
# img_name_part = img_name.split(img_name.split('_')[-1] + '_')[0]
|
736 |
+
# out_path_gt = save_imgs_path + '/' + prefix + 'ground_truth_' + img_name_part + '.obj'
|
737 |
+
# o3d.io.write_triangle_mesh(out_path_gt, mesh_gt)
|
738 |
+
|
739 |
+
|
740 |
+
trans_init = transform
|
741 |
+
threshold = 0.02 # 0.1 # 0.02
|
742 |
+
|
743 |
+
n_points = 10000
|
744 |
+
src = mesh_pred_new.sample_points_uniformly(number_of_points=n_points)
|
745 |
+
dst = mesh_gt.sample_points_uniformly(number_of_points=n_points)
|
746 |
+
|
747 |
+
# reg_p2p = o3d.pipelines.registration.registration_icp(src_down, dst_down, threshold, trans_init, o3d.pipelines.registration.TransformationEstimationPointToPoint(), o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
|
748 |
+
reg_p2p = o3d.pipelines.registration.registration_icp(src, dst, threshold, trans_init, o3d.pipelines.registration.TransformationEstimationPointToPoint(), o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=2000))
|
749 |
+
|
750 |
+
# mesh_pred_transf_refined = copy.deepcopy(mesh_pred_new).transform(reg_p2p.transformation)
|
751 |
+
# out_path_pred_transf_refined = save_imgs_path + '/' + prefix + 'alignment_final_' + img_name + '.obj'
|
752 |
+
# o3d.io.write_triangle_mesh(out_path_pred_transf_refined, mesh_pred_transf_refined)
|
753 |
+
|
754 |
+
|
755 |
+
aligned_mesh_final = trimesh.Trimesh(mesh_pred_new.vertices, mesh_pred_new.triangles, vertex_colors=[0, 255, 0])
|
756 |
+
gt_mesh = trimesh.Trimesh(mesh_gt.vertices, mesh_gt.triangles, vertex_colors=[255, 0, 0])
|
757 |
+
scene = trimesh.Scene([aligned_mesh_final, gt_mesh])
|
758 |
+
out_path_alignment_with_gt = save_imgs_path + '/' + prefix + 'alignment_with_gt_' + img_name + '.obj'
|
759 |
+
|
760 |
+
scene.export(out_path_alignment_with_gt)
|
761 |
+
'''
|
762 |
+
|
763 |
+
# import pdb; pdb.set_trace()
|
764 |
+
|
765 |
+
|
766 |
+
# SMAL_KEYPOINT_NAMES_FOR_3D_EVAL # 17 keypoints
|
767 |
+
# prepare target
|
768 |
+
target_keyp_isvalid = target_dict['keypoints_3d'][ind_img, :, 3].detach().cpu().numpy()
|
769 |
+
keyp_to_use = (np.asarray(SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL)==1)*(target_keyp_isvalid==1)
|
770 |
+
target_keyp_raw = target_dict['keypoints_3d'][ind_img, :, :3].detach().cpu().numpy()
|
771 |
+
target_keypoints = target_keyp_raw[keyp_to_use, :]
|
772 |
+
target_pointcloud = target_dict['pointcloud_points'][ind_img, :, :].detach().cpu().numpy()
|
773 |
+
# prepare prediction
|
774 |
+
pred_keypoints_raw = output_ref['vertices_smal'][ind_img, SMAL_KEYPOINT_INDICES_FOR_3D_EVAL, :].detach().cpu().numpy()
|
775 |
+
pred_keypoints = pred_keypoints_raw[keyp_to_use, :]
|
776 |
+
pred_pointcloud = verts_remeshed[ind_img, :, :].detach().cpu().numpy()
|
777 |
+
|
778 |
+
|
779 |
+
|
780 |
+
|
781 |
+
'''
|
782 |
+
pred_keypoints_transf, pred_pointcloud_transf, procrustes_params = compute_similarity_transform(pred_keypoints, target_keypoints, num_joints=None, verts=pred_pointcloud)
|
783 |
+
pa_error = np.sqrt(np.sum((target_keypoints - pred_keypoints_transf) ** 2, axis=1))
|
784 |
+
error_procrustes = np.mean(pa_error)
|
785 |
+
|
786 |
+
|
787 |
+
col_target = np.zeros((target_pointcloud.shape[0], 3), dtype=np.uint8)
|
788 |
+
col_target[:, 0] = 255
|
789 |
+
col_pred = np.zeros((pred_pointcloud_transf.shape[0], 3), dtype=np.uint8)
|
790 |
+
col_pred[:, 1] = 255
|
791 |
+
pc = trimesh.points.PointCloud(np.concatenate((target_pointcloud, pred_pointcloud_transf)), colors=np.concatenate((col_target, col_pred)))
|
792 |
+
out_path_pc = save_imgs_path + '/' + prefix + 'pointclouds_aligned_' + img_name + '.obj'
|
793 |
+
pc.export(out_path_pc)
|
794 |
+
|
795 |
+
print(target_dict['mesh_path'][ind_img])
|
796 |
+
print(error_procrustes)
|
797 |
+
file_alignment_errors.write(target_dict['mesh_path'][ind_img] + '\n')
|
798 |
+
file_alignment_errors.write('error: ' + str(error_procrustes) + ' \n')
|
799 |
+
|
800 |
+
writer.writerow({'name': (target_dict['mesh_path'][ind_img]).split('/')[-1], 'error': str(error_procrustes)})
|
801 |
+
|
802 |
+
# import pdb; pdb.set_trace()
|
803 |
+
# alignment_dict = calculate_alignemnt_errors(output_ref['vertices_smal'][ind_img, :, :], target_dict['keypoints_3d'][ind_img, :, :], target_dict['pointcloud_points'][ind_img, :, :])
|
804 |
+
# file_alignment_errors.write('error: ' + str(alignment_dict['error_procrustes']) + ' \n')
|
805 |
+
'''
|
806 |
+
|
807 |
+
|
808 |
+
|
809 |
+
|
810 |
+
|
811 |
+
|
812 |
+
if index == 0:
|
813 |
+
if len_dataset is None:
|
814 |
+
len_data = val_loader.batch_size * len(val_loader) # 1703
|
815 |
+
else:
|
816 |
+
len_data = len_dataset
|
817 |
+
if result_network == 'normal':
|
818 |
+
summaries = {'normal': dict(), 'ref': dict()}
|
819 |
+
summary = summaries['normal']
|
820 |
+
else:
|
821 |
+
summary = summaries['ref']
|
822 |
+
summary['pck'] = np.zeros((len_data))
|
823 |
+
summary['pck_by_part'] = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS}
|
824 |
+
summary['acc_sil_2d'] = np.zeros(len_data)
|
825 |
+
summary['betas'] = np.zeros((len_data,betas.shape[1]))
|
826 |
+
summary['betas_limbs'] = np.zeros((len_data, betas_limbs.shape[1]))
|
827 |
+
summary['z'] = np.zeros((len_data, zz.shape[1]))
|
828 |
+
summary['pose_rotmat'] = np.zeros((len_data, pose_rotmat.shape[1], 3, 3))
|
829 |
+
summary['flength'] = np.zeros((len_data, flength.shape[1]))
|
830 |
+
summary['trans'] = np.zeros((len_data, trans.shape[1]))
|
831 |
+
summary['breed_indices'] = np.zeros((len_data))
|
832 |
+
summary['image_names'] = [] # len_data * [None]
|
833 |
+
# ['vertices_smal'] = np.zeros((len_data, vertices_smal.shape[1], 3))
|
834 |
+
else:
|
835 |
+
if result_network == 'normal':
|
836 |
+
summary = summaries['normal']
|
837 |
+
else:
|
838 |
+
summary = summaries['ref']
|
839 |
+
|
840 |
+
|
841 |
+
# import pdb; pdb.set_trace()
|
842 |
+
|
843 |
+
|
844 |
+
eval_save_visualizations_and_meshes(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, render_all=render_all)
|
845 |
+
|
846 |
+
|
847 |
+
preds = eval_prepare_pck_and_iou(model, input, data_info, target_dict, test_name_list, vertices_smal, hg_keyp_norm, hg_keyp_scores, zz, betas, betas_limbs, pose_rotmat, trans, flength, pred_keyp, pred_silh, save_imgs_path, prefix, index, pck_thresh=None, skip_pck_and_iou=True)
|
848 |
+
# add results for all images in this batch to lists
|
849 |
+
curr_batch_size = pred_keyp.shape[0]
|
850 |
+
eval_add_preds_to_summary(summary, preds, my_step, batch_size, curr_batch_size, skip_pck_and_iou=True)
|
851 |
+
|
852 |
+
# summary['vertices_smal'][my_step * batch_size:my_step * batch_size + curr_batch_size] = vertices_smal.detach().cpu().numpy()
|
853 |
+
|
854 |
+
|
855 |
+
|
856 |
+
|
857 |
+
|
858 |
+
|
859 |
+
|
860 |
+
|
861 |
+
|
862 |
+
|
863 |
+
|
864 |
+
|
865 |
+
|
866 |
+
|
867 |
+
|
868 |
+
'''
|
869 |
+
try:
|
870 |
+
if test_name_list is not None:
|
871 |
+
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
|
872 |
+
img_name = img_name.split('.')[0]
|
873 |
+
else:
|
874 |
+
img_name = str(index) + '_' + str(ind_img)
|
875 |
+
partial_results['img_name'] = img_name
|
876 |
+
visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
|
877 |
+
focal_lengths=output_unnorm['flength'],
|
878 |
+
color=0) # 2)
|
879 |
+
# save image with predicted keypoints
|
880 |
+
pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
|
881 |
+
pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
|
882 |
+
pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
|
883 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
884 |
+
if save_imgs_path is not None:
|
885 |
+
out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
|
886 |
+
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
|
887 |
+
# save predicted 3d model
|
888 |
+
# (1) front view
|
889 |
+
pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
890 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
891 |
+
partial_results['tex_pred'] = pred_tex
|
892 |
+
if save_imgs_path is not None:
|
893 |
+
out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
|
894 |
+
plt.imsave(out_path, pred_tex)
|
895 |
+
input_image = input[ind_img, :, :, :].detach().clone()
|
896 |
+
for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
|
897 |
+
input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
|
898 |
+
im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
|
899 |
+
im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
|
900 |
+
partial_results['comp_pred'] = im_masked
|
901 |
+
if save_imgs_path is not None:
|
902 |
+
out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
|
903 |
+
plt.imsave(out_path, im_masked)
|
904 |
+
# (2) side view
|
905 |
+
vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
|
906 |
+
roll = np.pi / 2 * torch.ones(1).float().to(device)
|
907 |
+
pitch = np.pi / 2 * torch.ones(1).float().to(device)
|
908 |
+
tensor_0 = torch.zeros(1).float().to(device)
|
909 |
+
tensor_1 = torch.ones(1).float().to(device)
|
910 |
+
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
|
911 |
+
RY = torch.stack([
|
912 |
+
torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
|
913 |
+
torch.stack([tensor_0, tensor_1, tensor_0]),
|
914 |
+
torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
|
915 |
+
vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3))
|
916 |
+
vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
|
917 |
+
visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
|
918 |
+
focal_lengths=output_unnorm['flength'],
|
919 |
+
color=0) # 2)
|
920 |
+
pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
921 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
922 |
+
partial_results['rot_tex_pred'] = pred_tex
|
923 |
+
if save_imgs_path is not None:
|
924 |
+
out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
|
925 |
+
plt.imsave(out_path, pred_tex)
|
926 |
+
render_all = True
|
927 |
+
if render_all:
|
928 |
+
# save input image
|
929 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
930 |
+
if save_imgs_path is not None:
|
931 |
+
out_path = save_imgs_path + '/image_' + img_name + '.png'
|
932 |
+
save_input_image(inp_img, out_path)
|
933 |
+
# save posed mesh
|
934 |
+
V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
|
935 |
+
Faces = model.smal.f
|
936 |
+
mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
|
937 |
+
partial_results['mesh_posed'] = mesh_posed
|
938 |
+
if save_imgs_path is not None:
|
939 |
+
mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
|
940 |
+
except:
|
941 |
+
print('pass...')
|
942 |
+
all_results.append(partial_results)
|
943 |
+
'''
|
944 |
+
|
945 |
+
my_step += 1
|
946 |
+
|
947 |
+
|
948 |
+
file_alignment_errors.close()
|
949 |
+
csv_file_alignment_errors.close()
|
950 |
+
|
951 |
+
|
952 |
+
if return_results:
|
953 |
+
return all_results
|
954 |
+
else:
|
955 |
+
return summaries
|
src/combined_model/train_main_image_to_3d_withbreedrel.py
ADDED
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.backends.cudnn
|
5 |
+
import torch.nn.parallel
|
6 |
+
from tqdm import tqdm
|
7 |
+
import os
|
8 |
+
import pathlib
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
import cv2
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import trimesh
|
14 |
+
|
15 |
+
import sys
|
16 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
17 |
+
from stacked_hourglass.utils.evaluation import accuracy, AverageMeter, final_preds, get_preds, get_preds_soft
|
18 |
+
from stacked_hourglass.utils.visualization import save_input_image_with_keypoints, save_input_image
|
19 |
+
from metrics.metrics import Metrics
|
20 |
+
from configs.SMAL_configs import EVAL_KEYPOINTS, KEYPOINT_GROUPS
|
21 |
+
|
22 |
+
|
23 |
+
# ---------------------------------------------------------------------------------------------------------------------------
|
24 |
+
def do_training_epoch(train_loader, model, loss_module, device, data_info, optimiser, quiet=False, acc_joints=None, weight_dict=None):
|
25 |
+
losses = AverageMeter()
|
26 |
+
losses_keyp = AverageMeter()
|
27 |
+
losses_silh = AverageMeter()
|
28 |
+
losses_shape = AverageMeter()
|
29 |
+
losses_pose = AverageMeter()
|
30 |
+
losses_class = AverageMeter()
|
31 |
+
losses_breed = AverageMeter()
|
32 |
+
losses_partseg = AverageMeter()
|
33 |
+
accuracies = AverageMeter()
|
34 |
+
# Put the model in training mode.
|
35 |
+
model.train()
|
36 |
+
# prepare progress bar
|
37 |
+
iterable = enumerate(train_loader)
|
38 |
+
progress = None
|
39 |
+
if not quiet:
|
40 |
+
progress = tqdm(iterable, desc='Train', total=len(train_loader), ascii=True, leave=False)
|
41 |
+
iterable = progress
|
42 |
+
# information for normalization
|
43 |
+
norm_dict = {
|
44 |
+
'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
|
45 |
+
'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
|
46 |
+
'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
|
47 |
+
'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
|
48 |
+
'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
|
49 |
+
# prepare variables, put them on the right device
|
50 |
+
for i, (input, target_dict) in iterable:
|
51 |
+
batch_size = input.shape[0]
|
52 |
+
for key in target_dict.keys():
|
53 |
+
if key == 'breed_index':
|
54 |
+
target_dict[key] = target_dict[key].long().to(device)
|
55 |
+
elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
|
56 |
+
target_dict[key] = target_dict[key].float().to(device)
|
57 |
+
elif key in ['has_seg', 'gc']:
|
58 |
+
target_dict[key] = target_dict[key].to(device)
|
59 |
+
else:
|
60 |
+
pass
|
61 |
+
input = input.float().to(device)
|
62 |
+
|
63 |
+
# ----------------------- do training step -----------------------
|
64 |
+
assert model.training, 'model must be in training mode.'
|
65 |
+
with torch.enable_grad():
|
66 |
+
# ----- forward pass -----
|
67 |
+
output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
|
68 |
+
# ----- loss -----
|
69 |
+
loss, loss_dict = loss_module(output_reproj=output_reproj,
|
70 |
+
target_dict=target_dict,
|
71 |
+
weight_dict=weight_dict)
|
72 |
+
# ----- backward pass and parameter update -----
|
73 |
+
optimiser.zero_grad()
|
74 |
+
loss.backward()
|
75 |
+
optimiser.step()
|
76 |
+
# ----------------------------------------------------------------
|
77 |
+
|
78 |
+
# prepare losses for progress bar
|
79 |
+
bs_fake = 1 # batch_size
|
80 |
+
losses.update(loss_dict['loss'], bs_fake)
|
81 |
+
losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
|
82 |
+
losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
|
83 |
+
losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
|
84 |
+
losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
|
85 |
+
losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
|
86 |
+
losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
|
87 |
+
losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
|
88 |
+
acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
|
89 |
+
accuracies.update(acc, bs_fake)
|
90 |
+
# Show losses as part of the progress bar.
|
91 |
+
if progress is not None:
|
92 |
+
my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format(
|
93 |
+
loss=losses.avg,
|
94 |
+
loss_keyp=losses_keyp.avg,
|
95 |
+
loss_silh=losses_silh.avg,
|
96 |
+
loss_shape=losses_shape.avg,
|
97 |
+
loss_pose=losses_pose.avg,
|
98 |
+
loss_class=losses_class.avg,
|
99 |
+
loss_breed=losses_breed.avg,
|
100 |
+
loss_partseg=losses_partseg.avg
|
101 |
+
)
|
102 |
+
progress.set_postfix_str(my_string)
|
103 |
+
|
104 |
+
return my_string, accuracies.avg
|
105 |
+
|
106 |
+
|
107 |
+
# ---------------------------------------------------------------------------------------------------------------------------
|
108 |
+
def do_validation_epoch(val_loader, model, loss_module, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, len_dataset=None):
|
109 |
+
losses = AverageMeter()
|
110 |
+
losses_keyp = AverageMeter()
|
111 |
+
losses_silh = AverageMeter()
|
112 |
+
losses_shape = AverageMeter()
|
113 |
+
losses_pose = AverageMeter()
|
114 |
+
losses_class = AverageMeter()
|
115 |
+
losses_breed = AverageMeter()
|
116 |
+
losses_partseg = AverageMeter()
|
117 |
+
accuracies = AverageMeter()
|
118 |
+
if save_imgs_path is not None:
|
119 |
+
pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
|
120 |
+
# Put the model in evaluation mode.
|
121 |
+
model.eval()
|
122 |
+
# prepare progress bar
|
123 |
+
iterable = enumerate(val_loader)
|
124 |
+
progress = None
|
125 |
+
if not quiet:
|
126 |
+
progress = tqdm(iterable, desc='Valid', total=len(val_loader), ascii=True, leave=False)
|
127 |
+
iterable = progress
|
128 |
+
# summarize information for normalization
|
129 |
+
norm_dict = {
|
130 |
+
'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
|
131 |
+
'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
|
132 |
+
'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
|
133 |
+
'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
|
134 |
+
'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
|
135 |
+
batch_size = val_loader.batch_size
|
136 |
+
# prepare variables, put them on the right device
|
137 |
+
my_step = 0
|
138 |
+
for i, (input, target_dict) in iterable:
|
139 |
+
curr_batch_size = input.shape[0]
|
140 |
+
for key in target_dict.keys():
|
141 |
+
if key == 'breed_index':
|
142 |
+
target_dict[key] = target_dict[key].long().to(device)
|
143 |
+
elif key in ['index', 'pts', 'tpts', 'target_weight', 'silh', 'silh_distmat_tofg', 'silh_distmat_tobg', 'sim_breed_index', 'img_border_mask']:
|
144 |
+
target_dict[key] = target_dict[key].float().to(device)
|
145 |
+
elif key in ['has_seg', 'gc']:
|
146 |
+
target_dict[key] = target_dict[key].to(device)
|
147 |
+
else:
|
148 |
+
pass
|
149 |
+
input = input.float().to(device)
|
150 |
+
|
151 |
+
# ----------------------- do validation step -----------------------
|
152 |
+
with torch.no_grad():
|
153 |
+
# ----- forward pass -----
|
154 |
+
# output: (['pose', 'flength', 'trans', 'keypoints_norm', 'keypoints_scores'])
|
155 |
+
# output_unnorm: (['pose_rotmat', 'flength', 'trans', 'keypoints'])
|
156 |
+
# output_reproj: (['vertices_smal', 'torch_meshes', 'keyp_3d', 'keyp_2d', 'silh', 'betas', 'pose_rot6d', 'dog_breed', 'shapedirs', 'z', 'flength_unnorm', 'flength'])
|
157 |
+
# target_dict: (['index', 'center', 'scale', 'pts', 'tpts', 'target_weight', 'breed_index', 'sim_breed_index', 'ind_dataset', 'silh'])
|
158 |
+
output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
|
159 |
+
# ----- loss -----
|
160 |
+
if metrics == 'no_loss':
|
161 |
+
loss, loss_dict = loss_module(output_reproj=output_reproj,
|
162 |
+
target_dict=target_dict,
|
163 |
+
weight_dict=weight_dict)
|
164 |
+
# ----------------------------------------------------------------
|
165 |
+
|
166 |
+
if i == 0:
|
167 |
+
if len_dataset is None:
|
168 |
+
len_data = val_loader.batch_size * len(val_loader) # 1703
|
169 |
+
else:
|
170 |
+
len_data = len_dataset
|
171 |
+
if metrics == 'all' or metrics == 'no_loss':
|
172 |
+
pck = np.zeros((len_data))
|
173 |
+
pck_by_part = {group:np.zeros((len_data)) for group in KEYPOINT_GROUPS}
|
174 |
+
acc_sil_2d = np.zeros(len_data)
|
175 |
+
|
176 |
+
all_betas = np.zeros((len_data, output_reproj['betas'].shape[1]))
|
177 |
+
all_betas_limbs = np.zeros((len_data, output_reproj['betas_limbs'].shape[1]))
|
178 |
+
all_z = np.zeros((len_data, output_reproj['z'].shape[1]))
|
179 |
+
all_pose_rotmat = np.zeros((len_data, output_unnorm['pose_rotmat'].shape[1], 3, 3))
|
180 |
+
all_flength = np.zeros((len_data, output_unnorm['flength'].shape[1]))
|
181 |
+
all_trans = np.zeros((len_data, output_unnorm['trans'].shape[1]))
|
182 |
+
all_breed_indices = np.zeros((len_data))
|
183 |
+
all_image_names = [] # len_data * [None]
|
184 |
+
|
185 |
+
index = i
|
186 |
+
ind_img = 0
|
187 |
+
if save_imgs_path is not None:
|
188 |
+
# render predicted 3d models
|
189 |
+
visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
|
190 |
+
focal_lengths=output_unnorm['flength'],
|
191 |
+
color=0) # color=2)
|
192 |
+
for ind_img in range(len(target_dict['index'])):
|
193 |
+
try:
|
194 |
+
if test_name_list is not None:
|
195 |
+
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
|
196 |
+
img_name = img_name.split('.')[0]
|
197 |
+
else:
|
198 |
+
img_name = str(index) + '_' + str(ind_img)
|
199 |
+
# save image with predicted keypoints
|
200 |
+
out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
|
201 |
+
pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
|
202 |
+
pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
|
203 |
+
pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
|
204 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
205 |
+
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
|
206 |
+
# save predicted 3d model (front view)
|
207 |
+
pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
208 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
209 |
+
out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
|
210 |
+
plt.imsave(out_path, pred_tex)
|
211 |
+
input_image = input[ind_img, :, :, :].detach().clone()
|
212 |
+
for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
|
213 |
+
input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
|
214 |
+
im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
|
215 |
+
im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
|
216 |
+
out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
|
217 |
+
plt.imsave(out_path, im_masked)
|
218 |
+
# save predicted 3d model (side view)
|
219 |
+
vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
|
220 |
+
roll = np.pi / 2 * torch.ones(1).float().to(device)
|
221 |
+
pitch = np.pi / 2 * torch.ones(1).float().to(device)
|
222 |
+
tensor_0 = torch.zeros(1).float().to(device)
|
223 |
+
tensor_1 = torch.ones(1).float().to(device)
|
224 |
+
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
|
225 |
+
RY = torch.stack([
|
226 |
+
torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
|
227 |
+
torch.stack([tensor_0, tensor_1, tensor_0]),
|
228 |
+
torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
|
229 |
+
vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((curr_batch_size, -1, 3))
|
230 |
+
vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
|
231 |
+
|
232 |
+
visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
|
233 |
+
focal_lengths=output_unnorm['flength'],
|
234 |
+
color=0) # 2)
|
235 |
+
pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
236 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
237 |
+
out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
|
238 |
+
plt.imsave(out_path, pred_tex)
|
239 |
+
if render_all:
|
240 |
+
# save input image
|
241 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
242 |
+
out_path = save_imgs_path + '/image_' + img_name + '.png'
|
243 |
+
save_input_image(inp_img, out_path)
|
244 |
+
# save mesh
|
245 |
+
V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
|
246 |
+
Faces = model.smal.f
|
247 |
+
mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
|
248 |
+
mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
|
249 |
+
except:
|
250 |
+
print('dont save an image')
|
251 |
+
|
252 |
+
if metrics == 'all' or metrics == 'no_loss':
|
253 |
+
# prepare a dictionary with all the predicted results
|
254 |
+
preds = {}
|
255 |
+
preds['betas'] = output_reproj['betas'].cpu().detach().numpy()
|
256 |
+
preds['betas_limbs'] = output_reproj['betas_limbs'].cpu().detach().numpy()
|
257 |
+
preds['z'] = output_reproj['z'].cpu().detach().numpy()
|
258 |
+
preds['pose_rotmat'] = output_unnorm['pose_rotmat'].cpu().detach().numpy()
|
259 |
+
preds['flength'] = output_unnorm['flength'].cpu().detach().numpy()
|
260 |
+
preds['trans'] = output_unnorm['trans'].cpu().detach().numpy()
|
261 |
+
preds['breed_index'] = target_dict['breed_index'].cpu().detach().numpy().reshape((-1))
|
262 |
+
img_names = []
|
263 |
+
for ind_img2 in range(0, output_reproj['betas'].shape[0]):
|
264 |
+
if test_name_list is not None:
|
265 |
+
img_name2 = test_name_list[int(target_dict['index'][ind_img2].cpu().detach().numpy())].replace('/', '_')
|
266 |
+
img_name2 = img_name2.split('.')[0]
|
267 |
+
else:
|
268 |
+
img_name2 = str(index) + '_' + str(ind_img2)
|
269 |
+
img_names.append(img_name2)
|
270 |
+
preds['image_names'] = img_names
|
271 |
+
# prepare keypoints for PCK calculation - predicted as well as ground truth
|
272 |
+
pred_keypoints_norm = output['keypoints_norm'] # -1 to 1
|
273 |
+
pred_keypoints_256 = output_reproj['keyp_2d']
|
274 |
+
pred_keypoints = pred_keypoints_256
|
275 |
+
gt_keypoints_256 = target_dict['tpts'][:, :, :2] / 64. * (256. - 1)
|
276 |
+
gt_keypoints_norm = gt_keypoints_256 / 256 / 0.5 - 1
|
277 |
+
gt_keypoints = torch.cat((gt_keypoints_256, target_dict['tpts'][:, :, 2:3]), dim=2) # gt_keypoints_norm
|
278 |
+
# prepare silhouette for IoU calculation - predicted as well as ground truth
|
279 |
+
has_seg = target_dict['has_seg']
|
280 |
+
img_border_mask = target_dict['img_border_mask'][:, 0, :, :]
|
281 |
+
gtseg = target_dict['silh']
|
282 |
+
synth_silhouettes = output_reproj['silh'][:, 0, :, :] # output_reproj['silh']
|
283 |
+
synth_silhouettes[synth_silhouettes>0.5] = 1
|
284 |
+
synth_silhouettes[synth_silhouettes<0.5] = 0
|
285 |
+
# calculate PCK as well as IoU (similar to WLDO)
|
286 |
+
preds['acc_PCK'] = Metrics.PCK(
|
287 |
+
pred_keypoints, gt_keypoints,
|
288 |
+
gtseg, has_seg, idxs=EVAL_KEYPOINTS,
|
289 |
+
thresh_range=[pck_thresh], # [0.15],
|
290 |
+
)
|
291 |
+
preds['acc_IOU'] = Metrics.IOU(
|
292 |
+
synth_silhouettes, gtseg,
|
293 |
+
img_border_mask, mask=has_seg
|
294 |
+
)
|
295 |
+
for group, group_kps in KEYPOINT_GROUPS.items():
|
296 |
+
preds[f'{group}_PCK'] = Metrics.PCK(
|
297 |
+
pred_keypoints, gt_keypoints, gtseg, has_seg,
|
298 |
+
thresh_range=[pck_thresh], # [0.15],
|
299 |
+
idxs=group_kps
|
300 |
+
)
|
301 |
+
# add results for all images in this batch to lists
|
302 |
+
curr_batch_size = pred_keypoints_256.shape[0]
|
303 |
+
if not (preds['acc_PCK'].data.cpu().numpy().shape == (pck[my_step * batch_size:my_step * batch_size + curr_batch_size]).shape):
|
304 |
+
import pdb; pdb.set_trace()
|
305 |
+
pck[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_PCK'].data.cpu().numpy()
|
306 |
+
acc_sil_2d[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['acc_IOU'].data.cpu().numpy()
|
307 |
+
for part in pck_by_part:
|
308 |
+
pck_by_part[part][my_step * batch_size:my_step * batch_size + curr_batch_size] = preds[f'{part}_PCK'].data.cpu().numpy()
|
309 |
+
all_betas[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas']
|
310 |
+
all_betas_limbs[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['betas_limbs']
|
311 |
+
all_z[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['z']
|
312 |
+
all_pose_rotmat[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['pose_rotmat']
|
313 |
+
all_flength[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['flength']
|
314 |
+
all_trans[my_step * batch_size:my_step * batch_size + curr_batch_size, ...] = preds['trans']
|
315 |
+
all_breed_indices[my_step * batch_size:my_step * batch_size + curr_batch_size] = preds['breed_index']
|
316 |
+
all_image_names.extend(preds['image_names'])
|
317 |
+
# update progress bar
|
318 |
+
if progress is not None:
|
319 |
+
my_string = "PCK: {0:.2f}, IOU: {1:.2f}".format(
|
320 |
+
pck[:(my_step * batch_size + curr_batch_size)].mean(),
|
321 |
+
acc_sil_2d[:(my_step * batch_size + curr_batch_size)].mean())
|
322 |
+
progress.set_postfix_str(my_string)
|
323 |
+
else:
|
324 |
+
# measure accuracy and record loss
|
325 |
+
bs_fake = 1 # batch_size
|
326 |
+
losses.update(loss_dict['loss'], bs_fake)
|
327 |
+
losses_keyp.update(loss_dict['loss_keyp_weighted'], bs_fake)
|
328 |
+
losses_silh.update(loss_dict['loss_silh_weighted'], bs_fake)
|
329 |
+
losses_shape.update(loss_dict['loss_shape_weighted'], bs_fake)
|
330 |
+
losses_pose.update(loss_dict['loss_poseprior_weighted'], bs_fake)
|
331 |
+
losses_class.update(loss_dict['loss_class_weighted'], bs_fake)
|
332 |
+
losses_breed.update(loss_dict['loss_breed_weighted'], bs_fake)
|
333 |
+
losses_partseg.update(loss_dict['loss_partseg_weighted'], bs_fake)
|
334 |
+
acc = - loss_dict['loss_keyp_weighted'] # this will be used to keep track of the 'best model'
|
335 |
+
accuracies.update(acc, bs_fake)
|
336 |
+
# Show losses as part of the progress bar.
|
337 |
+
if progress is not None:
|
338 |
+
my_string = 'Loss: {loss:0.4f}, loss_keyp: {loss_keyp:0.4f}, loss_silh: {loss_silh:0.4f}, loss_partseg: {loss_partseg:0.4f}, loss_shape: {loss_shape:0.4f}, loss_pose: {loss_pose:0.4f}, loss_class: {loss_class:0.4f}, loss_breed: {loss_breed:0.4f}'.format(
|
339 |
+
loss=losses.avg,
|
340 |
+
loss_keyp=losses_keyp.avg,
|
341 |
+
loss_silh=losses_silh.avg,
|
342 |
+
loss_shape=losses_shape.avg,
|
343 |
+
loss_pose=losses_pose.avg,
|
344 |
+
loss_class=losses_class.avg,
|
345 |
+
loss_breed=losses_breed.avg,
|
346 |
+
loss_partseg=losses_partseg.avg
|
347 |
+
)
|
348 |
+
progress.set_postfix_str(my_string)
|
349 |
+
my_step += 1
|
350 |
+
if metrics == 'all':
|
351 |
+
summary = {'pck': pck, 'acc_sil_2d': acc_sil_2d, 'pck_by_part':pck_by_part,
|
352 |
+
'betas': all_betas, 'betas_limbs': all_betas_limbs, 'z': all_z, 'pose_rotmat': all_pose_rotmat,
|
353 |
+
'flenght': all_flength, 'trans': all_trans, 'image_names': all_image_names, 'breed_indices': all_breed_indices}
|
354 |
+
return my_string, summary
|
355 |
+
elif metrics == 'no_loss':
|
356 |
+
return my_string, np.average(np.asarray(acc_sil_2d))
|
357 |
+
else:
|
358 |
+
return my_string, accuracies.avg
|
359 |
+
|
360 |
+
|
361 |
+
# ---------------------------------------------------------------------------------------------------------------------------
|
362 |
+
def do_visual_epoch(val_loader, model, device, data_info, flip=False, quiet=False, acc_joints=None, save_imgs_path=None, weight_dict=None, metrics=None, val_opt='default', test_name_list=None, render_all=False, pck_thresh=0.15, return_results=False):
|
363 |
+
if save_imgs_path is not None:
|
364 |
+
pathlib.Path(save_imgs_path).mkdir(parents=True, exist_ok=True)
|
365 |
+
all_results = []
|
366 |
+
|
367 |
+
# Put the model in evaluation mode.
|
368 |
+
model.eval()
|
369 |
+
|
370 |
+
iterable = enumerate(val_loader)
|
371 |
+
|
372 |
+
# information for normalization
|
373 |
+
norm_dict = {
|
374 |
+
'pose_rot6d_mean': torch.from_numpy(data_info.pose_rot6d_mean).float().to(device),
|
375 |
+
'trans_mean': torch.from_numpy(data_info.trans_mean).float().to(device),
|
376 |
+
'trans_std': torch.from_numpy(data_info.trans_std).float().to(device),
|
377 |
+
'flength_mean': torch.from_numpy(data_info.flength_mean).float().to(device),
|
378 |
+
'flength_std': torch.from_numpy(data_info.flength_std).float().to(device)}
|
379 |
+
|
380 |
+
'''
|
381 |
+
return_mesh_with_gt_groundplane = True
|
382 |
+
if return_mesh_with_gt_groundplane:
|
383 |
+
remeshing_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data_remeshed/uniform_surface_sampling/my_smpl_39dogsnorm_Jr_4_dog_remesh4000_info.pkl'
|
384 |
+
with open(remeshing_path, 'rb') as fp:
|
385 |
+
remeshing_dict = pkl.load(fp)
|
386 |
+
remeshing_relevant_faces = torch.tensor(remeshing_dict['smal_faces'][remeshing_dict['faceid_closest']], dtype=torch.long, device=device)
|
387 |
+
remeshing_relevant_barys = torch.tensor(remeshing_dict['barys_closest'], dtype=torch.float32, device=device)
|
388 |
+
|
389 |
+
# from smal_pytorch.smal_model.smal_torch_new import SMAL
|
390 |
+
print('start: load smal default model (barc), but only for vertices')
|
391 |
+
smal = SMAL()
|
392 |
+
print('end: load smal default model (barc), but only for vertices')
|
393 |
+
smal_template_verts = smal.v_template.detach().cpu().numpy()
|
394 |
+
smal_faces = smal.faces.detach().cpu().numpy()
|
395 |
+
|
396 |
+
file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.txt', 'a') # append mode
|
397 |
+
file_alignment_errors.write(" ----------- start evaluation ------------- \n ")
|
398 |
+
|
399 |
+
csv_file_alignment_errors = open(save_imgs_path + '/a_ref_procrustes_alignmnet_errors.csv', 'w') # write mode
|
400 |
+
fieldnames = ['name', 'error']
|
401 |
+
writer = csv.DictWriter(csv_file_alignment_errors, fieldnames=fieldnames)
|
402 |
+
writer.writeheader()
|
403 |
+
'''
|
404 |
+
|
405 |
+
my_step = 0
|
406 |
+
for i, (input, target_dict) in iterable:
|
407 |
+
batch_size = input.shape[0]
|
408 |
+
input = input.float().to(device)
|
409 |
+
partial_results = {}
|
410 |
+
|
411 |
+
# ----------------------- do visualization step -----------------------
|
412 |
+
with torch.no_grad():
|
413 |
+
output, output_unnorm, output_reproj = model(input, norm_dict=norm_dict)
|
414 |
+
|
415 |
+
index = i
|
416 |
+
ind_img = 0
|
417 |
+
for ind_img in range(batch_size): # range(min(12, batch_size)): # range(12): # [0]: #range(0, batch_size):
|
418 |
+
|
419 |
+
try:
|
420 |
+
if test_name_list is not None:
|
421 |
+
img_name = test_name_list[int(target_dict['index'][ind_img].cpu().detach().numpy())].replace('/', '_')
|
422 |
+
img_name = img_name.split('.')[0]
|
423 |
+
else:
|
424 |
+
img_name = str(index) + '_' + str(ind_img)
|
425 |
+
partial_results['img_name'] = img_name
|
426 |
+
visualizations = model.render_vis_nograd(vertices=output_reproj['vertices_smal'],
|
427 |
+
focal_lengths=output_unnorm['flength'],
|
428 |
+
color=0) # 2)
|
429 |
+
# save image with predicted keypoints
|
430 |
+
pred_unp = (output['keypoints_norm'][ind_img, :, :] + 1.) / 2 * (data_info.image_size - 1)
|
431 |
+
pred_unp_maxval = output['keypoints_scores'][ind_img, :, :]
|
432 |
+
pred_unp_prep = torch.cat((pred_unp, pred_unp_maxval), 1)
|
433 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
434 |
+
if save_imgs_path is not None:
|
435 |
+
out_path = save_imgs_path + '/keypoints_pred_' + img_name + '.png'
|
436 |
+
save_input_image_with_keypoints(inp_img, pred_unp_prep, out_path=out_path, threshold=0.1, print_scores=True, ratio_in_out=1.0) # threshold=0.3
|
437 |
+
# save predicted 3d model
|
438 |
+
# (1) front view
|
439 |
+
pred_tex = visualizations[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
440 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
441 |
+
partial_results['tex_pred'] = pred_tex
|
442 |
+
if save_imgs_path is not None:
|
443 |
+
out_path = save_imgs_path + '/tex_pred_' + img_name + '.png'
|
444 |
+
plt.imsave(out_path, pred_tex)
|
445 |
+
input_image = input[ind_img, :, :, :].detach().clone()
|
446 |
+
for t, m, s in zip(input_image, data_info.rgb_mean, data_info.rgb_stddev): t.add_(m)
|
447 |
+
input_image_np = input_image.detach().cpu().numpy().transpose(1, 2, 0)
|
448 |
+
im_masked = cv2.addWeighted(input_image_np,0.2,pred_tex,0.8,0)
|
449 |
+
im_masked[pred_tex_max<0.01, :] = input_image_np[pred_tex_max<0.01, :]
|
450 |
+
partial_results['comp_pred'] = im_masked
|
451 |
+
if save_imgs_path is not None:
|
452 |
+
out_path = save_imgs_path + '/comp_pred_' + img_name + '.png'
|
453 |
+
plt.imsave(out_path, im_masked)
|
454 |
+
# (2) side view
|
455 |
+
vertices_cent = output_reproj['vertices_smal'] - output_reproj['vertices_smal'].mean(dim=1)[:, None, :]
|
456 |
+
roll = np.pi / 2 * torch.ones(1).float().to(device)
|
457 |
+
pitch = np.pi / 2 * torch.ones(1).float().to(device)
|
458 |
+
tensor_0 = torch.zeros(1).float().to(device)
|
459 |
+
tensor_1 = torch.ones(1).float().to(device)
|
460 |
+
RX = torch.stack([torch.stack([tensor_1, tensor_0, tensor_0]), torch.stack([tensor_0, torch.cos(roll), -torch.sin(roll)]),torch.stack([tensor_0, torch.sin(roll), torch.cos(roll)])]).reshape(3,3)
|
461 |
+
RY = torch.stack([
|
462 |
+
torch.stack([torch.cos(pitch), tensor_0, torch.sin(pitch)]),
|
463 |
+
torch.stack([tensor_0, tensor_1, tensor_0]),
|
464 |
+
torch.stack([-torch.sin(pitch), tensor_0, torch.cos(pitch)])]).reshape(3,3)
|
465 |
+
vertices_rot = (torch.matmul(RY, vertices_cent.reshape((-1, 3))[:, :, None])).reshape((batch_size, -1, 3))
|
466 |
+
vertices_rot[:, :, 2] = vertices_rot[:, :, 2] + torch.ones_like(vertices_rot[:, :, 2]) * 20 # 18 # *16
|
467 |
+
visualizations_rot = model.render_vis_nograd(vertices=vertices_rot,
|
468 |
+
focal_lengths=output_unnorm['flength'],
|
469 |
+
color=0) # 2)
|
470 |
+
pred_tex = visualizations_rot[ind_img, :, :, :].permute((1, 2, 0)).cpu().detach().numpy() / 256
|
471 |
+
pred_tex_max = np.max(pred_tex, axis=2)
|
472 |
+
partial_results['rot_tex_pred'] = pred_tex
|
473 |
+
if save_imgs_path is not None:
|
474 |
+
out_path = save_imgs_path + '/rot_tex_pred_' + img_name + '.png'
|
475 |
+
plt.imsave(out_path, pred_tex)
|
476 |
+
render_all = True
|
477 |
+
if render_all:
|
478 |
+
# save input image
|
479 |
+
inp_img = input[ind_img, :, :, :].detach().clone()
|
480 |
+
if save_imgs_path is not None:
|
481 |
+
out_path = save_imgs_path + '/image_' + img_name + '.png'
|
482 |
+
save_input_image(inp_img, out_path)
|
483 |
+
# save posed mesh
|
484 |
+
V_posed = output_reproj['vertices_smal'][ind_img, :, :].detach().cpu().numpy()
|
485 |
+
Faces = model.smal.f
|
486 |
+
mesh_posed = trimesh.Trimesh(vertices=V_posed, faces=Faces, process=False, maintain_order=True)
|
487 |
+
partial_results['mesh_posed'] = mesh_posed
|
488 |
+
if save_imgs_path is not None:
|
489 |
+
mesh_posed.export(save_imgs_path + '/mesh_posed_' + img_name + '.obj')
|
490 |
+
except:
|
491 |
+
print('pass...')
|
492 |
+
all_results.append(partial_results)
|
493 |
+
if return_results:
|
494 |
+
return all_results
|
495 |
+
else:
|
496 |
+
return
|
src/configs/SMAL_configs.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
|
8 |
+
# SMAL_DATA_DIR = '/is/cluster/work/nrueegg/dog_project/pytorch-dogs-inference/src/smal_pytorch/smpl_models/'
|
9 |
+
# SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', 'smal_pytorch', 'smal_data')
|
10 |
+
SMAL_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'smal_data')
|
11 |
+
|
12 |
+
# we replace the old SMAL model by a more dog specific model (see BARC cvpr 2022 paper)
|
13 |
+
# our model has several differences compared to the original SMAL model, some of them are:
|
14 |
+
# - the PCA shape space is recalculated (from partially new data and weighted)
|
15 |
+
# - coefficients for limb length changes are allowed (similar to WLDO, we did borrow some of their code)
|
16 |
+
# - all dogs have a core of approximately the same length
|
17 |
+
# - dogs are centered in their root joint (which is close to the tail base)
|
18 |
+
# -> like this the root rotations is always around this joint AND (0, 0, 0)
|
19 |
+
# -> before this it would happen that the animal 'slips' from the image middle to the side when rotating it. Now
|
20 |
+
# 'trans' also defines the center of the rotation
|
21 |
+
# - we correct the back joint locations such that all those joints are more aligned
|
22 |
+
|
23 |
+
# logscale_part_list = ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l']
|
24 |
+
# logscale_part_list = ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f']
|
25 |
+
|
26 |
+
SMAL_MODEL_CONFIG = {
|
27 |
+
'barc': {
|
28 |
+
'smal_model_type': 'barc',
|
29 |
+
'smal_model_path': os.path.join(SMAL_DATA_DIR, 'my_smpl_SMBLD_nbj_v3.pkl'),
|
30 |
+
'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl'),
|
31 |
+
'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'my_smpl_data_SMBLD_v3.pkl'),
|
32 |
+
'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
|
33 |
+
},
|
34 |
+
'39dogs_diffsize': {
|
35 |
+
'smal_model_type': '39dogs_diffsize',
|
36 |
+
'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_00791_nadine_Jr_4_dog.pkl'),
|
37 |
+
'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_00791_nadine_Jr_4_dog.pkl'),
|
38 |
+
'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_00791_nadine_Jr_4_dog.pkl'),
|
39 |
+
'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
|
40 |
+
},
|
41 |
+
'39dogs_norm': {
|
42 |
+
'smal_model_type': '39dogs_norm',
|
43 |
+
'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_Jr_4_dog.pkl'),
|
44 |
+
'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
|
45 |
+
'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
|
46 |
+
'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
|
47 |
+
},
|
48 |
+
'39dogs_norm_9ll': { # 9 limb length parameters
|
49 |
+
'smal_model_type': '39dogs_norm_9ll',
|
50 |
+
'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_Jr_4_dog.pkl'),
|
51 |
+
'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
|
52 |
+
'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_Jr_4_dog.pkl'),
|
53 |
+
'logscale_part_list': ['front_legs_l', 'front_legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l', 'back_legs_l', 'back_legs_f'],
|
54 |
+
},
|
55 |
+
'39dogs_norm_newv2': { # front and back legs of equal lengths
|
56 |
+
'smal_model_type': '39dogs_norm_newv2',
|
57 |
+
'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_newv2_dog.pkl'),
|
58 |
+
'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv2_dog.pkl'),
|
59 |
+
'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv2_dog.pkl'),
|
60 |
+
'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
|
61 |
+
},
|
62 |
+
'39dogs_norm_newv3': { # pca on dame AND different front and back legs lengths
|
63 |
+
'smal_model_type': '39dogs_norm_newv3',
|
64 |
+
'smal_model_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_39dogsnorm_newv3_dog.pkl'),
|
65 |
+
'smal_model_data_path': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv3_dog.pkl'),
|
66 |
+
'unity_smal_shape_prior_dogs': os.path.join(SMAL_DATA_DIR, 'new_dog_models', 'my_smpl_data_39dogsnorm_newv3_dog.pkl'),
|
67 |
+
'logscale_part_list': ['legs_l', 'legs_f', 'tail_l', 'tail_f', 'ears_y', 'ears_l', 'head_l'],
|
68 |
+
},
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
SYMMETRY_INDS_FILE = os.path.join(SMAL_DATA_DIR, 'symmetry_inds.json')
|
73 |
+
|
74 |
+
mean_dog_bone_lengths_txt = os.path.join(SMAL_DATA_DIR, 'mean_dog_bone_lengths.txt')
|
75 |
+
|
76 |
+
# some vertex indices, (from silvia zuffi´s code, create_projected_images_cats.py)
|
77 |
+
KEY_VIDS = np.array(([1068, 1080, 1029, 1226], # left eye
|
78 |
+
[2660, 3030, 2675, 3038], # right eye
|
79 |
+
[910], # mouth low
|
80 |
+
[360, 1203, 1235, 1230], # front left leg, low
|
81 |
+
[3188, 3156, 2327, 3183], # front right leg, low
|
82 |
+
[1976, 1974, 1980, 856], # back left leg, low
|
83 |
+
[3854, 2820, 3852, 3858], # back right leg, low
|
84 |
+
[452, 1811], # tail start
|
85 |
+
[416, 235, 182], # front left leg, top
|
86 |
+
[2156, 2382, 2203], # front right leg, top
|
87 |
+
[829], # back left leg, top
|
88 |
+
[2793], # back right leg, top
|
89 |
+
[60, 114, 186, 59], # throat, close to base of neck
|
90 |
+
[2091, 2037, 2036, 2160], # withers (a bit lower than in reality)
|
91 |
+
[384, 799, 1169, 431], # front left leg, middle
|
92 |
+
[2351, 2763, 2397, 3127], # front right leg, middle
|
93 |
+
[221, 104], # back left leg, middle
|
94 |
+
[2754, 2192], # back right leg, middle
|
95 |
+
[191, 1158, 3116, 2165], # neck
|
96 |
+
[28], # Tail tip
|
97 |
+
[542], # Left Ear
|
98 |
+
[2507], # Right Ear
|
99 |
+
[1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip
|
100 |
+
[0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail
|
101 |
+
|
102 |
+
# the following vertices are used for visibility only: if one of the vertices is visible,
|
103 |
+
# then we assume that the joint is visible! There is some noise, but we don't care, as this is
|
104 |
+
# for generation of the synthetic dataset only
|
105 |
+
KEY_VIDS_VISIBILITY_ONLY = np.array(([1068, 1080, 1029, 1226, 645], # left eye
|
106 |
+
[2660, 3030, 2675, 3038, 2567], # right eye
|
107 |
+
[910, 11, 5], # mouth low
|
108 |
+
[360, 1203, 1235, 1230, 298, 408, 303, 293, 384], # front left leg, low
|
109 |
+
[3188, 3156, 2327, 3183, 2261, 2271, 2573, 2265], # front right leg, low
|
110 |
+
[1976, 1974, 1980, 856, 559, 851, 556], # back left leg, low
|
111 |
+
[3854, 2820, 3852, 3858, 2524, 2522, 2815, 2072], # back right leg, low
|
112 |
+
[452, 1811, 63, 194, 52, 370, 64], # tail start
|
113 |
+
[416, 235, 182, 440, 8, 80, 73, 112], # front left leg, top
|
114 |
+
[2156, 2382, 2203, 2050, 2052, 2406, 3], # front right leg, top
|
115 |
+
[829, 219, 218, 173, 17, 7, 279], # back left leg, top
|
116 |
+
[2793, 582, 140, 87, 2188, 2147, 2063], # back right leg, top
|
117 |
+
[60, 114, 186, 59, 878, 130, 189, 45], # throat, close to base of neck
|
118 |
+
[2091, 2037, 2036, 2160, 190, 2164], # withers (a bit lower than in reality)
|
119 |
+
[384, 799, 1169, 431, 321, 314, 437, 310, 323], # front left leg, middle
|
120 |
+
[2351, 2763, 2397, 3127, 2278, 2285, 2282, 2275, 2359], # front right leg, middle
|
121 |
+
[221, 104, 105, 97, 103], # back left leg, middle
|
122 |
+
[2754, 2192, 2080, 2251, 2075, 2074], # back right leg, middle
|
123 |
+
[191, 1158, 3116, 2165, 154, 653, 133, 339], # neck
|
124 |
+
[28, 474, 475, 731, 24], # Tail tip
|
125 |
+
[542, 147, 509, 200, 522], # Left Ear
|
126 |
+
[2507,2174, 2122, 2126, 2474], # Right Ear
|
127 |
+
[1039, 1845, 1846, 1870, 1879, 1919, 2997, 3761, 3762], # nose tip
|
128 |
+
[0, 464, 465, 726, 1824, 2429, 2430, 2690]), dtype=object) # half tail
|
129 |
+
|
130 |
+
# Keypoint indices for 3d sketchfab evaluation
|
131 |
+
SMAL_KEYPOINT_NAMES_FOR_3D_EVAL = ['right_front_paw','right_front_elbow','right_back_paw','right_back_hock','right_ear_top','right_ear_bottom','right_eye', \
|
132 |
+
'left_front_paw','left_front_elbow','left_back_paw','left_back_hock','left_ear_top','left_ear_bottom','left_eye', \
|
133 |
+
'nose','tail_start','tail_end']
|
134 |
+
SMAL_KEYPOINT_INDICES_FOR_3D_EVAL = [2577, 2361, 2820, 2085, 2125, 2453, 2668, 613, 394, 855, 786, 149, 486, 1079, 1845, 1820, 28]
|
135 |
+
SMAL_KEYPOINT_WHICHTOUSE_FOR_3D_EVAL = [1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0] # [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0]
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
# see: https://github.com/benjiebob/SMALify/blob/master/config.py
|
141 |
+
# JOINT DEFINITIONS - based on SMAL joints and additional {eyes, ear tips, chin and nose}
|
142 |
+
TORSO_JOINTS = [2, 5, 8, 11, 12, 23]
|
143 |
+
CANONICAL_MODEL_JOINTS = [
|
144 |
+
10, 9, 8, # upper_left [paw, middle, top]
|
145 |
+
20, 19, 18, # lower_left [paw, middle, top]
|
146 |
+
14, 13, 12, # upper_right [paw, middle, top]
|
147 |
+
24, 23, 22, # lower_right [paw, middle, top]
|
148 |
+
25, 31, # tail [start, end]
|
149 |
+
33, 34, # ear base [left, right]
|
150 |
+
35, 36, # nose, chin
|
151 |
+
38, 37, # ear tip [left, right]
|
152 |
+
39, 40, # eyes [left, right]
|
153 |
+
6, 11, # withers, throat (throat is inaccurate and withers also)
|
154 |
+
28] # tail middle
|
155 |
+
# old: 15, 15, # withers, throat (TODO: Labelled same as throat for now), throat
|
156 |
+
|
157 |
+
CANONICAL_MODEL_JOINTS_REFINED = [
|
158 |
+
41, 9, 8, # upper_left [paw, middle, top]
|
159 |
+
43, 19, 18, # lower_left [paw, middle, top]
|
160 |
+
42, 13, 12, # upper_right [paw, middle, top]
|
161 |
+
44, 23, 22, # lower_right [paw, middle, top]
|
162 |
+
25, 31, # tail [start, end]
|
163 |
+
33, 34, # ear base [left, right]
|
164 |
+
35, 36, # nose, chin
|
165 |
+
38, 37, # ear tip [left, right]
|
166 |
+
39, 40, # eyes [left, right]
|
167 |
+
46, 45, # withers, throat
|
168 |
+
28] # tail middle
|
169 |
+
|
170 |
+
# the following list gives the indices of the KEY_VIDS_JOINTS that must be taken in order
|
171 |
+
# to judge if the CANONICAL_MODEL_JOINTS are visible - those are all approximations!
|
172 |
+
CMJ_VISIBILITY_IN_KEY_VIDS = [
|
173 |
+
3, 14, 8, # left front leg
|
174 |
+
5, 16, 10, # left rear leg
|
175 |
+
4, 15, 9, # right front leg
|
176 |
+
6, 17, 11, # right rear leg
|
177 |
+
7, 19, # tail front, tail back
|
178 |
+
20, 21, # ear base (but can not be found in blue, se we take the tip)
|
179 |
+
2, 2, # mouth (was: 22, 2)
|
180 |
+
20, 21, # ear tips
|
181 |
+
1, 0, # eyes
|
182 |
+
18, # withers, not sure where this point is
|
183 |
+
12, # throat
|
184 |
+
23, # mid tail
|
185 |
+
]
|
186 |
+
|
187 |
+
# define which bone lengths are used as input to the 2d-to-3d network
|
188 |
+
IDXS_BONES_NO_REDUNDANCY = [6,7,8,9,16,17,18,19,32,1,2,3,4,5,14,15,24,25,26,27,28,29,30,31]
|
189 |
+
# load bone lengths of the mean dog (already filtered)
|
190 |
+
mean_dog_bone_lengths = []
|
191 |
+
with open(mean_dog_bone_lengths_txt, 'r') as f:
|
192 |
+
for line in f:
|
193 |
+
mean_dog_bone_lengths.append(float(line.split('\n')[0]))
|
194 |
+
MEAN_DOG_BONE_LENGTHS_NO_RED = np.asarray(mean_dog_bone_lengths)[IDXS_BONES_NO_REDUNDANCY] # (24, )
|
195 |
+
|
196 |
+
# Body part segmentation:
|
197 |
+
# the body can be segmented based on the bones and for the new dog model also based on the new shapedirs
|
198 |
+
# axis_horizontal = self.shapedirs[2, :].reshape((-1, 3))[:, 0]
|
199 |
+
# all_indices = np.arange(3889)
|
200 |
+
# tail_indices = all_indices[axis_horizontal.detach().cpu().numpy() < 0.0]
|
201 |
+
VERTEX_IDS_TAIL = [ 0, 4, 9, 10, 24, 25, 28, 453, 454, 456, 457,
|
202 |
+
458, 459, 460, 461, 462, 463, 464, 465, 466, 467, 468,
|
203 |
+
469, 470, 471, 472, 473, 474, 475, 724, 725, 726, 727,
|
204 |
+
728, 729, 730, 731, 813, 975, 976, 977, 1109, 1110, 1111,
|
205 |
+
1811, 1813, 1819, 1820, 1821, 1822, 1823, 1824, 1825, 1826, 1827,
|
206 |
+
1828, 1835, 1836, 1960, 1961, 1962, 1963, 1964, 1965, 1966, 1967,
|
207 |
+
1968, 1969, 2418, 2419, 2421, 2422, 2423, 2424, 2425, 2426, 2427,
|
208 |
+
2428, 2429, 2430, 2431, 2432, 2433, 2434, 2435, 2436, 2437, 2438,
|
209 |
+
2439, 2440, 2688, 2689, 2690, 2691, 2692, 2693, 2694, 2695, 2777,
|
210 |
+
3067, 3068, 3069, 3842, 3843, 3844, 3845, 3846, 3847]
|
211 |
+
|
212 |
+
# same as in https://github.com/benjiebob/WLDO/blob/master/global_utils/config.py
|
213 |
+
EVAL_KEYPOINTS = [
|
214 |
+
0, 1, 2, # left front
|
215 |
+
3, 4, 5, # left rear
|
216 |
+
6, 7, 8, # right front
|
217 |
+
9, 10, 11, # right rear
|
218 |
+
12, 13, # tail start -> end
|
219 |
+
14, 15, # left ear, right ear
|
220 |
+
16, 17, # nose, chin
|
221 |
+
18, 19] # left tip, right tip
|
222 |
+
|
223 |
+
KEYPOINT_GROUPS = {
|
224 |
+
'legs': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], # legs
|
225 |
+
'tail': [12, 13], # tail
|
226 |
+
'ears': [14, 15, 18, 19], # ears
|
227 |
+
'face': [16, 17] # face
|
228 |
+
}
|
229 |
+
|
230 |
+
|
src/configs/anipose_data_info.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
|
7 |
+
STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics')
|
8 |
+
STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json')
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class DataInfo:
|
12 |
+
rgb_mean: List[float]
|
13 |
+
rgb_stddev: List[float]
|
14 |
+
joint_names: List[str]
|
15 |
+
hflip_indices: List[int]
|
16 |
+
n_joints: int
|
17 |
+
n_keyp: int
|
18 |
+
n_bones: int
|
19 |
+
n_betas: int
|
20 |
+
image_size: int
|
21 |
+
trans_mean: np.ndarray
|
22 |
+
trans_std: np.ndarray
|
23 |
+
flength_mean: np.ndarray
|
24 |
+
flength_std: np.ndarray
|
25 |
+
pose_rot6d_mean: np.ndarray
|
26 |
+
keypoint_weights: List[float]
|
27 |
+
|
28 |
+
# SMAL samples 3d statistics
|
29 |
+
# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore
|
30 |
+
def load_statistics(statistics_path):
|
31 |
+
with open(statistics_path) as f:
|
32 |
+
statistics = json.load(f)
|
33 |
+
'''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']]
|
34 |
+
statistics['pose_mean'] = new_pose_mean
|
35 |
+
j_out = json.dumps(statistics, indent=4) #, sort_keys=True)
|
36 |
+
with open(self.statistics_path, 'w') as file: file.write(j_out)'''
|
37 |
+
new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']),
|
38 |
+
'trans_std': np.asarray(statistics['trans_std']),
|
39 |
+
'flength_mean': np.asarray(statistics['flength_mean']),
|
40 |
+
'flength_std': np.asarray(statistics['flength_std']),
|
41 |
+
'pose_mean': np.asarray(statistics['pose_mean']),
|
42 |
+
}
|
43 |
+
new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6))
|
44 |
+
return new_statistics
|
45 |
+
STATISTICS = load_statistics(STATISTICS_PATH)
|
46 |
+
|
47 |
+
AniPose_JOINT_NAMES_swapped = [
|
48 |
+
'L_F_Paw', 'L_F_Knee', 'L_F_Elbow',
|
49 |
+
'L_B_Paw', 'L_B_Knee', 'L_B_Elbow',
|
50 |
+
'R_F_Paw', 'R_F_Knee', 'R_F_Elbow',
|
51 |
+
'R_B_Paw', 'R_B_Knee', 'R_B_Elbow',
|
52 |
+
'TailBase', '_Tail_end_', 'L_EarBase', 'R_EarBase',
|
53 |
+
'Nose', '_Chin_', '_Left_ear_tip_', '_Right_ear_tip_',
|
54 |
+
'L_Eye', 'R_Eye', 'Withers', 'Throat']
|
55 |
+
|
56 |
+
KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2]
|
57 |
+
|
58 |
+
COMPLETE_DATA_INFO = DataInfo(
|
59 |
+
rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
|
60 |
+
rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
|
61 |
+
joint_names=AniPose_JOINT_NAMES_swapped, # AniPose_JOINT_NAMES,
|
62 |
+
hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23],
|
63 |
+
n_joints = 35,
|
64 |
+
n_keyp = 24, # 20, # 25,
|
65 |
+
n_bones = 24,
|
66 |
+
n_betas = 30, # 10,
|
67 |
+
image_size = 256,
|
68 |
+
trans_mean = STATISTICS['trans_mean'],
|
69 |
+
trans_std = STATISTICS['trans_std'],
|
70 |
+
flength_mean = STATISTICS['flength_mean'],
|
71 |
+
flength_std = STATISTICS['flength_std'],
|
72 |
+
pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
|
73 |
+
keypoint_weights = KEYPOINT_WEIGHTS
|
74 |
+
)
|
src/configs/barc_cfg_defaults.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from yacs.config import CfgNode as CN
|
3 |
+
import argparse
|
4 |
+
import yaml
|
5 |
+
import os
|
6 |
+
|
7 |
+
abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
|
8 |
+
|
9 |
+
_C = CN()
|
10 |
+
_C.barc_dir = abs_barc_dir
|
11 |
+
_C.device = 'cuda'
|
12 |
+
|
13 |
+
## path settings
|
14 |
+
_C.paths = CN()
|
15 |
+
_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/'
|
16 |
+
_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/'
|
17 |
+
_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
|
18 |
+
|
19 |
+
## parameter settings
|
20 |
+
_C.params = CN()
|
21 |
+
_C.params.ARCH = 'hg8'
|
22 |
+
_C.params.STRUCTURE_POSE_NET = 'normflow' # 'default' # 'vae'
|
23 |
+
_C.params.NF_VERSION = 3
|
24 |
+
_C.params.N_JOINTS = 35
|
25 |
+
_C.params.N_KEYP = 24 #20
|
26 |
+
_C.params.N_SEG = 2
|
27 |
+
_C.params.N_PARTSEG = 15
|
28 |
+
_C.params.UPSAMPLE_SEG = True
|
29 |
+
_C.params.ADD_PARTSEG = True # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt)
|
30 |
+
_C.params.N_BETAS = 30 # 10
|
31 |
+
_C.params.N_BETAS_LIMBS = 7
|
32 |
+
_C.params.N_BONES = 24
|
33 |
+
_C.params.N_BREEDS = 121 # 120 breeds plus background
|
34 |
+
_C.params.IMG_SIZE = 256
|
35 |
+
_C.params.SILH_NO_TAIL = False
|
36 |
+
_C.params.KP_THRESHOLD = None
|
37 |
+
_C.params.ADD_Z_TO_3D_INPUT = False
|
38 |
+
_C.params.N_SEGBPS = 64*2
|
39 |
+
_C.params.ADD_SEGBPS_TO_3D_INPUT = True
|
40 |
+
_C.params.FIX_FLENGTH = False
|
41 |
+
_C.params.RENDER_ALL = True
|
42 |
+
_C.params.VLIN = 2
|
43 |
+
_C.params.STRUCTURE_Z_TO_B = 'lin'
|
44 |
+
_C.params.N_Z_FREE = 64
|
45 |
+
_C.params.PCK_THRESH = 0.15
|
46 |
+
_C.params.REF_NET_TYPE = 'add' # refinement network type
|
47 |
+
_C.params.REF_DETACH_SHAPE = True
|
48 |
+
_C.params.GRAPHCNN_TYPE = 'inexistent'
|
49 |
+
_C.params.ISFLAT_TYPE = 'inexistent'
|
50 |
+
_C.params.SHAPEREF_TYPE = 'inexistent'
|
51 |
+
|
52 |
+
## SMAL settings
|
53 |
+
_C.smal = CN()
|
54 |
+
_C.smal.SMAL_MODEL_TYPE = 'barc'
|
55 |
+
_C.smal.SMAL_KEYP_CONF = 'green'
|
56 |
+
|
57 |
+
## optimization settings
|
58 |
+
_C.optim = CN()
|
59 |
+
_C.optim.LR = 5e-4
|
60 |
+
_C.optim.SCHEDULE = [150, 175, 200]
|
61 |
+
_C.optim.GAMMA = 0.1
|
62 |
+
_C.optim.MOMENTUM = 0
|
63 |
+
_C.optim.WEIGHT_DECAY = 0
|
64 |
+
_C.optim.EPOCHS = 220
|
65 |
+
_C.optim.BATCH_SIZE = 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
|
66 |
+
_C.optim.TRAIN_PARTS = 'all_without_shapedirs'
|
67 |
+
|
68 |
+
## dataset settings
|
69 |
+
_C.data = CN()
|
70 |
+
_C.data.DATASET = 'stanext24'
|
71 |
+
_C.data.V12 = True
|
72 |
+
_C.data.SHORTEN_VAL_DATASET_TO = None
|
73 |
+
_C.data.VAL_OPT = 'val'
|
74 |
+
_C.data.VAL_METRICS = 'no_loss'
|
75 |
+
|
76 |
+
# ---------------------------------------
|
77 |
+
def update_dependent_vars(cfg):
|
78 |
+
cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG
|
79 |
+
if cfg.params.VLIN == 0:
|
80 |
+
cfg.params.NUM_STAGE_COMB = 2
|
81 |
+
cfg.params.NUM_STAGE_HEADS = 1
|
82 |
+
cfg.params.NUM_STAGE_HEADS_POSE = 1
|
83 |
+
cfg.params.TRANS_SEP = False
|
84 |
+
elif cfg.params.VLIN == 1:
|
85 |
+
cfg.params.NUM_STAGE_COMB = 3
|
86 |
+
cfg.params.NUM_STAGE_HEADS = 1
|
87 |
+
cfg.params.NUM_STAGE_HEADS_POSE = 2
|
88 |
+
cfg.params.TRANS_SEP = False
|
89 |
+
elif cfg.params.VLIN == 2:
|
90 |
+
cfg.params.NUM_STAGE_COMB = 3
|
91 |
+
cfg.params.NUM_STAGE_HEADS = 1
|
92 |
+
cfg.params.NUM_STAGE_HEADS_POSE = 2
|
93 |
+
cfg.params.TRANS_SEP = True
|
94 |
+
else:
|
95 |
+
raise NotImplementedError
|
96 |
+
if cfg.params.STRUCTURE_Z_TO_B == '1dconv':
|
97 |
+
cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS
|
98 |
+
else:
|
99 |
+
cfg.params.N_Z = cfg.params.N_Z_FREE
|
100 |
+
return
|
101 |
+
|
102 |
+
|
103 |
+
update_dependent_vars(_C)
|
104 |
+
global _cfg_global
|
105 |
+
_cfg_global = _C.clone()
|
106 |
+
|
107 |
+
|
108 |
+
def get_cfg_defaults():
|
109 |
+
# Get a yacs CfgNode object with default values as defined within this file.
|
110 |
+
# Return a clone so that the defaults will not be altered.
|
111 |
+
return _C.clone()
|
112 |
+
|
113 |
+
def update_cfg_global_with_yaml(cfg_yaml_file):
|
114 |
+
_cfg_global.merge_from_file(cfg_yaml_file)
|
115 |
+
update_dependent_vars(_cfg_global)
|
116 |
+
return
|
117 |
+
|
118 |
+
def get_cfg_global_updated():
|
119 |
+
# return _cfg_global.clone()
|
120 |
+
return _cfg_global
|
121 |
+
|
src/configs/barc_cfg_train.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
paths:
|
3 |
+
ROOT_OUT_PATH: './results/'
|
4 |
+
ROOT_CHECKPOINT_PATH: './checkpoint/'
|
5 |
+
MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
|
6 |
+
|
7 |
+
smal:
|
8 |
+
SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_diffsize' # 'barc'
|
9 |
+
SMAL_KEYP_CONF: 'olive' # 'green'
|
10 |
+
|
11 |
+
optim:
|
12 |
+
LR: 5e-4
|
13 |
+
SCHEDULE: [150, 175, 200]
|
14 |
+
GAMMA: 0.1
|
15 |
+
MOMENTUM: 0
|
16 |
+
WEIGHT_DECAY: 0
|
17 |
+
EPOCHS: 220
|
18 |
+
BATCH_SIZE: 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
|
19 |
+
TRAIN_PARTS: 'all_without_shapedirs'
|
20 |
+
|
21 |
+
data:
|
22 |
+
DATASET: 'stanext24'
|
23 |
+
SHORTEN_VAL_DATASET_TO: 600 # this is faster as we do not evaluate on the whole validation set
|
24 |
+
VAL_OPT: 'val'
|
src/configs/barc_loss_weights_allzeros.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
{
|
5 |
+
"breed_options": [
|
6 |
+
"4"
|
7 |
+
],
|
8 |
+
"breed": 0.0,
|
9 |
+
"class": 0.0,
|
10 |
+
"models3d": 0.0,
|
11 |
+
"keyp": 0.0,
|
12 |
+
"silh": 0.0,
|
13 |
+
"shape_options": [
|
14 |
+
"smal",
|
15 |
+
"limbs7"
|
16 |
+
],
|
17 |
+
"shape": [
|
18 |
+
0,
|
19 |
+
0
|
20 |
+
],
|
21 |
+
"poseprior_options": [
|
22 |
+
"normalizing_flow_tiger_logprob"
|
23 |
+
],
|
24 |
+
"poseprior": 0.0,
|
25 |
+
"poselegssidemovement": 0.0,
|
26 |
+
"flength": 0.0,
|
27 |
+
"partseg": 0,
|
28 |
+
"shapedirs": 0,
|
29 |
+
"pose_0": 0.0
|
30 |
+
}
|
src/configs/barc_loss_weights_with3dcgloss_higherbetaloss_v2_dm39dnnv3v2.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
{
|
5 |
+
"breed_options": [
|
6 |
+
"4"
|
7 |
+
],
|
8 |
+
"breed": 5.0,
|
9 |
+
"class": 5.0,
|
10 |
+
"models3d": 0.1,
|
11 |
+
"keyp": 0.2,
|
12 |
+
"silh": 50.0,
|
13 |
+
"shape_options": [
|
14 |
+
"smal",
|
15 |
+
"limbs7"
|
16 |
+
],
|
17 |
+
"shape": [
|
18 |
+
0.1,
|
19 |
+
1.0
|
20 |
+
],
|
21 |
+
"poseprior_options": [
|
22 |
+
"normalizing_flow_tiger_logprob"
|
23 |
+
],
|
24 |
+
"poseprior": 0.1,
|
25 |
+
"poselegssidemovement": 10.0,
|
26 |
+
"flength": 1.0,
|
27 |
+
"partseg": 0,
|
28 |
+
"shapedirs": 0,
|
29 |
+
"pose_0": 0.0
|
30 |
+
}
|
src/configs/data_info.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
|
8 |
+
STATISTICS_DATA_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'statistics')
|
9 |
+
STATISTICS_PATH = os.path.join(STATISTICS_DATA_DIR, 'statistics_modified_v1.json')
|
10 |
+
|
11 |
+
@dataclass
|
12 |
+
class DataInfo:
|
13 |
+
rgb_mean: List[float]
|
14 |
+
rgb_stddev: List[float]
|
15 |
+
joint_names: List[str]
|
16 |
+
hflip_indices: List[int]
|
17 |
+
n_joints: int
|
18 |
+
n_keyp: int
|
19 |
+
n_bones: int
|
20 |
+
n_betas: int
|
21 |
+
image_size: int
|
22 |
+
trans_mean: np.ndarray
|
23 |
+
trans_std: np.ndarray
|
24 |
+
flength_mean: np.ndarray
|
25 |
+
flength_std: np.ndarray
|
26 |
+
pose_rot6d_mean: np.ndarray
|
27 |
+
keypoint_weights: List[float]
|
28 |
+
|
29 |
+
# SMAL samples 3d statistics
|
30 |
+
# statistics like mean values were calculated once when the project was started and they were not changed afterwards anymore
|
31 |
+
def load_statistics(statistics_path):
|
32 |
+
with open(statistics_path) as f:
|
33 |
+
statistics = json.load(f)
|
34 |
+
'''new_pose_mean = [[[np.round(val, 2) for val in sublst] for sublst in sublst_big] for sublst_big in statistics['pose_mean']]
|
35 |
+
statistics['pose_mean'] = new_pose_mean
|
36 |
+
j_out = json.dumps(statistics, indent=4) #, sort_keys=True)
|
37 |
+
with open(self.statistics_path, 'w') as file: file.write(j_out)'''
|
38 |
+
new_statistics = {'trans_mean': np.asarray(statistics['trans_mean']),
|
39 |
+
'trans_std': np.asarray(statistics['trans_std']),
|
40 |
+
'flength_mean': np.asarray(statistics['flength_mean']),
|
41 |
+
'flength_std': np.asarray(statistics['flength_std']),
|
42 |
+
'pose_mean': np.asarray(statistics['pose_mean']),
|
43 |
+
}
|
44 |
+
new_statistics['pose_rot6d_mean'] = new_statistics['pose_mean'][:, :, :2].reshape((-1, 6))
|
45 |
+
return new_statistics
|
46 |
+
STATISTICS = load_statistics(STATISTICS_PATH)
|
47 |
+
|
48 |
+
|
49 |
+
############################################################################
|
50 |
+
# for StanExt (original number of keypoints, 20 not 24)
|
51 |
+
|
52 |
+
# for keypoint names see: https://github.com/benjiebob/StanfordExtra/blob/master/keypoint_definitions.csv
|
53 |
+
StanExt_JOINT_NAMES = [
|
54 |
+
'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top',
|
55 |
+
'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top',
|
56 |
+
'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top',
|
57 |
+
'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top',
|
58 |
+
'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear',
|
59 |
+
'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip']
|
60 |
+
|
61 |
+
KEYPOINT_WEIGHTS = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2]
|
62 |
+
|
63 |
+
COMPLETE_DATA_INFO = DataInfo(
|
64 |
+
rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
|
65 |
+
rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
|
66 |
+
joint_names=StanExt_JOINT_NAMES,
|
67 |
+
hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18],
|
68 |
+
n_joints = 35,
|
69 |
+
n_keyp = 20, # 25,
|
70 |
+
n_bones = 24,
|
71 |
+
n_betas = 30, # 10,
|
72 |
+
image_size = 256,
|
73 |
+
trans_mean = STATISTICS['trans_mean'],
|
74 |
+
trans_std = STATISTICS['trans_std'],
|
75 |
+
flength_mean = STATISTICS['flength_mean'],
|
76 |
+
flength_std = STATISTICS['flength_std'],
|
77 |
+
pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
|
78 |
+
keypoint_weights = KEYPOINT_WEIGHTS
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
############################################################################
|
83 |
+
# new for StanExt24
|
84 |
+
|
85 |
+
# ..., 'Left_eye', 'Right_eye', 'Withers', 'Throat'] # the last 4 keypoints are in the animal_pose dataset, but not StanfordExtra
|
86 |
+
StanExt_JOINT_NAMES_24 = [
|
87 |
+
'Left_front_leg_paw', 'Left_front_leg_middle_joint', 'Left_front_leg_top',
|
88 |
+
'Left_rear_leg_paw', 'Left_rear_leg_middle_joint', 'Left_rear_leg_top',
|
89 |
+
'Right_front_leg_paw', 'Right_front_leg_middle_joint', 'Right_front_leg_top',
|
90 |
+
'Right_rear_leg_paw', 'Right_rear_leg_middle_joint', 'Right_rear_leg_top',
|
91 |
+
'Tail_start', 'Tail_end', 'Base_of_left_ear', 'Base_of_right_ear',
|
92 |
+
'Nose', 'Chin', 'Left_ear_tip', 'Right_ear_tip',
|
93 |
+
'Left_eye', 'Right_eye', 'Withers', 'Throat']
|
94 |
+
|
95 |
+
KEYPOINT_WEIGHTS_24 = [3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 2, 2, 3, 3, 2, 2, 3, 1, 2, 2, 1, 1, 0, 0]
|
96 |
+
|
97 |
+
COMPLETE_DATA_INFO_24 = DataInfo(
|
98 |
+
rgb_mean=[0.4404, 0.4440, 0.4327], # not sure
|
99 |
+
rgb_stddev=[0.2458, 0.2410, 0.2468], # not sure
|
100 |
+
joint_names=StanExt_JOINT_NAMES_24,
|
101 |
+
hflip_indices=[6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 12, 13, 15, 14, 16, 17, 19, 18, 21, 20, 22, 23],
|
102 |
+
n_joints = 35,
|
103 |
+
n_keyp = 24, # 20, # 25,
|
104 |
+
n_bones = 24,
|
105 |
+
n_betas = 30, # 10,
|
106 |
+
image_size = 256,
|
107 |
+
trans_mean = STATISTICS['trans_mean'],
|
108 |
+
trans_std = STATISTICS['trans_std'],
|
109 |
+
flength_mean = STATISTICS['flength_mean'],
|
110 |
+
flength_std = STATISTICS['flength_std'],
|
111 |
+
pose_rot6d_mean = STATISTICS['pose_rot6d_mean'],
|
112 |
+
keypoint_weights = KEYPOINT_WEIGHTS_24
|
113 |
+
)
|
114 |
+
|
115 |
+
|
src/configs/dataset_path_configs.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
|
7 |
+
abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))
|
8 |
+
|
9 |
+
# stanext dataset
|
10 |
+
# (1) path to stanext dataset
|
11 |
+
STAN_V12_ROOT_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset' + '/StanfordExtra_V12/'
|
12 |
+
IMG_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'StanExtV12_Images')
|
13 |
+
JSON_V12_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', "StanfordExtra_v12.json")
|
14 |
+
STAN_V12_TRAIN_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'train_stanford_StanfordExtra_v12.npy')
|
15 |
+
STAN_V12_VAL_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'val_stanford_StanfordExtra_v12.npy')
|
16 |
+
STAN_V12_TEST_LIST_DIR = os.path.join(STAN_V12_ROOT_DIR, 'labels', 'test_stanford_StanfordExtra_v12.npy')
|
17 |
+
# (2) path to related data such as breed indices and prepared predictions for withers, throat and eye keypoints
|
18 |
+
STANEXT_RELATED_DATA_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'stanext_related_data')
|
19 |
+
|
20 |
+
# test image crop dataset
|
21 |
+
TEST_IMAGE_CROP_ROOT_DIR = os.path.join(os.path.dirname(__file__), '..', '..', 'datasets', 'test_image_crops')
|
src/configs/dog_breeds/dog_breed_class.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import warnings
|
4 |
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
5 |
+
import pandas as pd
|
6 |
+
import difflib
|
7 |
+
import json
|
8 |
+
import pickle as pkl
|
9 |
+
import csv
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
|
13 |
+
# ----------------------------------------------------------------------------------------------------------------- #
|
14 |
+
class DogBreed(object):
|
15 |
+
def __init__(self, abbrev, name_akc=None, name_stanext=None, name_xlsx=None, path_akc=None, path_stanext=None, ind_in_xlsx=None, ind_in_xlsx_matrix=None, ind_in_stanext=None, clade=None):
|
16 |
+
self._abbrev = abbrev
|
17 |
+
self._name_xlsx = name_xlsx
|
18 |
+
self._name_akc = name_akc
|
19 |
+
self._name_stanext = name_stanext
|
20 |
+
self._path_stanext = path_stanext
|
21 |
+
self._additional_names = set()
|
22 |
+
if self._name_akc is not None:
|
23 |
+
self.add_akc_info(name_akc, path_akc)
|
24 |
+
if self._name_stanext is not None:
|
25 |
+
self.add_stanext_info(name_stanext, path_stanext, ind_in_stanext)
|
26 |
+
if self._name_xlsx is not None:
|
27 |
+
self.add_xlsx_info(name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade)
|
28 |
+
def add_xlsx_info(self, name_xlsx, ind_in_xlsx, ind_in_xlsx_matrix, clade):
|
29 |
+
assert (name_xlsx is not None) and (ind_in_xlsx is not None) and (ind_in_xlsx_matrix is not None) and (clade is not None)
|
30 |
+
self._name_xlsx = name_xlsx
|
31 |
+
self._ind_in_xlsx = ind_in_xlsx
|
32 |
+
self._ind_in_xlsx_matrix = ind_in_xlsx_matrix
|
33 |
+
self._clade = clade
|
34 |
+
def add_stanext_info(self, name_stanext, path_stanext, ind_in_stanext):
|
35 |
+
assert (name_stanext is not None) and (path_stanext is not None) and (ind_in_stanext is not None)
|
36 |
+
self._name_stanext = name_stanext
|
37 |
+
self._path_stanext = path_stanext
|
38 |
+
self._ind_in_stanext = ind_in_stanext
|
39 |
+
def add_akc_info(self, name_akc, path_akc):
|
40 |
+
assert (name_akc is not None) and (path_akc is not None)
|
41 |
+
self._name_akc = name_akc
|
42 |
+
self._path_akc = path_akc
|
43 |
+
def add_additional_names(self, name_list):
|
44 |
+
self._additional_names = self._additional_names.union(set(name_list))
|
45 |
+
def add_text_info(self, text_height, text_weight, text_life_exp):
|
46 |
+
self._text_height = text_height
|
47 |
+
self._text_weight = text_weight
|
48 |
+
self._text_life_exp = text_life_exp
|
49 |
+
def get_datasets(self):
|
50 |
+
# all datasets in which this breed is found
|
51 |
+
datasets = set()
|
52 |
+
if self._name_akc is not None:
|
53 |
+
datasets.add('akc')
|
54 |
+
if self._name_stanext is not None:
|
55 |
+
datasets.add('stanext')
|
56 |
+
if self._name_xlsx is not None:
|
57 |
+
datasets.add('xlsx')
|
58 |
+
return datasets
|
59 |
+
def get_names(self):
|
60 |
+
# set of names for this breed
|
61 |
+
names = {self._abbrev, self._name_akc, self._name_stanext, self._name_xlsx, self._path_stanext}.union(self._additional_names)
|
62 |
+
names.discard(None)
|
63 |
+
return names
|
64 |
+
def get_names_as_pointing_dict(self):
|
65 |
+
# each name points to the abbreviation
|
66 |
+
names = self.get_names()
|
67 |
+
my_dict = {}
|
68 |
+
for name in names:
|
69 |
+
my_dict[name] = self._abbrev
|
70 |
+
return my_dict
|
71 |
+
def print_overview(self):
|
72 |
+
# print important information to get an overview of the class instance
|
73 |
+
if self._name_akc is not None:
|
74 |
+
name = self._name_akc
|
75 |
+
elif self._name_xlsx is not None:
|
76 |
+
name = self._name_xlsx
|
77 |
+
else:
|
78 |
+
name = self._name_stanext
|
79 |
+
print('----------------------------------------------------')
|
80 |
+
print('----- dog breed: ' + name )
|
81 |
+
print('----------------------------------------------------')
|
82 |
+
print('[names]')
|
83 |
+
print(self.get_names())
|
84 |
+
print('[datasets]')
|
85 |
+
print(self.get_datasets())
|
86 |
+
# see https://stackoverflow.com/questions/9058305/getting-attributes-of-a-class
|
87 |
+
print('[instance attributes]')
|
88 |
+
for attribute, value in self.__dict__.items():
|
89 |
+
print(attribute, '=', value)
|
90 |
+
def use_dict_to_save_class_instance(self):
|
91 |
+
my_dict = {}
|
92 |
+
for attribute, value in self.__dict__.items():
|
93 |
+
my_dict[attribute] = value
|
94 |
+
return my_dict
|
95 |
+
def use_dict_to_load_class_instance(self, my_dict):
|
96 |
+
for attribute, value in my_dict.items():
|
97 |
+
setattr(self, attribute, value)
|
98 |
+
return
|
99 |
+
|
100 |
+
# ----------------------------------------------------------------------------------------------------------------- #
|
101 |
+
def get_name_list_from_summary(summary):
|
102 |
+
name_from_abbrev_dict = {}
|
103 |
+
for breed in summary.values():
|
104 |
+
abbrev = breed._abbrev
|
105 |
+
all_names = breed.get_names()
|
106 |
+
name_from_abbrev_dict[abbrev] = list(all_names)
|
107 |
+
return name_from_abbrev_dict
|
108 |
+
def get_partial_summary(summary, part):
|
109 |
+
assert part in ['xlsx', 'akc', 'stanext']
|
110 |
+
partial_summary = {}
|
111 |
+
for key, value in summary.items():
|
112 |
+
if (part == 'xlsx' and value._name_xlsx is not None) \
|
113 |
+
or (part == 'akc' and value._name_akc is not None) \
|
114 |
+
or (part == 'stanext' and value._name_stanext is not None):
|
115 |
+
partial_summary[key] = value
|
116 |
+
return partial_summary
|
117 |
+
def get_akc_but_not_stanext_partial_summary(summary):
|
118 |
+
partial_summary = {}
|
119 |
+
for key, value in summary.items():
|
120 |
+
if value._name_akc is not None:
|
121 |
+
if value._name_stanext is None:
|
122 |
+
partial_summary[key] = value
|
123 |
+
return partial_summary
|
124 |
+
|
125 |
+
# ----------------------------------------------------------------------------------------------------------------- #
|
126 |
+
def main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1):
|
127 |
+
with open(path_complete_abbrev_dict_v1, 'rb') as file:
|
128 |
+
complete_abbrev_dict = pkl.load(file)
|
129 |
+
with open(path_complete_summary_breeds_v1, 'rb') as file:
|
130 |
+
complete_summary_breeds_attributes_only = pkl.load(file)
|
131 |
+
|
132 |
+
complete_summary_breeds = {}
|
133 |
+
for key, value in complete_summary_breeds_attributes_only.items():
|
134 |
+
attributes_only = complete_summary_breeds_attributes_only[key]
|
135 |
+
complete_summary_breeds[key] = DogBreed(abbrev=attributes_only['_abbrev'])
|
136 |
+
complete_summary_breeds[key].use_dict_to_load_class_instance(attributes_only)
|
137 |
+
return complete_abbrev_dict, complete_summary_breeds
|
138 |
+
|
139 |
+
|
140 |
+
# ----------------------------------------------------------------------------------------------------------------- #
|
141 |
+
def load_similarity_matrix_raw(xlsx_path):
|
142 |
+
# --- LOAD EXCEL FILE FROM DOG BREED PAPER
|
143 |
+
xlsx = pd.read_excel(xlsx_path)
|
144 |
+
# create an array
|
145 |
+
abbrev_indices = {}
|
146 |
+
matrix_raw = np.zeros((168, 168))
|
147 |
+
for ind in range(1, 169):
|
148 |
+
abbrev = xlsx[xlsx.columns[2]][ind]
|
149 |
+
abbrev_indices[abbrev] = ind-1
|
150 |
+
for ind_col in range(0, 168):
|
151 |
+
for ind_row in range(0, 168):
|
152 |
+
matrix_raw[ind_col, ind_row] = float(xlsx[xlsx.columns[3+ind_col]][1+ind_row])
|
153 |
+
return matrix_raw, abbrev_indices
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
# ----------------------------------------------------------------------------------------------------------------- #
|
158 |
+
# ----------------------------------------------------------------------------------------------------------------- #
|
159 |
+
# load the (in advance created) final dict of dog breed classes
|
160 |
+
ROOT_PATH_BREED_DATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', '..', 'data', 'breed_data')
|
161 |
+
path_complete_abbrev_dict_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_abbrev_dict_v2.pkl')
|
162 |
+
path_complete_summary_breeds_v1 = os.path.join(ROOT_PATH_BREED_DATA, 'complete_summary_breeds_v2.pkl')
|
163 |
+
COMPLETE_ABBREV_DICT, COMPLETE_SUMMARY_BREEDS = main_load_dog_breed_classes(path_complete_abbrev_dict_v1, path_complete_summary_breeds_v1)
|
164 |
+
# load similarity matrix, data from:
|
165 |
+
# Parker H. G., Dreger D. L., Rimbault M., Davis B. W., Mullen A. B., Carpintero-Ramirez G., and Ostrander E. A.
|
166 |
+
# Genomic analyses reveal the influence of geographic origin, migration, and hybridization on modern dog breed
|
167 |
+
# development. Cell Reports, 4(19):697–708, 2017.
|
168 |
+
xlsx_path = os.path.join(ROOT_PATH_BREED_DATA, 'NIHMS866262-supplement-2.xlsx')
|
169 |
+
SIM_MATRIX_RAW, SIM_ABBREV_INDICES = load_similarity_matrix_raw(xlsx_path)
|
170 |
+
|
src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
paths:
|
3 |
+
ROOT_OUT_PATH: './results/'
|
4 |
+
ROOT_CHECKPOINT_PATH: './checkpoint/'
|
5 |
+
MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
|
6 |
+
|
7 |
+
smal:
|
8 |
+
SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc'
|
9 |
+
SMAL_KEYP_CONF: 'olive' # 'green'
|
10 |
+
|
11 |
+
optim:
|
12 |
+
BATCH_SIZE: 12
|
13 |
+
|
14 |
+
params:
|
15 |
+
REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot' # 'add'
|
16 |
+
REF_DETACH_SHAPE: True
|
17 |
+
GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent'
|
18 |
+
SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent'
|
19 |
+
ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent'
|
20 |
+
|
21 |
+
data:
|
22 |
+
DATASET: 'stanext24'
|
23 |
+
VAL_OPT: 'test' # 'val'
|
src/configs/refinement_cfg_test_withvertexwisegc_csaddnonflat_crops.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
paths:
|
3 |
+
ROOT_OUT_PATH: './results/'
|
4 |
+
ROOT_CHECKPOINT_PATH: './checkpoint/'
|
5 |
+
MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
|
6 |
+
|
7 |
+
smal:
|
8 |
+
SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc'
|
9 |
+
SMAL_KEYP_CONF: 'olive' # 'green'
|
10 |
+
|
11 |
+
optim:
|
12 |
+
BATCH_SIZE: 12
|
13 |
+
|
14 |
+
params:
|
15 |
+
REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot' # 'add'
|
16 |
+
REF_DETACH_SHAPE: True
|
17 |
+
GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent'
|
18 |
+
SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent'
|
19 |
+
ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent'
|
20 |
+
|
21 |
+
data:
|
22 |
+
DATASET: 'ImgCropList'
|
23 |
+
VAL_OPT: 'test' # 'val'
|
src/configs/refinement_cfg_train_withvertexwisegc_isflat_csmorestanding.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
paths:
|
3 |
+
ROOT_OUT_PATH: './results/'
|
4 |
+
ROOT_CHECKPOINT_PATH: './checkpoint/'
|
5 |
+
MODELPATH_NORMFLOW: './checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'
|
6 |
+
|
7 |
+
smal:
|
8 |
+
SMAL_MODEL_TYPE: '39dogs_norm_newv3' # '39dogs_norm' # '39dogs_diffsize' # 'barc'
|
9 |
+
SMAL_KEYP_CONF: 'olive' # 'green'
|
10 |
+
|
11 |
+
optim:
|
12 |
+
LR: 5e-5 # 5e-7 # (new) 5e-6 # 5e-5 # 5e-5 # 5e-4
|
13 |
+
SCHEDULE: [150, 175, 200] # [220, 270] # [150, 175, 200]
|
14 |
+
GAMMA: 0.1
|
15 |
+
MOMENTUM: 0
|
16 |
+
WEIGHT_DECAY: 0
|
17 |
+
EPOCHS: 220 # 300
|
18 |
+
BATCH_SIZE: 14 # 12 # keep 12 (needs to be an even number, as we have a custom data sampler)
|
19 |
+
TRAIN_PARTS: 'refinement_model' # 'refinement_model_and_shape' # 'refinement_model'
|
20 |
+
|
21 |
+
params:
|
22 |
+
REF_NET_TYPE: 'multrot01all_res34' # 'multrot01all_res34' # 'multrot01all' # 'multrot01' # 'multrot01' # 'multrot01' # 'multrot' # 'multrot_res34' # 'multrot' # 'add'
|
23 |
+
REF_DETACH_SHAPE: True
|
24 |
+
GRAPHCNN_TYPE: 'multistage_simple' # 'inexistent'
|
25 |
+
SHAPEREF_TYPE: 'inexistent' # 'linear' # 'inexistent'
|
26 |
+
ISFLAT_TYPE: 'linear' # 'inexistent' # 'inexistent'
|
27 |
+
|
28 |
+
data:
|
29 |
+
DATASET: 'stanext24_withgc_csaddnonflatmorestanding' # 'stanext24_withgc_csaddnonflat' # 'stanext24_withgc_cs0'
|
30 |
+
SHORTEN_VAL_DATASET_TO: 600 # this is faster as we do not evaluate on the whole validation set
|
31 |
+
VAL_OPT: 'val'
|
src/configs/refinement_loss_weights_withgc_withvertexwise_addnonflat.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
{
|
5 |
+
"keyp_ref": 0.2,
|
6 |
+
"silh_ref": 50.0,
|
7 |
+
"pose_legs_side": 1.0,
|
8 |
+
"pose_legs_tors": 1.0,
|
9 |
+
"pose_tail_side": 0.0,
|
10 |
+
"pose_tail_tors": 0.0,
|
11 |
+
"pose_spine_side": 0.0,
|
12 |
+
"pose_spine_tors": 0.0,
|
13 |
+
"reg_trans": 0.0,
|
14 |
+
"reg_flength": 0.0,
|
15 |
+
"reg_pose": 0.0,
|
16 |
+
"gc_plane": 5.0,
|
17 |
+
"gc_blowplane": 5.0,
|
18 |
+
"gc_vertexwise": 10.0,
|
19 |
+
"gc_isflat": 0.5
|
20 |
+
}
|
src/configs/ttopt_loss_weights/bite_loss_weights_ttopt.json
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"silhouette": {
|
3 |
+
"weight": 40.0,
|
4 |
+
"weight_vshift": 20.0,
|
5 |
+
"value": 0.0
|
6 |
+
},
|
7 |
+
"keyp":{
|
8 |
+
"weight": 0.2,
|
9 |
+
"weight_vshift": 0.01,
|
10 |
+
"value": 0.0
|
11 |
+
},
|
12 |
+
"pose_legs_side":{
|
13 |
+
"weight": 1.0,
|
14 |
+
"weight_vshift": 1.0,
|
15 |
+
"value": 0.0
|
16 |
+
},
|
17 |
+
"pose_legs_tors":{
|
18 |
+
"weight": 10.0,
|
19 |
+
"weight_vshift": 10.0,
|
20 |
+
"value": 0.0
|
21 |
+
},
|
22 |
+
"pose_tail_side":{
|
23 |
+
"weight": 1,
|
24 |
+
"weight_vshift": 1,
|
25 |
+
"value": 0.0
|
26 |
+
},
|
27 |
+
"pose_tail_tors":{
|
28 |
+
"weight": 10.0,
|
29 |
+
"weight_vshift": 10.0,
|
30 |
+
"value": 0.0
|
31 |
+
},
|
32 |
+
"pose_spine_side":{
|
33 |
+
"weight": 0.0,
|
34 |
+
"weight_vshift": 0.0,
|
35 |
+
"value": 0.0
|
36 |
+
},
|
37 |
+
"pose_spine_tors":{
|
38 |
+
"weight": 0.0,
|
39 |
+
"weight_vshift": 0.0,
|
40 |
+
"value": 0.0
|
41 |
+
},
|
42 |
+
"gc_plane":{
|
43 |
+
"weight": 10.0,
|
44 |
+
"weight_vshift": 20.0,
|
45 |
+
"value": 0.0
|
46 |
+
},
|
47 |
+
"gc_belowplane":{
|
48 |
+
"weight": 10.0,
|
49 |
+
"weight_vshift": 20.0,
|
50 |
+
"value": 0.0
|
51 |
+
},
|
52 |
+
"lapctf": {
|
53 |
+
"weight": 0.0,
|
54 |
+
"weight_vshift": 10.0,
|
55 |
+
"value": 0.0
|
56 |
+
},
|
57 |
+
"arap": {
|
58 |
+
"weight": 0.0,
|
59 |
+
"weight_vshift": 0.0,
|
60 |
+
"value": 0.0
|
61 |
+
},
|
62 |
+
"edge": {
|
63 |
+
"weight": 0.0,
|
64 |
+
"weight_vshift": 10.0,
|
65 |
+
"value": 0.0
|
66 |
+
},
|
67 |
+
"normal": {
|
68 |
+
"weight": 0.0,
|
69 |
+
"weight_vshift": 1.0,
|
70 |
+
"value": 0.0
|
71 |
+
},
|
72 |
+
"laplacian": {
|
73 |
+
"weight": 0.0,
|
74 |
+
"weight_vshift": 0.0,
|
75 |
+
"value": 0.0
|
76 |
+
}
|
77 |
+
}
|
src/configs/ttopt_loss_weights/ttopt_loss_weights_v2c_withlapcft_v2.json
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"silhouette": {
|
3 |
+
"weight": 40.0,
|
4 |
+
"weight_vshift": 20.0,
|
5 |
+
"value": 0.0
|
6 |
+
},
|
7 |
+
"keyp":{
|
8 |
+
"weight": 0.2,
|
9 |
+
"weight_vshift": 0.01,
|
10 |
+
"value": 0.0
|
11 |
+
},
|
12 |
+
"pose_legs_side":{
|
13 |
+
"weight": 1.0,
|
14 |
+
"weight_vshift": 1.0,
|
15 |
+
"value": 0.0
|
16 |
+
},
|
17 |
+
"pose_legs_tors":{
|
18 |
+
"weight": 10.0,
|
19 |
+
"weight_vshift": 10.0,
|
20 |
+
"value": 0.0
|
21 |
+
},
|
22 |
+
"pose_tail_side":{
|
23 |
+
"weight": 1,
|
24 |
+
"weight_vshift": 1,
|
25 |
+
"value": 0.0
|
26 |
+
},
|
27 |
+
"pose_tail_tors":{
|
28 |
+
"weight": 10.0,
|
29 |
+
"weight_vshift": 10.0,
|
30 |
+
"value": 0.0
|
31 |
+
},
|
32 |
+
"pose_spine_side":{
|
33 |
+
"weight": 0.0,
|
34 |
+
"weight_vshift": 0.0,
|
35 |
+
"value": 0.0
|
36 |
+
},
|
37 |
+
"pose_spine_tors":{
|
38 |
+
"weight": 0.0,
|
39 |
+
"weight_vshift": 0.0,
|
40 |
+
"value": 0.0
|
41 |
+
},
|
42 |
+
"gc_plane":{
|
43 |
+
"weight": 10.0,
|
44 |
+
"weight_vshift": 20.0,
|
45 |
+
"value": 0.0
|
46 |
+
},
|
47 |
+
"gc_belowplane":{
|
48 |
+
"weight": 10.0,
|
49 |
+
"weight_vshift": 20.0,
|
50 |
+
"value": 0.0
|
51 |
+
},
|
52 |
+
"lapctf": {
|
53 |
+
"weight": 0.0,
|
54 |
+
"weight_vshift": 10.0,
|
55 |
+
"value": 0.0
|
56 |
+
},
|
57 |
+
"arap": {
|
58 |
+
"weight": 0.0,
|
59 |
+
"weight_vshift": 0.0,
|
60 |
+
"value": 0.0
|
61 |
+
},
|
62 |
+
"edge": {
|
63 |
+
"weight": 0.0,
|
64 |
+
"weight_vshift": 10.0,
|
65 |
+
"value": 0.0
|
66 |
+
},
|
67 |
+
"normal": {
|
68 |
+
"weight": 0.0,
|
69 |
+
"weight_vshift": 1.0,
|
70 |
+
"value": 0.0
|
71 |
+
},
|
72 |
+
"laplacian": {
|
73 |
+
"weight": 0.0,
|
74 |
+
"weight_vshift": 0.0,
|
75 |
+
"value": 0.0
|
76 |
+
}
|
77 |
+
}
|
src/graph_networks/__init__.py
ADDED
File without changes
|
src/graph_networks/graphcmr/__init__.py
ADDED
File without changes
|
src/graph_networks/graphcmr/get_downsampled_mesh_npz.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# try to use aenv_conda3 (maybe also export PYOPENGL_PLATFORM=osmesa)
|
3 |
+
# python src/graph_networks/graphcmr/get_downsampled_mesh_npz.py
|
4 |
+
|
5 |
+
# see https://github.com/nkolot/GraphCMR/issues/35
|
6 |
+
|
7 |
+
|
8 |
+
from __future__ import print_function
|
9 |
+
# import mesh_sampling
|
10 |
+
from psbody.mesh import Mesh, MeshViewer, MeshViewers
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import copy
|
15 |
+
import argparse
|
16 |
+
import pickle
|
17 |
+
import time
|
18 |
+
import sys
|
19 |
+
import trimesh
|
20 |
+
|
21 |
+
|
22 |
+
|
23 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), "../../../../"))
|
24 |
+
from barc_for_bite.src.graph_networks.graphcmr.pytorch_coma_mesh_operations import generate_transform_matrices
|
25 |
+
from barc_for_bite.src.configs.SMAL_configs import SMAL_MODEL_CONFIG
|
26 |
+
from barc_for_bite.src.smal_pytorch.smal_model.smal_torch_new import SMAL
|
27 |
+
# smal_model_path = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/smal_data/new_dog_models/my_smpl_00791_nadine_Jr_4_dog.pkl'
|
28 |
+
|
29 |
+
|
30 |
+
SMAL_MODEL_TYPE = '39dogs_diffsize' # '39dogs_diffsize' # '39dogs_norm' # 'barc'
|
31 |
+
smal_model_path = SMAL_MODEL_CONFIG[SMAL_MODEL_TYPE]['smal_model_path']
|
32 |
+
|
33 |
+
# data_path_root = "/is/cluster/work/nrueegg/icon_pifu_related/ICON/lib/graph_networks/graphcmr/data/"
|
34 |
+
data_path_root = "/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/"
|
35 |
+
|
36 |
+
smal_dog_model_name = os.path.basename(smal_model_path).split('.pkl')[0] # 'my_smpl_SMBLD_nbj_v3'
|
37 |
+
suffix = "_template"
|
38 |
+
template_obj_path = data_path_root + smal_dog_model_name + suffix + ".obj"
|
39 |
+
|
40 |
+
print("Loading smal .. ")
|
41 |
+
print(SMAL_MODEL_TYPE)
|
42 |
+
print(smal_model_path)
|
43 |
+
|
44 |
+
smal = SMAL(smal_model_type=SMAL_MODEL_TYPE, template_name='neutral')
|
45 |
+
smal_verts = smal.v_template.detach().cpu().numpy() # (3889, 3)
|
46 |
+
smal_faces = smal.f # (7774, 3)
|
47 |
+
smal_trimesh = trimesh.base.Trimesh(vertices=smal_verts, faces=smal_faces, process=False, maintain_order=True)
|
48 |
+
smal_trimesh.export(file_obj=template_obj_path) # file_type='obj')
|
49 |
+
|
50 |
+
|
51 |
+
print("Loading data .. ")
|
52 |
+
reference_mesh_file = template_obj_path # 'data/barc_neutral_vertices.obj' # 'data/smpl_neutral_vertices.obj'
|
53 |
+
reference_mesh = Mesh(filename=reference_mesh_file)
|
54 |
+
|
55 |
+
# ds_factors = [4, 4] # ds_factors = [4,1] # Sampling factor of the mesh at each stage of sampling
|
56 |
+
ds_factors = [4, 4, 4, 4]
|
57 |
+
print("Generating Transform Matrices ..")
|
58 |
+
|
59 |
+
|
60 |
+
# Generates adjecency matrices A, downsampling matrices D, and upsamling matrices U by sampling
|
61 |
+
# the mesh 4 times. Each time the mesh is sampled by a factor of 4
|
62 |
+
|
63 |
+
# M,A,D,U = mesh_sampling.generate_transform_matrices(reference_mesh, ds_factors)
|
64 |
+
M,A,D,U = generate_transform_matrices(reference_mesh, ds_factors)
|
65 |
+
|
66 |
+
# REMARK: there is a warning:
|
67 |
+
# lib/graph_networks/graphcmr/../../../lib/graph_networks/graphcmr/pytorch_coma_mesh_operations.py:237: FutureWarning: `rcond` parameter will
|
68 |
+
# change to the default of machine precision times ``max(M, N)`` where M and N are the input matrix dimensions.
|
69 |
+
# To use the future default and silence this warning we advise to pass `rcond=None`, to keep using the old, explicitly pass `rcond=-1`.
|
70 |
+
|
71 |
+
|
72 |
+
print(type(A))
|
73 |
+
np.savez(data_path_root + 'mesh_downsampling_' + smal_dog_model_name + suffix + '.npz', A = A, D = D, U = U)
|
74 |
+
np.savez(data_path_root + 'meshes/' + 'mesh_downsampling_meshes' + smal_dog_model_name + suffix + '.npz', M = M)
|
75 |
+
|
76 |
+
for ind_m, my_mesh in enumerate(M):
|
77 |
+
new_suffix = '_template_downsampled' + str(ind_m)
|
78 |
+
my_mesh_tri = trimesh.Trimesh(vertices=my_mesh.v, faces=my_mesh.f, process=False, maintain_order=True)
|
79 |
+
my_mesh_tri.export(data_path_root + 'meshes/' + 'mesh_downsampling_meshes' + smal_dog_model_name + new_suffix + '.obj')
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
|
src/graph_networks/graphcmr/graph_cnn.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
|
3 |
+
This file contains the Definition of GraphCNN
|
4 |
+
GraphCNN includes ResNet50 as a submodule
|
5 |
+
"""
|
6 |
+
from __future__ import division
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from .graph_layers import GraphResBlock, GraphLinear
|
12 |
+
from .resnet import resnet50
|
13 |
+
|
14 |
+
class GraphCNN(nn.Module):
|
15 |
+
|
16 |
+
def __init__(self, A, ref_vertices, num_layers=5, num_channels=512):
|
17 |
+
super(GraphCNN, self).__init__()
|
18 |
+
self.A = A
|
19 |
+
self.ref_vertices = ref_vertices
|
20 |
+
self.resnet = resnet50(pretrained=True)
|
21 |
+
layers = [GraphLinear(3 + 2048, 2 * num_channels)]
|
22 |
+
layers.append(GraphResBlock(2 * num_channels, num_channels, A))
|
23 |
+
for i in range(num_layers):
|
24 |
+
layers.append(GraphResBlock(num_channels, num_channels, A))
|
25 |
+
self.shape = nn.Sequential(GraphResBlock(num_channels, 64, A),
|
26 |
+
GraphResBlock(64, 32, A),
|
27 |
+
nn.GroupNorm(32 // 8, 32),
|
28 |
+
nn.ReLU(inplace=True),
|
29 |
+
GraphLinear(32, 3))
|
30 |
+
self.gc = nn.Sequential(*layers)
|
31 |
+
self.camera_fc = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels),
|
32 |
+
nn.ReLU(inplace=True),
|
33 |
+
GraphLinear(num_channels, 1),
|
34 |
+
nn.ReLU(inplace=True),
|
35 |
+
nn.Linear(A.shape[0], 3))
|
36 |
+
|
37 |
+
def forward(self, image):
|
38 |
+
"""Forward pass
|
39 |
+
Inputs:
|
40 |
+
image: size = (B, 3, 224, 224)
|
41 |
+
Returns:
|
42 |
+
Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
|
43 |
+
Weak-perspective camera: size = (B, 3)
|
44 |
+
"""
|
45 |
+
batch_size = image.shape[0]
|
46 |
+
ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1)
|
47 |
+
image_resnet = self.resnet(image)
|
48 |
+
image_enc = image_resnet.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1])
|
49 |
+
x = torch.cat([ref_vertices, image_enc], dim=1)
|
50 |
+
x = self.gc(x)
|
51 |
+
shape = self.shape(x)
|
52 |
+
camera = self.camera_fc(x).view(batch_size, 3)
|
53 |
+
return shape, camera
|
src/graph_networks/graphcmr/graph_cnn_groundcontact.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
code from https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
|
3 |
+
This file contains the Definition of GraphCNN
|
4 |
+
GraphCNN includes ResNet50 as a submodule
|
5 |
+
"""
|
6 |
+
from __future__ import division
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
# from .resnet import resnet50
|
12 |
+
import torchvision.models as models
|
13 |
+
|
14 |
+
|
15 |
+
import os
|
16 |
+
import sys
|
17 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
|
18 |
+
from src.graph_networks.graphcmr.utils_mesh import Mesh
|
19 |
+
from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear
|
20 |
+
|
21 |
+
|
22 |
+
class GraphCNN(nn.Module):
|
23 |
+
|
24 |
+
def __init__(self, A, ref_vertices, n_resnet_in, n_resnet_out, num_layers=5, num_channels=512):
|
25 |
+
super(GraphCNN, self).__init__()
|
26 |
+
self.A = A
|
27 |
+
self.ref_vertices = ref_vertices
|
28 |
+
# self.resnet = resnet50(pretrained=True)
|
29 |
+
# -> within the GraphCMR network they ignore the last fully connected layer
|
30 |
+
# replace the first layer
|
31 |
+
self.resnet = models.resnet34(pretrained=False)
|
32 |
+
n_in = 3 + 1
|
33 |
+
self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
34 |
+
# replace the last layer
|
35 |
+
self.resnet.fc = nn.Linear(512, n_resnet_out)
|
36 |
+
|
37 |
+
|
38 |
+
layers = [GraphLinear(3 + n_resnet_out, 2 * num_channels)] # [GraphLinear(3 + 2048, 2 * num_channels)]
|
39 |
+
layers.append(GraphResBlock(2 * num_channels, num_channels, A))
|
40 |
+
for i in range(num_layers):
|
41 |
+
layers.append(GraphResBlock(num_channels, num_channels, A))
|
42 |
+
self.n_out_gc = 2 # two labels per vertex
|
43 |
+
self.gc = nn.Sequential(GraphResBlock(num_channels, 64, A),
|
44 |
+
GraphResBlock(64, 32, A),
|
45 |
+
nn.GroupNorm(32 // 8, 32),
|
46 |
+
nn.ReLU(inplace=True),
|
47 |
+
GraphLinear(32, self.n_out_gc))
|
48 |
+
self.gcnn = nn.Sequential(*layers)
|
49 |
+
self.n_out_flatground = 1
|
50 |
+
self.flat_ground = nn.Sequential(nn.GroupNorm(num_channels // 8, num_channels),
|
51 |
+
nn.ReLU(inplace=True),
|
52 |
+
GraphLinear(num_channels, 1),
|
53 |
+
nn.ReLU(inplace=True),
|
54 |
+
nn.Linear(A.shape[0], self.n_out_flatground))
|
55 |
+
|
56 |
+
def forward(self, image):
|
57 |
+
"""Forward pass
|
58 |
+
Inputs:
|
59 |
+
image: size = (B, 3, 256, 256)
|
60 |
+
Returns:
|
61 |
+
Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
|
62 |
+
Weak-perspective camera: size = (B, 3)
|
63 |
+
"""
|
64 |
+
# import pdb; pdb.set_trace()
|
65 |
+
|
66 |
+
batch_size = image.shape[0]
|
67 |
+
ref_vertices = self.ref_vertices[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
|
68 |
+
image_resnet = self.resnet(image) # (bs, 512)
|
69 |
+
image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973)
|
70 |
+
x = torch.cat([ref_vertices, image_enc], dim=1)
|
71 |
+
x = self.gcnn(x) # (bs, 512, 973)
|
72 |
+
ground_contact = self.gc(x) # (bs, 2, 973)
|
73 |
+
ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1)
|
74 |
+
return ground_contact, ground_flatness
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
# how to use it:
|
80 |
+
#
|
81 |
+
# from src.graph_networks.graphcmr.utils_mesh import Mesh
|
82 |
+
#
|
83 |
+
# create Mesh object
|
84 |
+
# self.mesh = Mesh()
|
85 |
+
# self.faces = self.mesh.faces.to(self.device)
|
86 |
+
#
|
87 |
+
# create GraphCNN
|
88 |
+
# self.graph_cnn = GraphCNN(self.mesh.adjmat,
|
89 |
+
# self.mesh.ref_vertices.t(),
|
90 |
+
# num_channels=self.options.num_channels,
|
91 |
+
# num_layers=self.options.num_layers
|
92 |
+
# ).to(self.device)
|
93 |
+
# ------------
|
94 |
+
#
|
95 |
+
# Feed image in the GraphCNN
|
96 |
+
# Returns subsampled mesh and camera parameters
|
97 |
+
# pred_vertices_sub, pred_camera = self.graph_cnn(images)
|
98 |
+
#
|
99 |
+
# Upsample mesh in the original size
|
100 |
+
# pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2))
|
101 |
+
#
|
src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
code from
|
3 |
+
https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
|
4 |
+
https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py
|
5 |
+
This file contains the Definition of GraphCNN
|
6 |
+
GraphCNN includes ResNet50 as a submodule
|
7 |
+
"""
|
8 |
+
from __future__ import division
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
# from .resnet import resnet50
|
14 |
+
import torchvision.models as models
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
|
20 |
+
from src.graph_networks.graphcmr.utils_mesh import Mesh
|
21 |
+
from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear
|
22 |
+
|
23 |
+
|
24 |
+
class GraphCNNMS(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self, mesh, num_downsample=0, num_layers=5, n_resnet_out=256, num_channels=256):
|
27 |
+
'''
|
28 |
+
Args:
|
29 |
+
mesh: mesh data that store the adjacency matrix
|
30 |
+
num_channels: number of channels of GCN
|
31 |
+
num_downsample: number of downsampling of the input mesh
|
32 |
+
'''
|
33 |
+
|
34 |
+
super(GraphCNNMS, self).__init__()
|
35 |
+
|
36 |
+
self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled
|
37 |
+
# self.num_layers = len(self.A) - 1
|
38 |
+
self.num_layers = num_layers
|
39 |
+
assert self.num_layers <= len(self.A) - 1
|
40 |
+
print("Number of downsampling layer: {}".format(self.num_layers))
|
41 |
+
self.num_downsample = num_downsample
|
42 |
+
self.n_resnet_out = n_resnet_out
|
43 |
+
|
44 |
+
|
45 |
+
'''
|
46 |
+
self.use_pret_res = use_pret_res
|
47 |
+
# self.resnet = resnet50(pretrained=True)
|
48 |
+
# -> within the GraphCMR network they ignore the last fully connected layer
|
49 |
+
# replace the first layer
|
50 |
+
self.resnet = models.resnet34(pretrained=self.use_pret_res)
|
51 |
+
if (self.use_pret_res) and (n_resnet_in == 3):
|
52 |
+
print('use full pretrained resnet including first layer!')
|
53 |
+
else:
|
54 |
+
self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
55 |
+
# replace the last layer
|
56 |
+
self.resnet.fc = nn.Linear(512, n_resnet_out)
|
57 |
+
'''
|
58 |
+
|
59 |
+
self.lin1 = GraphLinear(3 + n_resnet_out, 2 * num_channels)
|
60 |
+
self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0])
|
61 |
+
encode_layers = []
|
62 |
+
decode_layers = []
|
63 |
+
|
64 |
+
for i in range(self.num_layers + 1): # range(len(self.A)):
|
65 |
+
encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i]))
|
66 |
+
|
67 |
+
decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels,
|
68 |
+
self.A[self.num_layers - i]))
|
69 |
+
current_channels = (i+1)*num_channels
|
70 |
+
# number of channels for the input is different because of the concatenation operation
|
71 |
+
self.n_out_gc = 2 # two labels per vertex
|
72 |
+
self.gc = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]),
|
73 |
+
GraphResBlock(64, 32, self.A[0]),
|
74 |
+
nn.GroupNorm(32 // 8, 32),
|
75 |
+
nn.ReLU(inplace=True),
|
76 |
+
GraphLinear(32, self.n_out_gc))
|
77 |
+
|
78 |
+
'''
|
79 |
+
self.n_out_flatground = 2
|
80 |
+
self.flat_ground = nn.Sequential(nn.GroupNorm(current_channels // 8, current_channels),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
GraphLinear(current_channels, 1),
|
83 |
+
nn.ReLU(inplace=True),
|
84 |
+
nn.Linear(A.shape[0], self.n_out_flatground))
|
85 |
+
'''
|
86 |
+
|
87 |
+
self.encoder = nn.Sequential(*encode_layers)
|
88 |
+
self.decoder = nn.Sequential(*decode_layers)
|
89 |
+
self.mesh = mesh
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def forward(self, image_enc):
|
95 |
+
"""Forward pass
|
96 |
+
Inputs:
|
97 |
+
image_enc: size = (B, self.n_resnet_out)
|
98 |
+
Returns:
|
99 |
+
Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
|
100 |
+
Weak-perspective camera: size = (B, 3)
|
101 |
+
"""
|
102 |
+
# import pdb; pdb.set_trace()
|
103 |
+
|
104 |
+
batch_size = image_enc.shape[0]
|
105 |
+
# ref_vertices = (self.mesh.get_ref_vertices(n=self.num_downsample).t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
|
106 |
+
ref_vertices = (self.mesh.ref_vertices.t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
|
107 |
+
'''image_resnet = self.resnet(image) # (bs, 512)'''
|
108 |
+
image_enc_prep = image_enc.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973)
|
109 |
+
|
110 |
+
# prepare network input
|
111 |
+
# -> for each node we feed the location of the vertex in the template mesh and an image encoding
|
112 |
+
x = torch.cat([ref_vertices, image_enc_prep], dim=1)
|
113 |
+
x = self.lin1(x)
|
114 |
+
x = self.res1(x)
|
115 |
+
x_ = [x]
|
116 |
+
output_list = []
|
117 |
+
for i in range(self.num_layers + 1):
|
118 |
+
if i == self.num_layers:
|
119 |
+
x = self.encoder[i](x)
|
120 |
+
else:
|
121 |
+
x = self.encoder[i](x)
|
122 |
+
x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1)
|
123 |
+
x = x.transpose(1, 2)
|
124 |
+
if i < self.num_layers-1:
|
125 |
+
x_.append(x)
|
126 |
+
for i in range(self.num_layers + 1):
|
127 |
+
if i == self.num_layers:
|
128 |
+
x = self.decoder[i](x)
|
129 |
+
output_list.append(x)
|
130 |
+
else:
|
131 |
+
x = self.decoder[i](x)
|
132 |
+
output_list.append(x)
|
133 |
+
x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample,
|
134 |
+
n2=self.num_layers-i-1+self.num_downsample)
|
135 |
+
x = x.transpose(1, 2)
|
136 |
+
x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder
|
137 |
+
|
138 |
+
ground_contact = self.gc(x)
|
139 |
+
|
140 |
+
'''
|
141 |
+
ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1)
|
142 |
+
'''
|
143 |
+
|
144 |
+
return ground_contact, output_list # , ground_flatness
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
# how to use it:
|
153 |
+
#
|
154 |
+
# from src.graph_networks.graphcmr.utils_mesh import Mesh
|
155 |
+
#
|
156 |
+
# create Mesh object
|
157 |
+
# self.mesh = Mesh()
|
158 |
+
# self.faces = self.mesh.faces.to(self.device)
|
159 |
+
#
|
160 |
+
# create GraphCNN
|
161 |
+
# self.graph_cnn = GraphCNN(self.mesh.adjmat,
|
162 |
+
# self.mesh.ref_vertices.t(),
|
163 |
+
# num_channels=self.options.num_channels,
|
164 |
+
# num_layers=self.options.num_layers
|
165 |
+
# ).to(self.device)
|
166 |
+
# ------------
|
167 |
+
#
|
168 |
+
# Feed image in the GraphCNN
|
169 |
+
# Returns subsampled mesh and camera parameters
|
170 |
+
# pred_vertices_sub, pred_camera = self.graph_cnn(images)
|
171 |
+
#
|
172 |
+
# Upsample mesh in the original size
|
173 |
+
# pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2))
|
174 |
+
#
|
src/graph_networks/graphcmr/graph_cnn_groundcontact_multistage_includingresnet.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
code from
|
3 |
+
https://raw.githubusercontent.com/nkolot/GraphCMR/master/models/graph_cnn.py
|
4 |
+
https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py
|
5 |
+
This file contains the Definition of GraphCNN
|
6 |
+
GraphCNN includes ResNet50 as a submodule
|
7 |
+
"""
|
8 |
+
from __future__ import division
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
|
13 |
+
# from .resnet import resnet50
|
14 |
+
import torchvision.models as models
|
15 |
+
|
16 |
+
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..', '..'))
|
20 |
+
from src.graph_networks.graphcmr.utils_mesh import Mesh
|
21 |
+
from src.graph_networks.graphcmr.graph_layers import GraphResBlock, GraphLinear
|
22 |
+
|
23 |
+
|
24 |
+
class GraphCNNMS(nn.Module):
|
25 |
+
|
26 |
+
def __init__(self, mesh, num_downsample=0, num_layers=5, n_resnet_in=3, n_resnet_out=256, num_channels=256, use_pret_res=False):
|
27 |
+
'''
|
28 |
+
Args:
|
29 |
+
mesh: mesh data that store the adjacency matrix
|
30 |
+
num_channels: number of channels of GCN
|
31 |
+
num_downsample: number of downsampling of the input mesh
|
32 |
+
'''
|
33 |
+
|
34 |
+
super(GraphCNNMS, self).__init__()
|
35 |
+
|
36 |
+
self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled
|
37 |
+
# self.num_layers = len(self.A) - 1
|
38 |
+
self.num_layers = num_layers
|
39 |
+
assert self.num_layers <= len(self.A) - 1
|
40 |
+
print("Number of downsampling layer: {}".format(self.num_layers))
|
41 |
+
self.num_downsample = num_downsample
|
42 |
+
self.use_pret_res = use_pret_res
|
43 |
+
|
44 |
+
# self.resnet = resnet50(pretrained=True)
|
45 |
+
# -> within the GraphCMR network they ignore the last fully connected layer
|
46 |
+
# replace the first layer
|
47 |
+
self.resnet = models.resnet34(pretrained=self.use_pret_res)
|
48 |
+
if (self.use_pret_res) and (n_resnet_in == 3):
|
49 |
+
print('use full pretrained resnet including first layer!')
|
50 |
+
else:
|
51 |
+
self.resnet.conv1 = nn.Conv2d(n_resnet_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
|
52 |
+
# replace the last layer
|
53 |
+
self.resnet.fc = nn.Linear(512, n_resnet_out)
|
54 |
+
|
55 |
+
self.lin1 = GraphLinear(3 + n_resnet_out, 2 * num_channels)
|
56 |
+
self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0])
|
57 |
+
encode_layers = []
|
58 |
+
decode_layers = []
|
59 |
+
|
60 |
+
for i in range(self.num_layers + 1): # range(len(self.A)):
|
61 |
+
encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i]))
|
62 |
+
|
63 |
+
decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels,
|
64 |
+
self.A[self.num_layers - i]))
|
65 |
+
current_channels = (i+1)*num_channels
|
66 |
+
# number of channels for the input is different because of the concatenation operation
|
67 |
+
self.n_out_gc = 2 # two labels per vertex
|
68 |
+
self.gc = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]),
|
69 |
+
GraphResBlock(64, 32, self.A[0]),
|
70 |
+
nn.GroupNorm(32 // 8, 32),
|
71 |
+
nn.ReLU(inplace=True),
|
72 |
+
GraphLinear(32, self.n_out_gc))
|
73 |
+
|
74 |
+
'''
|
75 |
+
self.n_out_flatground = 2
|
76 |
+
self.flat_ground = nn.Sequential(nn.GroupNorm(current_channels // 8, current_channels),
|
77 |
+
nn.ReLU(inplace=True),
|
78 |
+
GraphLinear(current_channels, 1),
|
79 |
+
nn.ReLU(inplace=True),
|
80 |
+
nn.Linear(A.shape[0], self.n_out_flatground))
|
81 |
+
'''
|
82 |
+
|
83 |
+
self.encoder = nn.Sequential(*encode_layers)
|
84 |
+
self.decoder = nn.Sequential(*decode_layers)
|
85 |
+
self.mesh = mesh
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
def forward(self, image):
|
91 |
+
"""Forward pass
|
92 |
+
Inputs:
|
93 |
+
image: size = (B, 3, 256, 256)
|
94 |
+
Returns:
|
95 |
+
Regressed (subsampled) non-parametric shape: size = (B, 1723, 3)
|
96 |
+
Weak-perspective camera: size = (B, 3)
|
97 |
+
"""
|
98 |
+
# import pdb; pdb.set_trace()
|
99 |
+
|
100 |
+
batch_size = image.shape[0]
|
101 |
+
# ref_vertices = (self.mesh.get_ref_vertices(n=self.num_downsample).t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
|
102 |
+
ref_vertices = (self.mesh.ref_vertices.t())[None, :, :].expand(batch_size, -1, -1) # (bs, 3, 973)
|
103 |
+
image_resnet = self.resnet(image) # (bs, 512)
|
104 |
+
image_enc = image_resnet.view(batch_size, -1, 1).expand(-1, -1, ref_vertices.shape[-1]) # (bs, 512, 973)
|
105 |
+
|
106 |
+
# prepare network input
|
107 |
+
# -> for each node we feed the location of the vertex in the template mesh and an image encoding
|
108 |
+
x = torch.cat([ref_vertices, image_enc], dim=1)
|
109 |
+
x = self.lin1(x)
|
110 |
+
x = self.res1(x)
|
111 |
+
x_ = [x]
|
112 |
+
output_list = []
|
113 |
+
for i in range(self.num_layers + 1):
|
114 |
+
if i == self.num_layers:
|
115 |
+
x = self.encoder[i](x)
|
116 |
+
else:
|
117 |
+
x = self.encoder[i](x)
|
118 |
+
x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1)
|
119 |
+
x = x.transpose(1, 2)
|
120 |
+
if i < self.num_layers-1:
|
121 |
+
x_.append(x)
|
122 |
+
for i in range(self.num_layers + 1):
|
123 |
+
if i == self.num_layers:
|
124 |
+
x = self.decoder[i](x)
|
125 |
+
output_list.append(x)
|
126 |
+
else:
|
127 |
+
x = self.decoder[i](x)
|
128 |
+
output_list.append(x)
|
129 |
+
x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample,
|
130 |
+
n2=self.num_layers-i-1+self.num_downsample)
|
131 |
+
x = x.transpose(1, 2)
|
132 |
+
x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder
|
133 |
+
|
134 |
+
ground_contact = self.gc(x)
|
135 |
+
|
136 |
+
'''
|
137 |
+
ground_flatness = self.flat_ground(x).view(batch_size, self.n_out_flatground) # (bs, 1)
|
138 |
+
'''
|
139 |
+
|
140 |
+
return ground_contact, output_list # , ground_flatness
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
# how to use it:
|
149 |
+
#
|
150 |
+
# from src.graph_networks.graphcmr.utils_mesh import Mesh
|
151 |
+
#
|
152 |
+
# create Mesh object
|
153 |
+
# self.mesh = Mesh()
|
154 |
+
# self.faces = self.mesh.faces.to(self.device)
|
155 |
+
#
|
156 |
+
# create GraphCNN
|
157 |
+
# self.graph_cnn = GraphCNN(self.mesh.adjmat,
|
158 |
+
# self.mesh.ref_vertices.t(),
|
159 |
+
# num_channels=self.options.num_channels,
|
160 |
+
# num_layers=self.options.num_layers
|
161 |
+
# ).to(self.device)
|
162 |
+
# ------------
|
163 |
+
#
|
164 |
+
# Feed image in the GraphCNN
|
165 |
+
# Returns subsampled mesh and camera parameters
|
166 |
+
# pred_vertices_sub, pred_camera = self.graph_cnn(images)
|
167 |
+
#
|
168 |
+
# Upsample mesh in the original size
|
169 |
+
# pred_vertices = self.mesh.upsample(pred_vertices_sub.transpose(1,2))
|
170 |
+
#
|
src/graph_networks/graphcmr/graph_layers.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
code from https://github.com/nkolot/GraphCMR/blob/master/models/graph_layers.py
|
3 |
+
This file contains definitions of layers used to build the GraphCNN
|
4 |
+
"""
|
5 |
+
from __future__ import division
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import math
|
11 |
+
|
12 |
+
class GraphConvolution(nn.Module):
|
13 |
+
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
|
14 |
+
def __init__(self, in_features, out_features, adjmat, bias=True):
|
15 |
+
super(GraphConvolution, self).__init__()
|
16 |
+
self.in_features = in_features
|
17 |
+
self.out_features = out_features
|
18 |
+
self.adjmat = adjmat
|
19 |
+
self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
|
20 |
+
if bias:
|
21 |
+
self.bias = nn.Parameter(torch.FloatTensor(out_features))
|
22 |
+
else:
|
23 |
+
self.register_parameter('bias', None)
|
24 |
+
self.reset_parameters()
|
25 |
+
|
26 |
+
def reset_parameters(self):
|
27 |
+
# stdv = 1. / math.sqrt(self.weight.size(1))
|
28 |
+
stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1))
|
29 |
+
self.weight.data.uniform_(-stdv, stdv)
|
30 |
+
if self.bias is not None:
|
31 |
+
self.bias.data.uniform_(-stdv, stdv)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
if x.ndimension() == 2:
|
35 |
+
support = torch.matmul(x, self.weight)
|
36 |
+
output = torch.matmul(self.adjmat, support)
|
37 |
+
if self.bias is not None:
|
38 |
+
output = output + self.bias
|
39 |
+
return output
|
40 |
+
else:
|
41 |
+
output = []
|
42 |
+
for i in range(x.shape[0]):
|
43 |
+
support = torch.matmul(x[i], self.weight)
|
44 |
+
# output.append(torch.matmul(self.adjmat, support))
|
45 |
+
output.append(spmm(self.adjmat, support))
|
46 |
+
output = torch.stack(output, dim=0)
|
47 |
+
if self.bias is not None:
|
48 |
+
output = output + self.bias
|
49 |
+
return output
|
50 |
+
|
51 |
+
def __repr__(self):
|
52 |
+
return self.__class__.__name__ + ' (' \
|
53 |
+
+ str(self.in_features) + ' -> ' \
|
54 |
+
+ str(self.out_features) + ')'
|
55 |
+
|
56 |
+
class GraphLinear(nn.Module):
|
57 |
+
"""
|
58 |
+
Generalization of 1x1 convolutions on Graphs
|
59 |
+
"""
|
60 |
+
def __init__(self, in_channels, out_channels):
|
61 |
+
super(GraphLinear, self).__init__()
|
62 |
+
self.in_channels = in_channels
|
63 |
+
self.out_channels = out_channels
|
64 |
+
self.W = nn.Parameter(torch.FloatTensor(out_channels, in_channels))
|
65 |
+
self.b = nn.Parameter(torch.FloatTensor(out_channels))
|
66 |
+
self.reset_parameters()
|
67 |
+
|
68 |
+
def reset_parameters(self):
|
69 |
+
w_stdv = 1 / (self.in_channels * self.out_channels)
|
70 |
+
self.W.data.uniform_(-w_stdv, w_stdv)
|
71 |
+
self.b.data.uniform_(-w_stdv, w_stdv)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
return torch.matmul(self.W[None, :], x) + self.b[None, :, None]
|
75 |
+
|
76 |
+
class GraphResBlock(nn.Module):
|
77 |
+
"""
|
78 |
+
Graph Residual Block similar to the Bottleneck Residual Block in ResNet
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self, in_channels, out_channels, A):
|
82 |
+
super(GraphResBlock, self).__init__()
|
83 |
+
self.in_channels = in_channels
|
84 |
+
self.out_channels = out_channels
|
85 |
+
self.lin1 = GraphLinear(in_channels, out_channels // 2)
|
86 |
+
self.conv = GraphConvolution(out_channels // 2, out_channels // 2, A)
|
87 |
+
self.lin2 = GraphLinear(out_channels // 2, out_channels)
|
88 |
+
self.skip_conv = GraphLinear(in_channels, out_channels)
|
89 |
+
self.pre_norm = nn.GroupNorm(in_channels // 8, in_channels)
|
90 |
+
self.norm1 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2))
|
91 |
+
self.norm2 = nn.GroupNorm((out_channels // 2) // 8, (out_channels // 2))
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
y = F.relu(self.pre_norm(x))
|
95 |
+
y = self.lin1(y)
|
96 |
+
|
97 |
+
y = F.relu(self.norm1(y))
|
98 |
+
y = self.conv(y.transpose(1,2)).transpose(1,2)
|
99 |
+
|
100 |
+
y = F.relu(self.norm2(y))
|
101 |
+
y = self.lin2(y)
|
102 |
+
if self.in_channels != self.out_channels:
|
103 |
+
x = self.skip_conv(x)
|
104 |
+
return x+y
|
105 |
+
|
106 |
+
class SparseMM(torch.autograd.Function):
|
107 |
+
"""Redefine sparse @ dense matrix multiplication to enable backpropagation.
|
108 |
+
The builtin matrix multiplication operation does not support backpropagation in some cases.
|
109 |
+
"""
|
110 |
+
@staticmethod
|
111 |
+
def forward(ctx, sparse, dense):
|
112 |
+
ctx.req_grad = dense.requires_grad
|
113 |
+
ctx.save_for_backward(sparse)
|
114 |
+
return torch.matmul(sparse, dense)
|
115 |
+
|
116 |
+
@staticmethod
|
117 |
+
def backward(ctx, grad_output):
|
118 |
+
grad_input = None
|
119 |
+
sparse, = ctx.saved_tensors
|
120 |
+
if ctx.req_grad:
|
121 |
+
grad_input = torch.matmul(sparse.t(), grad_output)
|
122 |
+
return None, grad_input
|
123 |
+
|
124 |
+
def spmm(sparse, dense):
|
125 |
+
return SparseMM.apply(sparse, dense)
|
src/graph_networks/graphcmr/graphcnn_coarse_to_fine_animal_pose.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
code from: https://github.com/chaneyddtt/Coarse-to-fine-3D-Animal/blob/main/model/graph_hg.py
|
4 |
+
This file contains the Definition of GraphCNN
|
5 |
+
GraphCNN includes ResNet50 as a submodule
|
6 |
+
"""
|
7 |
+
from __future__ import division
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from model.networks.graph_layers import GraphResBlock, GraphLinear
|
13 |
+
from smal.mesh import Mesh
|
14 |
+
from smal.smal_torch import SMAL
|
15 |
+
|
16 |
+
# encoder-decoder structured GCN with skip connections
|
17 |
+
class GraphCNN_hg(nn.Module):
|
18 |
+
|
19 |
+
def __init__(self, mesh, num_channels=256, local_feat=False, num_downsample=0):
|
20 |
+
'''
|
21 |
+
Args:
|
22 |
+
mesh: mesh data that store the adjacency matrix
|
23 |
+
num_channels: number of channels of GCN
|
24 |
+
local_feat: whether use local feature for refinement
|
25 |
+
num_downsample: number of downsampling of the input mesh
|
26 |
+
'''
|
27 |
+
super(GraphCNN_hg, self).__init__()
|
28 |
+
self.A = mesh._A[num_downsample:] # get the correct adjacency matrix because the input might be downsampled
|
29 |
+
self.num_layers = len(self.A) - 1
|
30 |
+
print("Number of downsampling layer: {}".format(self.num_layers))
|
31 |
+
self.num_downsample = num_downsample
|
32 |
+
if local_feat:
|
33 |
+
self.lin1 = GraphLinear(3 + 2048 + 3840, 2 * num_channels)
|
34 |
+
else:
|
35 |
+
self.lin1 = GraphLinear(3 + 2048, 2 * num_channels)
|
36 |
+
self.res1 = GraphResBlock(2 * num_channels, num_channels, self.A[0])
|
37 |
+
encode_layers = []
|
38 |
+
decode_layers = []
|
39 |
+
|
40 |
+
for i in range(len(self.A)):
|
41 |
+
encode_layers.append(GraphResBlock(num_channels, num_channels, self.A[i]))
|
42 |
+
|
43 |
+
decode_layers.append(GraphResBlock((i+1)*num_channels, (i+1)*num_channels,
|
44 |
+
self.A[len(self.A) - i - 1]))
|
45 |
+
current_channels = (i+1)*num_channels
|
46 |
+
# number of channels for the input is different because of the concatenation operation
|
47 |
+
self.shape = nn.Sequential(GraphResBlock(current_channels, 64, self.A[0]),
|
48 |
+
GraphResBlock(64, 32, self.A[0]),
|
49 |
+
nn.GroupNorm(32 // 8, 32),
|
50 |
+
nn.ReLU(inplace=True),
|
51 |
+
GraphLinear(32, 3))
|
52 |
+
|
53 |
+
self.encoder = nn.Sequential(*encode_layers)
|
54 |
+
self.decoder = nn.Sequential(*decode_layers)
|
55 |
+
self.mesh = mesh
|
56 |
+
|
57 |
+
def forward(self, verts_c, img_fea_global, img_fea_multiscale=None, points_local=None):
|
58 |
+
'''
|
59 |
+
Args:
|
60 |
+
verts_c: vertices from the coarse estimation
|
61 |
+
img_fea_global: global feature for mesh refinement
|
62 |
+
img_fea_multiscale: multi-scale feature from the encoder, used for local feature extraction
|
63 |
+
points_local: 2D keypoint for local feature extraction
|
64 |
+
Returns: refined mesh
|
65 |
+
'''
|
66 |
+
batch_size = img_fea_global.shape[0]
|
67 |
+
ref_vertices = verts_c.transpose(1, 2)
|
68 |
+
image_enc = img_fea_global.view(batch_size, 2048, 1).expand(-1, -1, ref_vertices.shape[-1])
|
69 |
+
if points_local is not None:
|
70 |
+
feat_local = torch.nn.functional.grid_sample(img_fea_multiscale, points_local)
|
71 |
+
x = torch.cat([ref_vertices, image_enc, feat_local.squeeze(2)], dim=1)
|
72 |
+
else:
|
73 |
+
x = torch.cat([ref_vertices, image_enc], dim=1)
|
74 |
+
x = self.lin1(x)
|
75 |
+
x = self.res1(x)
|
76 |
+
x_ = [x]
|
77 |
+
for i in range(self.num_layers + 1):
|
78 |
+
if i == self.num_layers:
|
79 |
+
x = self.encoder[i](x)
|
80 |
+
else:
|
81 |
+
x = self.encoder[i](x)
|
82 |
+
x = self.mesh.downsample(x.transpose(1, 2), n1=self.num_downsample+i, n2=self.num_downsample+i+1)
|
83 |
+
x = x.transpose(1, 2)
|
84 |
+
if i < self.num_layers-1:
|
85 |
+
x_.append(x)
|
86 |
+
for i in range(self.num_layers + 1):
|
87 |
+
if i == self.num_layers:
|
88 |
+
x = self.decoder[i](x)
|
89 |
+
else:
|
90 |
+
x = self.decoder[i](x)
|
91 |
+
x = self.mesh.upsample(x.transpose(1, 2), n1=self.num_layers-i+self.num_downsample,
|
92 |
+
n2=self.num_layers-i-1+self.num_downsample)
|
93 |
+
x = x.transpose(1, 2)
|
94 |
+
x = torch.cat([x, x_[self.num_layers-i-1]], dim=1) # skip connection between encoder and decoder
|
95 |
+
|
96 |
+
shape = self.shape(x)
|
97 |
+
return shape
|
src/graph_networks/graphcmr/my_remarks.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
this folder contains code from https://github.com/nkolot/GraphCMR/tree/master/models
|
3 |
+
|
4 |
+
|
5 |
+
other (newer) networks operating on meshes such as SMAL would be:
|
6 |
+
https://github.com/microsoft/MeshTransformer
|
7 |
+
https://github.com/microsoft/MeshGraphormer
|
8 |
+
|
9 |
+
see also:
|
10 |
+
https://arxiv.org/pdf/2112.01554.pdf, page 13
|
11 |
+
(Neural Head Avatars from Monocular RGB Videos)
|
src/graph_networks/graphcmr/pytorch_coma_mesh_operations.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from https://github.com/pixelite1201/pytorch_coma/blob/master/mesh_operations.py
|
2 |
+
|
3 |
+
import math
|
4 |
+
import heapq
|
5 |
+
import numpy as np
|
6 |
+
import scipy.sparse as sp
|
7 |
+
from psbody.mesh import Mesh
|
8 |
+
|
9 |
+
def row(A):
|
10 |
+
return A.reshape((1, -1))
|
11 |
+
|
12 |
+
def col(A):
|
13 |
+
return A.reshape((-1, 1))
|
14 |
+
|
15 |
+
def get_vert_connectivity(mesh_v, mesh_f):
|
16 |
+
"""Returns a sparse matrix (of size #verts x #verts) where each nonzero
|
17 |
+
element indicates a neighborhood relation. For example, if there is a
|
18 |
+
nonzero element in position (15,12), that means vertex 15 is connected
|
19 |
+
by an edge to vertex 12."""
|
20 |
+
|
21 |
+
vpv = sp.csc_matrix((len(mesh_v),len(mesh_v)))
|
22 |
+
|
23 |
+
# for each column in the faces...
|
24 |
+
for i in range(3):
|
25 |
+
IS = mesh_f[:,i]
|
26 |
+
JS = mesh_f[:,(i+1)%3]
|
27 |
+
data = np.ones(len(IS))
|
28 |
+
ij = np.vstack((row(IS.flatten()), row(JS.flatten())))
|
29 |
+
mtx = sp.csc_matrix((data, ij), shape=vpv.shape)
|
30 |
+
vpv = vpv + mtx + mtx.T
|
31 |
+
|
32 |
+
return vpv
|
33 |
+
|
34 |
+
def get_vertices_per_edge(mesh_v, mesh_f):
|
35 |
+
"""Returns an Ex2 array of adjacencies between vertices, where
|
36 |
+
each element in the array is a vertex index. Each edge is included
|
37 |
+
only once. If output of get_faces_per_edge is provided, this is used to
|
38 |
+
avoid call to get_vert_connectivity()"""
|
39 |
+
|
40 |
+
vc = sp.coo_matrix(get_vert_connectivity(mesh_v, mesh_f))
|
41 |
+
result = np.hstack((col(vc.row), col(vc.col)))
|
42 |
+
result = result[result[:,0] < result[:,1]] # for uniqueness
|
43 |
+
|
44 |
+
return result
|
45 |
+
|
46 |
+
|
47 |
+
def vertex_quadrics(mesh):
|
48 |
+
"""Computes a quadric for each vertex in the Mesh.
|
49 |
+
see also:
|
50 |
+
https://www.cs.cmu.edu/~./garland/Papers/quadrics.pdf
|
51 |
+
https://users.csc.calpoly.edu/~zwood/teaching/csc570/final06/jseeba/
|
52 |
+
Returns:
|
53 |
+
v_quadrics: an (N x 4 x 4) array, where N is # vertices.
|
54 |
+
"""
|
55 |
+
|
56 |
+
# Allocate quadrics
|
57 |
+
v_quadrics = np.zeros((len(mesh.v), 4, 4,))
|
58 |
+
|
59 |
+
# For each face...
|
60 |
+
for f_idx in range(len(mesh.f)):
|
61 |
+
|
62 |
+
# Compute normalized plane equation for that face
|
63 |
+
vert_idxs = mesh.f[f_idx]
|
64 |
+
verts = np.hstack((mesh.v[vert_idxs], np.array([1, 1, 1]).reshape(-1, 1)))
|
65 |
+
u, s, v = np.linalg.svd(verts)
|
66 |
+
eq = v[-1, :].reshape(-1, 1)
|
67 |
+
eq = eq / (np.linalg.norm(eq[0:3]))
|
68 |
+
|
69 |
+
# Add the outer product of the plane equation to the
|
70 |
+
# quadrics of the vertices for this face
|
71 |
+
for k in range(3):
|
72 |
+
v_quadrics[mesh.f[f_idx, k], :, :] += np.outer(eq, eq)
|
73 |
+
|
74 |
+
return v_quadrics
|
75 |
+
|
76 |
+
def _get_sparse_transform(faces, num_original_verts):
|
77 |
+
verts_left = np.unique(faces.flatten())
|
78 |
+
IS = np.arange(len(verts_left))
|
79 |
+
JS = verts_left
|
80 |
+
data = np.ones(len(JS))
|
81 |
+
|
82 |
+
mp = np.arange(0, np.max(faces.flatten()) + 1)
|
83 |
+
mp[JS] = IS
|
84 |
+
new_faces = mp[faces.copy().flatten()].reshape((-1, 3))
|
85 |
+
|
86 |
+
ij = np.vstack((IS.flatten(), JS.flatten()))
|
87 |
+
mtx = sp.csc_matrix((data, ij), shape=(len(verts_left) , num_original_verts ))
|
88 |
+
|
89 |
+
return (new_faces, mtx)
|
90 |
+
|
91 |
+
def qslim_decimator_transformer(mesh, factor=None, n_verts_desired=None):
|
92 |
+
"""Return a simplified version of this mesh.
|
93 |
+
|
94 |
+
A Qslim-style approach is used here.
|
95 |
+
|
96 |
+
:param factor: fraction of the original vertices to retain
|
97 |
+
:param n_verts_desired: number of the original vertices to retain
|
98 |
+
:returns: new_faces: An Fx3 array of faces, mtx: Transformation matrix
|
99 |
+
"""
|
100 |
+
|
101 |
+
if factor is None and n_verts_desired is None:
|
102 |
+
raise Exception('Need either factor or n_verts_desired.')
|
103 |
+
|
104 |
+
if n_verts_desired is None:
|
105 |
+
n_verts_desired = math.ceil(len(mesh.v) * factor)
|
106 |
+
|
107 |
+
Qv = vertex_quadrics(mesh)
|
108 |
+
|
109 |
+
# fill out a sparse matrix indicating vertex-vertex adjacency
|
110 |
+
# from psbody.mesh.topology.connectivity import get_vertices_per_edge
|
111 |
+
vert_adj = get_vertices_per_edge(mesh.v, mesh.f)
|
112 |
+
# vert_adj = sp.lil_matrix((len(mesh.v), len(mesh.v)))
|
113 |
+
# for f_idx in range(len(mesh.f)):
|
114 |
+
# vert_adj[mesh.f[f_idx], mesh.f[f_idx]] = 1
|
115 |
+
|
116 |
+
vert_adj = sp.csc_matrix((vert_adj[:, 0] * 0 + 1, (vert_adj[:, 0], vert_adj[:, 1])), shape=(len(mesh.v), len(mesh.v)))
|
117 |
+
vert_adj = vert_adj + vert_adj.T
|
118 |
+
vert_adj = vert_adj.tocoo()
|
119 |
+
|
120 |
+
def collapse_cost(Qv, r, c, v):
|
121 |
+
Qsum = Qv[r, :, :] + Qv[c, :, :]
|
122 |
+
p1 = np.vstack((v[r].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
|
123 |
+
p2 = np.vstack((v[c].reshape(-1, 1), np.array([1]).reshape(-1, 1)))
|
124 |
+
|
125 |
+
destroy_c_cost = p1.T.dot(Qsum).dot(p1)
|
126 |
+
destroy_r_cost = p2.T.dot(Qsum).dot(p2)
|
127 |
+
result = {
|
128 |
+
'destroy_c_cost': destroy_c_cost,
|
129 |
+
'destroy_r_cost': destroy_r_cost,
|
130 |
+
'collapse_cost': min([destroy_c_cost, destroy_r_cost]),
|
131 |
+
'Qsum': Qsum}
|
132 |
+
return result
|
133 |
+
|
134 |
+
# construct a queue of edges with costs
|
135 |
+
queue = []
|
136 |
+
for k in range(vert_adj.nnz):
|
137 |
+
r = vert_adj.row[k]
|
138 |
+
c = vert_adj.col[k]
|
139 |
+
|
140 |
+
if r > c:
|
141 |
+
continue
|
142 |
+
|
143 |
+
cost = collapse_cost(Qv, r, c, mesh.v)['collapse_cost']
|
144 |
+
heapq.heappush(queue, (cost, (r, c)))
|
145 |
+
|
146 |
+
# decimate
|
147 |
+
collapse_list = []
|
148 |
+
nverts_total = len(mesh.v)
|
149 |
+
faces = mesh.f.copy()
|
150 |
+
while nverts_total > n_verts_desired:
|
151 |
+
e = heapq.heappop(queue)
|
152 |
+
r = e[1][0]
|
153 |
+
c = e[1][1]
|
154 |
+
if r == c:
|
155 |
+
continue
|
156 |
+
|
157 |
+
cost = collapse_cost(Qv, r, c, mesh.v)
|
158 |
+
if cost['collapse_cost'] > e[0]:
|
159 |
+
heapq.heappush(queue, (cost['collapse_cost'], e[1]))
|
160 |
+
# print 'found outdated cost, %.2f < %.2f' % (e[0], cost['collapse_cost'])
|
161 |
+
continue
|
162 |
+
else:
|
163 |
+
|
164 |
+
# update old vert idxs to new one,
|
165 |
+
# in queue and in face list
|
166 |
+
if cost['destroy_c_cost'] < cost['destroy_r_cost']:
|
167 |
+
to_destroy = c
|
168 |
+
to_keep = r
|
169 |
+
else:
|
170 |
+
to_destroy = r
|
171 |
+
to_keep = c
|
172 |
+
|
173 |
+
collapse_list.append([to_keep, to_destroy])
|
174 |
+
|
175 |
+
# in our face array, replace "to_destroy" vertidx with "to_keep" vertidx
|
176 |
+
np.place(faces, faces == to_destroy, to_keep)
|
177 |
+
|
178 |
+
# same for queue
|
179 |
+
which1 = [idx for idx in range(len(queue)) if queue[idx][1][0] == to_destroy]
|
180 |
+
which2 = [idx for idx in range(len(queue)) if queue[idx][1][1] == to_destroy]
|
181 |
+
for k in which1:
|
182 |
+
queue[k] = (queue[k][0], (to_keep, queue[k][1][1]))
|
183 |
+
for k in which2:
|
184 |
+
queue[k] = (queue[k][0], (queue[k][1][0], to_keep))
|
185 |
+
|
186 |
+
Qv[r, :, :] = cost['Qsum']
|
187 |
+
Qv[c, :, :] = cost['Qsum']
|
188 |
+
|
189 |
+
a = faces[:, 0] == faces[:, 1]
|
190 |
+
b = faces[:, 1] == faces[:, 2]
|
191 |
+
c = faces[:, 2] == faces[:, 0]
|
192 |
+
|
193 |
+
# remove degenerate faces
|
194 |
+
def logical_or3(x, y, z):
|
195 |
+
return np.logical_or(x, np.logical_or(y, z))
|
196 |
+
|
197 |
+
faces_to_keep = np.logical_not(logical_or3(a, b, c))
|
198 |
+
faces = faces[faces_to_keep, :].copy()
|
199 |
+
|
200 |
+
nverts_total = (len(np.unique(faces.flatten())))
|
201 |
+
|
202 |
+
new_faces, mtx = _get_sparse_transform(faces, len(mesh.v))
|
203 |
+
return new_faces, mtx
|
204 |
+
|
205 |
+
|
206 |
+
def setup_deformation_transfer(source, target, use_normals=False):
|
207 |
+
rows = np.zeros(3 * target.v.shape[0])
|
208 |
+
cols = np.zeros(3 * target.v.shape[0])
|
209 |
+
coeffs_v = np.zeros(3 * target.v.shape[0])
|
210 |
+
coeffs_n = np.zeros(3 * target.v.shape[0])
|
211 |
+
|
212 |
+
nearest_faces, nearest_parts, nearest_vertices = source.compute_aabb_tree().nearest(target.v, True)
|
213 |
+
nearest_faces = nearest_faces.ravel().astype(np.int64)
|
214 |
+
nearest_parts = nearest_parts.ravel().astype(np.int64)
|
215 |
+
nearest_vertices = nearest_vertices.ravel()
|
216 |
+
|
217 |
+
for i in range(target.v.shape[0]):
|
218 |
+
# Closest triangle index
|
219 |
+
f_id = nearest_faces[i]
|
220 |
+
# Closest triangle vertex ids
|
221 |
+
nearest_f = source.f[f_id]
|
222 |
+
|
223 |
+
# Closest surface point
|
224 |
+
nearest_v = nearest_vertices[3 * i:3 * i + 3]
|
225 |
+
# Distance vector to the closest surface point
|
226 |
+
dist_vec = target.v[i] - nearest_v
|
227 |
+
|
228 |
+
rows[3 * i:3 * i + 3] = i * np.ones(3)
|
229 |
+
cols[3 * i:3 * i + 3] = nearest_f
|
230 |
+
|
231 |
+
n_id = nearest_parts[i]
|
232 |
+
if n_id == 0:
|
233 |
+
# Closest surface point in triangle
|
234 |
+
A = np.vstack((source.v[nearest_f])).T
|
235 |
+
coeffs_v[3 * i:3 * i + 3] = np.linalg.lstsq(A, nearest_v)[0]
|
236 |
+
elif n_id > 0 and n_id <= 3:
|
237 |
+
# Closest surface point on edge
|
238 |
+
A = np.vstack((source.v[nearest_f[n_id - 1]], source.v[nearest_f[n_id % 3]])).T
|
239 |
+
tmp_coeffs = np.linalg.lstsq(A, target.v[i])[0]
|
240 |
+
coeffs_v[3 * i + n_id - 1] = tmp_coeffs[0]
|
241 |
+
coeffs_v[3 * i + n_id % 3] = tmp_coeffs[1]
|
242 |
+
else:
|
243 |
+
# Closest surface point a vertex
|
244 |
+
coeffs_v[3 * i + n_id - 4] = 1.0
|
245 |
+
|
246 |
+
# if use_normals:
|
247 |
+
# A = np.vstack((vn[nearest_f])).T
|
248 |
+
# coeffs_n[3 * i:3 * i + 3] = np.linalg.lstsq(A, dist_vec)[0]
|
249 |
+
|
250 |
+
#coeffs = np.hstack((coeffs_v, coeffs_n))
|
251 |
+
#rows = np.hstack((rows, rows))
|
252 |
+
#cols = np.hstack((cols, source.v.shape[0] + cols))
|
253 |
+
matrix = sp.csc_matrix((coeffs_v, (rows, cols)), shape=(target.v.shape[0], source.v.shape[0]))
|
254 |
+
return matrix
|
255 |
+
|
256 |
+
|
257 |
+
def generate_transform_matrices(mesh, factors):
|
258 |
+
"""Generates len(factors) meshes, each of them is scaled by factors[i] and
|
259 |
+
computes the transformations between them.
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
M: a set of meshes downsampled from mesh by a factor specified in factors.
|
263 |
+
A: Adjacency matrix for each of the meshes
|
264 |
+
D: Downsampling transforms between each of the meshes
|
265 |
+
U: Upsampling transforms between each of the meshes
|
266 |
+
"""
|
267 |
+
|
268 |
+
factors = map(lambda x: 1.0 / x, factors)
|
269 |
+
M, A, D, U = [], [], [], []
|
270 |
+
A.append(get_vert_connectivity(mesh.v, mesh.f).tocoo())
|
271 |
+
M.append(mesh)
|
272 |
+
|
273 |
+
for i,factor in enumerate(factors):
|
274 |
+
ds_f, ds_D = qslim_decimator_transformer(M[-1], factor=factor)
|
275 |
+
D.append(ds_D.tocoo())
|
276 |
+
new_mesh_v = ds_D.dot(M[-1].v)
|
277 |
+
new_mesh = Mesh(v=new_mesh_v, f=ds_f)
|
278 |
+
M.append(new_mesh)
|
279 |
+
A.append(get_vert_connectivity(new_mesh.v, new_mesh.f).tocoo())
|
280 |
+
U.append(setup_deformation_transfer(M[-1], M[-2]).tocoo())
|
281 |
+
|
282 |
+
return M, A, D, U
|
src/graph_networks/graphcmr/utils_mesh.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# code from https://github.com/nkolot/GraphCMR/blob/master/utils/mesh.py
|
2 |
+
|
3 |
+
from __future__ import division
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import scipy.sparse
|
7 |
+
|
8 |
+
# from models import SMPL
|
9 |
+
import os
|
10 |
+
import sys
|
11 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
12 |
+
from graph_networks.graphcmr.graph_layers import spmm
|
13 |
+
|
14 |
+
def scipy_to_pytorch(A, U, D):
|
15 |
+
"""Convert scipy sparse matrices to pytorch sparse matrix."""
|
16 |
+
ptU = []
|
17 |
+
ptD = []
|
18 |
+
|
19 |
+
for i in range(len(U)):
|
20 |
+
u = scipy.sparse.coo_matrix(U[i])
|
21 |
+
i = torch.LongTensor(np.array([u.row, u.col]))
|
22 |
+
v = torch.FloatTensor(u.data)
|
23 |
+
ptU.append(torch.sparse.FloatTensor(i, v, u.shape))
|
24 |
+
|
25 |
+
for i in range(len(D)):
|
26 |
+
d = scipy.sparse.coo_matrix(D[i])
|
27 |
+
i = torch.LongTensor(np.array([d.row, d.col]))
|
28 |
+
v = torch.FloatTensor(d.data)
|
29 |
+
ptD.append(torch.sparse.FloatTensor(i, v, d.shape))
|
30 |
+
|
31 |
+
return ptU, ptD
|
32 |
+
|
33 |
+
|
34 |
+
def adjmat_sparse(adjmat, nsize=1):
|
35 |
+
"""Create row-normalized sparse graph adjacency matrix."""
|
36 |
+
adjmat = scipy.sparse.csr_matrix(adjmat)
|
37 |
+
if nsize > 1:
|
38 |
+
orig_adjmat = adjmat.copy()
|
39 |
+
for _ in range(1, nsize):
|
40 |
+
adjmat = adjmat * orig_adjmat
|
41 |
+
adjmat.data = np.ones_like(adjmat.data)
|
42 |
+
for i in range(adjmat.shape[0]):
|
43 |
+
adjmat[i,i] = 1
|
44 |
+
num_neighbors = np.array(1 / adjmat.sum(axis=-1))
|
45 |
+
adjmat = adjmat.multiply(num_neighbors)
|
46 |
+
adjmat = scipy.sparse.coo_matrix(adjmat)
|
47 |
+
row = adjmat.row
|
48 |
+
col = adjmat.col
|
49 |
+
data = adjmat.data
|
50 |
+
i = torch.LongTensor(np.array([row, col]))
|
51 |
+
v = torch.from_numpy(data).float()
|
52 |
+
adjmat = torch.sparse.FloatTensor(i, v, adjmat.shape)
|
53 |
+
return adjmat
|
54 |
+
|
55 |
+
def get_graph_params(filename, nsize=1):
|
56 |
+
"""Load and process graph adjacency matrix and upsampling/downsampling matrices."""
|
57 |
+
data = np.load(filename, encoding='latin1', allow_pickle=True) # np.load(filename, encoding='latin1')
|
58 |
+
A = data['A']
|
59 |
+
U = data['U']
|
60 |
+
D = data['D']
|
61 |
+
U, D = scipy_to_pytorch(A, U, D)
|
62 |
+
A = [adjmat_sparse(a, nsize=nsize) for a in A]
|
63 |
+
return A, U, D
|
64 |
+
|
65 |
+
class Mesh(object):
|
66 |
+
"""Mesh object that is used for handling certain graph operations."""
|
67 |
+
def __init__(self, filename='data/mesh_downsampling.npz',
|
68 |
+
num_downsampling=1, nsize=1, body_model=None, device=torch.device('cuda')):
|
69 |
+
self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize)
|
70 |
+
self._A = [a.to(device) for a in self._A]
|
71 |
+
self._U = [u.to(device) for u in self._U]
|
72 |
+
self._D = [d.to(device) for d in self._D]
|
73 |
+
self.num_downsampling = num_downsampling
|
74 |
+
|
75 |
+
# load template vertices from SMPL and normalize them
|
76 |
+
if body_model is None:
|
77 |
+
smpl = SMPL()
|
78 |
+
else:
|
79 |
+
smpl = body_model
|
80 |
+
ref_vertices = smpl.v_template
|
81 |
+
center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None]
|
82 |
+
ref_vertices -= center
|
83 |
+
ref_vertices /= ref_vertices.abs().max().item()
|
84 |
+
|
85 |
+
self._ref_vertices = ref_vertices.to(device)
|
86 |
+
self.faces = smpl.faces.int().to(device)
|
87 |
+
|
88 |
+
@property
|
89 |
+
def adjmat(self):
|
90 |
+
"""Return the graph adjacency matrix at the specified subsampling level."""
|
91 |
+
return self._A[self.num_downsampling].float()
|
92 |
+
|
93 |
+
@property
|
94 |
+
def ref_vertices(self):
|
95 |
+
"""Return the template vertices at the specified subsampling level."""
|
96 |
+
ref_vertices = self._ref_vertices
|
97 |
+
for i in range(self.num_downsampling):
|
98 |
+
ref_vertices = torch.spmm(self._D[i], ref_vertices)
|
99 |
+
return ref_vertices
|
100 |
+
|
101 |
+
def get_ref_vertices(self, n_downsample):
|
102 |
+
"""Return the template vertices at any desired subsampling level."""
|
103 |
+
ref_vertices = self._ref_vertices
|
104 |
+
for i in range(n_downsample):
|
105 |
+
ref_vertices = torch.spmm(self._D[i], ref_vertices)
|
106 |
+
return ref_vertices
|
107 |
+
|
108 |
+
def downsample(self, x, n1=0, n2=None):
|
109 |
+
"""Downsample mesh."""
|
110 |
+
if n2 is None:
|
111 |
+
n2 = self.num_downsampling
|
112 |
+
if x.ndimension() < 3:
|
113 |
+
for i in range(n1, n2):
|
114 |
+
x = spmm(self._D[i], x)
|
115 |
+
elif x.ndimension() == 3:
|
116 |
+
out = []
|
117 |
+
for i in range(x.shape[0]):
|
118 |
+
y = x[i]
|
119 |
+
for j in range(n1, n2):
|
120 |
+
y = spmm(self._D[j], y)
|
121 |
+
out.append(y)
|
122 |
+
x = torch.stack(out, dim=0)
|
123 |
+
return x
|
124 |
+
|
125 |
+
def upsample(self, x, n1=1, n2=0):
|
126 |
+
"""Upsample mesh."""
|
127 |
+
if x.ndimension() < 3:
|
128 |
+
for i in reversed(range(n2, n1)):
|
129 |
+
x = spmm(self._U[i], x)
|
130 |
+
elif x.ndimension() == 3:
|
131 |
+
out = []
|
132 |
+
for i in range(x.shape[0]):
|
133 |
+
y = x[i]
|
134 |
+
for j in reversed(range(n2, n1)):
|
135 |
+
y = spmm(self._U[j], y)
|
136 |
+
out.append(y)
|
137 |
+
x = torch.stack(out, dim=0)
|
138 |
+
return x
|
src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py
|
4 |
+
shortest.py
|
5 |
+
----------------
|
6 |
+
Given a mesh and two vertex indices find the shortest path
|
7 |
+
between the two vertices while only traveling along edges
|
8 |
+
of the mesh.
|
9 |
+
"""
|
10 |
+
|
11 |
+
# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh.py
|
12 |
+
|
13 |
+
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import glob
|
17 |
+
import csv
|
18 |
+
import json
|
19 |
+
import shutil
|
20 |
+
import tqdm
|
21 |
+
import numpy as np
|
22 |
+
import pickle as pkl
|
23 |
+
import trimesh
|
24 |
+
import networkx as nx
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def read_csv(csv_file):
|
31 |
+
with open(csv_file,'r') as f:
|
32 |
+
reader = csv.reader(f)
|
33 |
+
headers = next(reader)
|
34 |
+
row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader]
|
35 |
+
return row_list
|
36 |
+
|
37 |
+
|
38 |
+
def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'):
|
39 |
+
vert_dists = np.load(root_out_path + filename)
|
40 |
+
return vert_dists
|
41 |
+
|
42 |
+
|
43 |
+
def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False):
|
44 |
+
# root_out_path = ROOT_OUT_PATH
|
45 |
+
'''
|
46 |
+
from smal_pytorch.smal_model.smal_torch_new import SMAL
|
47 |
+
smal = SMAL()
|
48 |
+
verts = smal.v_template.detach().cpu().numpy()
|
49 |
+
faces = smal.faces.detach().cpu().numpy()
|
50 |
+
'''
|
51 |
+
# path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
|
52 |
+
my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
|
53 |
+
verts = my_mesh.vertices
|
54 |
+
faces = my_mesh.faces
|
55 |
+
# edges without duplication
|
56 |
+
edges = my_mesh.edges_unique
|
57 |
+
# the actual length of each unique edge
|
58 |
+
length = my_mesh.edges_unique_length
|
59 |
+
# create the graph with edge attributes for length (option A)
|
60 |
+
# g = nx.Graph()
|
61 |
+
# for edge, L in zip(edges, length): g.add_edge(*edge, length=L)
|
62 |
+
# you can create the graph with from_edgelist and
|
63 |
+
# a list comprehension (option B)
|
64 |
+
ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)])
|
65 |
+
# calculate the distances between all vertex pairs
|
66 |
+
if calc_dist_mat:
|
67 |
+
# calculate distances between all possible vertex pairs
|
68 |
+
# shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length')
|
69 |
+
# shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length')
|
70 |
+
dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra'))
|
71 |
+
vertex_distances = np.zeros((n_verts_smal, n_verts_smal))
|
72 |
+
for ind_v0 in range(n_verts_smal):
|
73 |
+
print(ind_v0)
|
74 |
+
for ind_v1 in range(ind_v0, n_verts_smal):
|
75 |
+
vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1]
|
76 |
+
vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1]
|
77 |
+
# save those distances
|
78 |
+
np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances)
|
79 |
+
vert_dists = vertex_distances
|
80 |
+
else:
|
81 |
+
vert_dists = np.load(root_out_path + 'all_vertex_distances.npy')
|
82 |
+
return ga, vert_dists
|
83 |
+
|
84 |
+
|
85 |
+
def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None):
|
86 |
+
# input:
|
87 |
+
# root_out_path_vis = ROOT_OUT_PATH
|
88 |
+
# img_v12_dir = IMG_V12_DIR
|
89 |
+
# name = images_with_gc_labelled[ind_img]
|
90 |
+
# gc_info_raw = gc_dict['bite/' + name]
|
91 |
+
# output:
|
92 |
+
# vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist]
|
93 |
+
n_verts_smal = 3889
|
94 |
+
gc_vertices = []
|
95 |
+
gc_info_np = np.zeros((n_verts_smal))
|
96 |
+
for ind_v in gc_info_raw:
|
97 |
+
if ind_v < n_verts_smal:
|
98 |
+
gc_vertices.append(ind_v)
|
99 |
+
gc_info_np[ind_v] = 1
|
100 |
+
# save a visualization of those annotations
|
101 |
+
if root_out_path_vis is not None:
|
102 |
+
my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True)
|
103 |
+
if img_v12_dir is not None and root_out_path_vis is not None:
|
104 |
+
vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1)
|
105 |
+
my_mesh.visual.vertex_colors = vert_colors
|
106 |
+
my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj'))
|
107 |
+
img_path = img_v12_dir + name
|
108 |
+
shutil.copy(img_path, root_out_path_vis + name)
|
109 |
+
# calculate for each vertex the distance to the closest element of the other group
|
110 |
+
non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices))
|
111 |
+
print('vertices in contact: ' + str(len(gc_vertices)))
|
112 |
+
print('vertices without contact: ' + str(len(non_gc_vertices)))
|
113 |
+
vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist
|
114 |
+
vertex_overview[:, 0] = gc_info_np
|
115 |
+
# loop through all contact vertices
|
116 |
+
for ind_v in gc_vertices:
|
117 |
+
min_length = 100
|
118 |
+
for ind_v_ps in non_gc_vertices: # possible solution
|
119 |
+
# this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
|
120 |
+
# this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
|
121 |
+
this_length = vert_dists[ind_v, ind_v_ps]
|
122 |
+
if this_length < min_length:
|
123 |
+
min_length = this_length
|
124 |
+
vertex_overview[ind_v, 1] = ind_v_ps
|
125 |
+
vertex_overview[ind_v, 2] = this_length
|
126 |
+
# loop through all non-contact vertices
|
127 |
+
for ind_v in non_gc_vertices:
|
128 |
+
min_length = 100
|
129 |
+
for ind_v_ps in gc_vertices: # possible solution
|
130 |
+
# this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
|
131 |
+
# this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
|
132 |
+
this_length = vert_dists[ind_v, ind_v_ps]
|
133 |
+
if this_length < min_length:
|
134 |
+
min_length = this_length
|
135 |
+
vertex_overview[ind_v, 1] = ind_v_ps
|
136 |
+
vertex_overview[ind_v, 2] = this_length
|
137 |
+
if root_out_path_vis is not None:
|
138 |
+
# save a colored mesh
|
139 |
+
my_mesh_dists = my_mesh.copy()
|
140 |
+
scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max()
|
141 |
+
scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max()
|
142 |
+
vert_col = np.zeros((n_verts_smal, 3))
|
143 |
+
vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green
|
144 |
+
vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red
|
145 |
+
my_mesh_dists.visual.vertex_colors = np.uint8(vert_col)
|
146 |
+
my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj'))
|
147 |
+
return vertex_overview
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
def main():
|
157 |
+
|
158 |
+
ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/'
|
159 |
+
ROOT_PATH_ANNOT = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/'
|
160 |
+
IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/'
|
161 |
+
# ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/'
|
162 |
+
ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/'
|
163 |
+
ROOT_OUT_PATH_VIS = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/vis/'
|
164 |
+
ROOT_OUT_PATH_DISTSGCNONGC = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage3/vertex_distances_gc_nongc/'
|
165 |
+
ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/'
|
166 |
+
|
167 |
+
# load all vertex distances
|
168 |
+
path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
|
169 |
+
my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
|
170 |
+
verts = my_mesh.vertices
|
171 |
+
faces = my_mesh.faces
|
172 |
+
# vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False)
|
173 |
+
vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy')
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
all_keys = []
|
179 |
+
gc_dict = {}
|
180 |
+
# data/stanext_related_data/ground_contact_annotations/stage3/main_partA1667_20221021_140108.csv
|
181 |
+
# for csv_file in ['main_partA500_20221018_131139.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']:
|
182 |
+
# for csv_file in ['main_partA1667_20221021_140108.csv', 'main_partA500_20221018_131139.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']:
|
183 |
+
for csv_file in ['main_partA1667_20221021_140108.csv', 'main_partA500_20221018_131139.csv', 'main_partB20221023_150926.csv', 'pilot_20221017_104201.csv', 'my_gcannotations_qualification.csv']:
|
184 |
+
# load all ground contact annotations
|
185 |
+
gc_annot_csv = ROOT_PATH_ANNOT + csv_file # 'my_gcannotations_qualification.csv'
|
186 |
+
gc_row_list = read_csv(gc_annot_csv)
|
187 |
+
for ind_row in range(len(gc_row_list)):
|
188 |
+
json_acceptable_string = (gc_row_list[ind_row]['vertices']).replace("'", "\"")
|
189 |
+
gc_dict_temp = json.loads(json_acceptable_string)
|
190 |
+
all_keys.extend(gc_dict_temp.keys())
|
191 |
+
gc_dict.update(gc_dict_temp)
|
192 |
+
print(len(gc_dict.keys()))
|
193 |
+
|
194 |
+
print('number of labeled images: ' + str(len(gc_dict.keys()))) # WHY IS THIS ONLY 699?
|
195 |
+
|
196 |
+
import pdb; pdb.set_trace()
|
197 |
+
|
198 |
+
|
199 |
+
# prepare and save contact annotations including distances
|
200 |
+
vertex_overview_dict = {}
|
201 |
+
for ind_img, name_ingcdict in enumerate(gc_dict.keys()): # range(len(gc_dict.keys())):
|
202 |
+
name = name_ingcdict.split('bite/')[1]
|
203 |
+
# name = images_with_gc_labelled[ind_img]
|
204 |
+
print('work on image ' + str(ind_img) + ': ' + name)
|
205 |
+
# gc_info_raw = gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact
|
206 |
+
gc_info_raw = gc_dict[name_ingcdict] # a list with all vertex numbers that are in ground contact
|
207 |
+
|
208 |
+
if not os.path.exists(ROOT_OUT_PATH_VIS + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_VIS + name.split('/')[0])
|
209 |
+
if not os.path.exists(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0])
|
210 |
+
|
211 |
+
vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH_VIS, verts=verts, faces=faces, img_v12_dir=None)
|
212 |
+
np.save(ROOT_OUT_PATH_DISTSGCNONGC + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview)
|
213 |
+
|
214 |
+
vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw}
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
# import pdb; pdb.set_trace()
|
221 |
+
|
222 |
+
with open(ROOT_OUT_PATH + 'gc_annots_overview_stage3complete_withtraintestval_xx.pkl', 'wb') as fp:
|
223 |
+
pkl.dump(vertex_overview_dict, fp)
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
if __name__ == "__main__":
|
238 |
+
main()
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py
|
4 |
+
shortest.py
|
5 |
+
----------------
|
6 |
+
Given a mesh and two vertex indices find the shortest path
|
7 |
+
between the two vertices while only traveling along edges
|
8 |
+
of the mesh.
|
9 |
+
"""
|
10 |
+
|
11 |
+
# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forfourpaws.py
|
12 |
+
|
13 |
+
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import glob
|
17 |
+
import csv
|
18 |
+
import json
|
19 |
+
import shutil
|
20 |
+
import tqdm
|
21 |
+
import numpy as np
|
22 |
+
import pickle as pkl
|
23 |
+
import trimesh
|
24 |
+
import networkx as nx
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def read_csv(csv_file):
|
31 |
+
with open(csv_file,'r') as f:
|
32 |
+
reader = csv.reader(f)
|
33 |
+
headers = next(reader)
|
34 |
+
row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader]
|
35 |
+
return row_list
|
36 |
+
|
37 |
+
|
38 |
+
def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'):
|
39 |
+
vert_dists = np.load(root_out_path + filename)
|
40 |
+
return vert_dists
|
41 |
+
|
42 |
+
|
43 |
+
def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False):
|
44 |
+
# root_out_path = ROOT_OUT_PATH
|
45 |
+
'''
|
46 |
+
from smal_pytorch.smal_model.smal_torch_new import SMAL
|
47 |
+
smal = SMAL()
|
48 |
+
verts = smal.v_template.detach().cpu().numpy()
|
49 |
+
faces = smal.faces.detach().cpu().numpy()
|
50 |
+
'''
|
51 |
+
# path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
|
52 |
+
my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
|
53 |
+
verts = my_mesh.vertices
|
54 |
+
faces = my_mesh.faces
|
55 |
+
# edges without duplication
|
56 |
+
edges = my_mesh.edges_unique
|
57 |
+
# the actual length of each unique edge
|
58 |
+
length = my_mesh.edges_unique_length
|
59 |
+
# create the graph with edge attributes for length (option A)
|
60 |
+
# g = nx.Graph()
|
61 |
+
# for edge, L in zip(edges, length): g.add_edge(*edge, length=L)
|
62 |
+
# you can create the graph with from_edgelist and
|
63 |
+
# a list comprehension (option B)
|
64 |
+
ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)])
|
65 |
+
# calculate the distances between all vertex pairs
|
66 |
+
if calc_dist_mat:
|
67 |
+
# calculate distances between all possible vertex pairs
|
68 |
+
# shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length')
|
69 |
+
# shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length')
|
70 |
+
dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra'))
|
71 |
+
vertex_distances = np.zeros((n_verts_smal, n_verts_smal))
|
72 |
+
for ind_v0 in range(n_verts_smal):
|
73 |
+
print(ind_v0)
|
74 |
+
for ind_v1 in range(ind_v0, n_verts_smal):
|
75 |
+
vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1]
|
76 |
+
vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1]
|
77 |
+
# save those distances
|
78 |
+
np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances)
|
79 |
+
vert_dists = vertex_distances
|
80 |
+
else:
|
81 |
+
vert_dists = np.load(root_out_path + 'all_vertex_distances.npy')
|
82 |
+
return ga, vert_dists
|
83 |
+
|
84 |
+
|
85 |
+
def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None):
|
86 |
+
# input:
|
87 |
+
# root_out_path_vis = ROOT_OUT_PATH
|
88 |
+
# img_v12_dir = IMG_V12_DIR
|
89 |
+
# name = images_with_gc_labelled[ind_img]
|
90 |
+
# gc_info_raw = gc_dict['bite/' + name]
|
91 |
+
# output:
|
92 |
+
# vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist]
|
93 |
+
n_verts_smal = 3889
|
94 |
+
gc_vertices = []
|
95 |
+
gc_info_np = np.zeros((n_verts_smal))
|
96 |
+
for ind_v in gc_info_raw:
|
97 |
+
if ind_v < n_verts_smal:
|
98 |
+
gc_vertices.append(ind_v)
|
99 |
+
gc_info_np[ind_v] = 1
|
100 |
+
# save a visualization of those annotations
|
101 |
+
if root_out_path_vis is not None:
|
102 |
+
my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True)
|
103 |
+
if img_v12_dir is not None and root_out_path_vis is not None:
|
104 |
+
vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1)
|
105 |
+
my_mesh.visual.vertex_colors = vert_colors
|
106 |
+
my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj'))
|
107 |
+
img_path = img_v12_dir + name
|
108 |
+
shutil.copy(img_path, root_out_path_vis + name)
|
109 |
+
# calculate for each vertex the distance to the closest element of the other group
|
110 |
+
non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices))
|
111 |
+
print('vertices in contact: ' + str(len(gc_vertices)))
|
112 |
+
print('vertices without contact: ' + str(len(non_gc_vertices)))
|
113 |
+
vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist
|
114 |
+
vertex_overview[:, 0] = gc_info_np
|
115 |
+
# loop through all contact vertices
|
116 |
+
for ind_v in gc_vertices:
|
117 |
+
min_length = 100
|
118 |
+
for ind_v_ps in non_gc_vertices: # possible solution
|
119 |
+
# this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
|
120 |
+
# this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
|
121 |
+
this_length = vert_dists[ind_v, ind_v_ps]
|
122 |
+
if this_length < min_length:
|
123 |
+
min_length = this_length
|
124 |
+
vertex_overview[ind_v, 1] = ind_v_ps
|
125 |
+
vertex_overview[ind_v, 2] = this_length
|
126 |
+
# loop through all non-contact vertices
|
127 |
+
for ind_v in non_gc_vertices:
|
128 |
+
min_length = 100
|
129 |
+
for ind_v_ps in gc_vertices: # possible solution
|
130 |
+
# this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
|
131 |
+
# this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
|
132 |
+
this_length = vert_dists[ind_v, ind_v_ps]
|
133 |
+
if this_length < min_length:
|
134 |
+
min_length = this_length
|
135 |
+
vertex_overview[ind_v, 1] = ind_v_ps
|
136 |
+
vertex_overview[ind_v, 2] = this_length
|
137 |
+
if root_out_path_vis is not None:
|
138 |
+
# save a colored mesh
|
139 |
+
my_mesh_dists = my_mesh.copy()
|
140 |
+
scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max()
|
141 |
+
scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max()
|
142 |
+
vert_col = np.zeros((n_verts_smal, 3))
|
143 |
+
vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green
|
144 |
+
vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red
|
145 |
+
my_mesh_dists.visual.vertex_colors = np.uint8(vert_col)
|
146 |
+
my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj'))
|
147 |
+
return vertex_overview
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
def main():
|
158 |
+
|
159 |
+
ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/'
|
160 |
+
IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/'
|
161 |
+
# ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/'
|
162 |
+
ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stages12together/'
|
163 |
+
ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/'
|
164 |
+
|
165 |
+
# load all vertex distances
|
166 |
+
path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
|
167 |
+
my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
|
168 |
+
verts = my_mesh.vertices
|
169 |
+
faces = my_mesh.faces
|
170 |
+
# vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False)
|
171 |
+
vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy')
|
172 |
+
|
173 |
+
# paw vertices:
|
174 |
+
# left and right is a bit different, but that is ok (we will anyways mirror data at training time)
|
175 |
+
right_front_paw = [3829,+3827,+3825,+3718,+3722,+3723,+3743,+3831,+3719,+3726,+3716,+3724,+3828,+3717,+3721,+3725,+3832,+3830,+3720,+3288,+3740,+3714,+3826,+3715,+3728,+3712,+3287,+3284,+3727,+3285,+3742,+3291,+3710,+3697,+3711,+3289,+3730,+3713,+3739,+3282,+3738,+3708,+3709,+3741,+3698,+3696,+3308,+3695,+3706,+3700,+3707,+3306,+3305,+3737,+3304,+3303,+3307,+3736,+3735,+3250,+3261,+3732,+3734,+3733,+3731,+3729,+3299,+3297,+3298,+3295,+3293,+3296,+3294,+3292,+3312,+3311,+3314,+3309,+3290,+3313,+3410,+3315,+3411,+3412,+3316,+3421,+3317,+3415,+3445,+3327,+3328,+3283,+3343,+3326,+3325,+3330,+3286,+3399,+3398,+3329,+3446,+3400,+3331,+3401,+3281,+3332,+3279,+3402,+3419,+3407,+3356,+3358,+3357,+3280,+3354,+3277,+3278,+3346,+3347,+3377,+3378,+3345,+3386,+3379,+3348,+3384,+3418,+3372,+3276,+3275,+3374,+3274,+3373,+3375,+3369,+3371,+3376,+3273,+3396,+3397,+3395,+3388,+3360,+3370,+3361,+3394,+3387,+3420,+3359,+3389,+3272,+3391,+3393,+3390,+3392,+3363,+3362,+3367,+3365,+3705,+3271,+3704,+3703,+3270,+3269,+3702,+3268,+3224,+3267,+3701,+3225,+3699,+3265,+3264,+3266,+3263,+3262,+3249,+3228,+3230,+3251,+3301,+3300,+3302,+3252]
|
176 |
+
right_back_paw = [3472,+3627,+3470,+3469,+3471,+3473,+3626,+3625,+3475,+3655,+3519,+3468,+3629,+3466,+3476,+3624,+3521,+3654,+3657,+3838,+3518,+3653,+3839,+3553,+3474,+3516,+3656,+3628,+3834,+3535,+3630,+3658,+3477,+3520,+3517,+3595,+3522,+3597,+3596,+3501,+3534,+3503,+3478,+3500,+3479,+3502,+3607,+3499,+3608,+3496,+3605,+3609,+3504,+3606,+3642,+3614,+3498,+3480,+3631,+3610,+3613,+3506,+3659,+3660,+3632,+3841,+3661,+3836,+3662,+3633,+3663,+3664,+3634,+3635,+3486,+3665,+3636,+3637,+3666,+3490,+3837,+3667,+3493,+3638,+3492,+3495,+3616,+3644,+3494,+3835,+3643,+3833,+3840,+3615,+3650,+3668,+3652,+3651,+3645,+3646,+3647,+3649,+3648,+3622,+3617,+3448,+3621,+3618,+3623,+3462,+3464,+3460,+3620,+3458,+3461,+3463,+3465,+3573,+3571,+3467,+3569,+3557,+3558,+3572,+3570,+3556,+3585,+3593,+3594,+3459,+3566,+3592,+3567,+3568,+3538,+3539,+3555,+3537,+3536,+3554,+3575,+3574,+3583,+3541,+3550,+3576,+3581,+3639,+3577,+3551,+3582,+3580,+3552,+3578,+3542,+3549,+3579,+3523,+3526,+3598,+3525,+3600,+3640,+3599,+3601,+3602,+3603,+3529,+3604,+3530,+3533,+3532,+3611,+3612,+3482,+3481,+3505,+3452,+3455,+3456,+3454,+3457,+3619,+3451,+3450,+3449,+3591,+3589,+3641,+3584,+3561,+3587,+3559,+3488,+3484,+3483]
|
177 |
+
left_front_paw = [1791,+1950,+1948,+1790,+1789,+1746,+1788,+1747,+1949,+1944,+1792,+1945,+1356,+1775,+1759,+1777,+1787,+1946,+1757,+1761,+1745,+1943,+1947,+1744,+1309,+1786,+1771,+1354,+1774,+1765,+1767,+1768,+1772,+1763,+1770,+1773,+1769,+1764,+1766,+1758,+1760,+1762,+1336,+1333,+1330,+1325,+1756,+1323,+1755,+1753,+1749,+1754,+1751,+1321,+1752,+1748,+1750,+1312,+1319,+1315,+1313,+1317,+1318,+1316,+1314,+1311,+1310,+1299,+1276,+1355,+1297,+1353,+1298,+1300,+1352,+1351,+1785,+1784,+1349,+1783,+1782,+1781,+1780,+1779,+1778,+1776,+1343,+1341,+1344,+1339,+1342,+1340,+1360,+1335,+1338,+1362,+1357,+1361,+1363,+1458,+1337,+1459,+1456,+1460,+1493,+1332,+1375,+1376,+1331,+1374,+1378,+1334,+1373,+1494,+1377,+1446,+1448,+1379,+1449,+1329,+1327,+1404,+1406,+1405,+1402,+1328,+1426,+1432,+1434,+1403,+1394,+1395,+1433,+1425,+1286,+1380,+1466,+1431,+1290,+1401,+1381,+1427,+1450,+1393,+1430,+1326,+1396,+1428,+1397,+1429,+1398,+1420,+1324,+1422,+1417,+1419,+1421,+1443,+1418,+1423,+1444,+1442,+1424,+1445,+1495,+1440,+1441,+1468,+1436,+1408,+1322,+1435,+1415,+1439,+1409,+1283,+1438,+1416,+1407,+1437,+1411,+1413,+1414,+1320,+1273,+1272,+1278,+1469,+1463,+1457,+1358,+1464,+1465,+1359,+1372,+1391,+1390,+1455,+1447,+1454,+1467,+1453,+1452,+1451,+1383,+1345,+1347,+1348,+1350,+1364,+1392,+1410,+1412]
|
178 |
+
left_back_paw = [1957,+1958,+1701,+1956,+1951,+1703,+1715,+1702,+1700,+1673,+1705,+1952,+1955,+1674,+1699,+1675,+1953,+1704,+1954,+1698,+1677,+1671,+1672,+1714,+1706,+1676,+1519,+1523,+1686,+1713,+1692,+1685,+1543,+1664,+1712,+1691,+1959,+1541,+1684,+1542,+1496,+1663,+1540,+1497,+1499,+1498,+1500,+1693,+1665,+1694,+1716,+1666,+1695,+1501,+1502,+1696,+1667,+1503,+1697,+1504,+1668,+1669,+1506,+1670,+1508,+1510,+1507,+1509,+1511,+1512,+1621,+1606,+1619,+1605,+1513,+1620,+1618,+1604,+1633,+1641,+1642,+1607,+1617,+1514,+1632,+1614,+1689,+1640,+1515,+1586,+1616,+1516,+1517,+1603,+1615,+1639,+1585,+1521,+1602,+1587,+1584,+1601,+1623,+1622,+1631,+1598,+1624,+1629,+1589,+1687,+1625,+1599,+1630,+1569,+1570,+1628,+1626,+1597,+1627,+1590,+1594,+1571,+1568,+1567,+1574,+1646,+1573,+1645,+1648,+1564,+1688,+1647,+1643,+1649,+1650,+1651,+1577,+1644,+1565,+1652,+1566,+1578,+1518,+1524,+1583,+1582,+1520,+1581,+1522,+1525,+1549,+1551,+1580,+1552,+1550,+1656,+1658,+1554,+1657,+1659,+1548,+1655,+1690,+1660,+1556,+1653,+1558,+1661,+1544,+1662,+1654,+1547,+1545,+1527,+1560,+1526,+1678,+1679,+1528,+1708,+1707,+1680,+1529,+1530,+1709,+1546,+1681,+1710,+1711,+1682,+1532,+1531,+1683,+1534,+1533,+1536,+1538,+1600,+1553]
|
179 |
+
|
180 |
+
|
181 |
+
all_contact_vertices = right_front_paw + right_back_paw + left_front_paw + left_back_paw
|
182 |
+
|
183 |
+
name = 'all4pawsincontact.jpg'
|
184 |
+
print('work on 4paw images')
|
185 |
+
gc_info_raw = all_contact_vertices # a list with all vertex numbers that are in ground contact
|
186 |
+
|
187 |
+
vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH, verts=verts, faces=faces, img_v12_dir=None)
|
188 |
+
np.save(ROOT_OUT_PATH + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview)
|
189 |
+
|
190 |
+
vertex_overview_dict = {}
|
191 |
+
vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw}
|
192 |
+
with open(ROOT_OUT_PATH + 'gc_annots_overview_all4pawsincontact_xx.pkl', 'wb') as fp:
|
193 |
+
pkl.dump(vertex_overview_dict, fp)
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
main()
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
|
212 |
+
|
213 |
+
|
src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
code adapted from: https://github.com/mikedh/trimesh/blob/main/examples/shortest.py
|
4 |
+
shortest.py
|
5 |
+
----------------
|
6 |
+
Given a mesh and two vertex indices find the shortest path
|
7 |
+
between the two vertices while only traveling along edges
|
8 |
+
of the mesh.
|
9 |
+
"""
|
10 |
+
|
11 |
+
# python src/graph_networks/losses_for_vertex_wise_predictions/calculate_distance_between_points_on_mesh_forpaws.py
|
12 |
+
|
13 |
+
|
14 |
+
import os
|
15 |
+
import sys
|
16 |
+
import glob
|
17 |
+
import csv
|
18 |
+
import json
|
19 |
+
import shutil
|
20 |
+
import tqdm
|
21 |
+
import numpy as np
|
22 |
+
import pickle as pkl
|
23 |
+
import trimesh
|
24 |
+
import networkx as nx
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def read_csv(csv_file):
|
31 |
+
with open(csv_file,'r') as f:
|
32 |
+
reader = csv.reader(f)
|
33 |
+
headers = next(reader)
|
34 |
+
row_list = [{h:x for (h,x) in zip(headers,row)} for row in reader]
|
35 |
+
return row_list
|
36 |
+
|
37 |
+
|
38 |
+
def load_all_template_mesh_distances(root_out_path, filename='all_vertex_distances.npy'):
|
39 |
+
vert_dists = np.load(root_out_path + filename)
|
40 |
+
return vert_dists
|
41 |
+
|
42 |
+
|
43 |
+
def prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, root_out_path, calc_dist_mat=False):
|
44 |
+
# root_out_path = ROOT_OUT_PATH
|
45 |
+
'''
|
46 |
+
from smal_pytorch.smal_model.smal_torch_new import SMAL
|
47 |
+
smal = SMAL()
|
48 |
+
verts = smal.v_template.detach().cpu().numpy()
|
49 |
+
faces = smal.faces.detach().cpu().numpy()
|
50 |
+
'''
|
51 |
+
# path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
|
52 |
+
my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
|
53 |
+
verts = my_mesh.vertices
|
54 |
+
faces = my_mesh.faces
|
55 |
+
# edges without duplication
|
56 |
+
edges = my_mesh.edges_unique
|
57 |
+
# the actual length of each unique edge
|
58 |
+
length = my_mesh.edges_unique_length
|
59 |
+
# create the graph with edge attributes for length (option A)
|
60 |
+
# g = nx.Graph()
|
61 |
+
# for edge, L in zip(edges, length): g.add_edge(*edge, length=L)
|
62 |
+
# you can create the graph with from_edgelist and
|
63 |
+
# a list comprehension (option B)
|
64 |
+
ga = nx.from_edgelist([(e[0], e[1], {'length': L}) for e, L in zip(edges, length)])
|
65 |
+
# calculate the distances between all vertex pairs
|
66 |
+
if calc_dist_mat:
|
67 |
+
# calculate distances between all possible vertex pairs
|
68 |
+
# shortest_path = nx.shortest_path(ga, source=ind_v0, target=ind_v1, weight='length')
|
69 |
+
# shortest_dist = nx.shortest_path_length(ga, source=ind_v0, target=ind_v1, weight='length')
|
70 |
+
dis = dict(nx.shortest_path_length(ga, weight='length', method='dijkstra'))
|
71 |
+
vertex_distances = np.zeros((n_verts_smal, n_verts_smal))
|
72 |
+
for ind_v0 in range(n_verts_smal):
|
73 |
+
print(ind_v0)
|
74 |
+
for ind_v1 in range(ind_v0, n_verts_smal):
|
75 |
+
vertex_distances[ind_v0, ind_v1] = dis[ind_v0][ind_v1]
|
76 |
+
vertex_distances[ind_v1, ind_v0] = dis[ind_v0][ind_v1]
|
77 |
+
# save those distances
|
78 |
+
np.save(root_out_path + 'all_vertex_distances.npy', vertex_distances)
|
79 |
+
vert_dists = vertex_distances
|
80 |
+
else:
|
81 |
+
vert_dists = np.load(root_out_path + 'all_vertex_distances.npy')
|
82 |
+
return ga, vert_dists
|
83 |
+
|
84 |
+
|
85 |
+
def calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=None, verts=None, faces=None, img_v12_dir=None):
|
86 |
+
# input:
|
87 |
+
# root_out_path_vis = ROOT_OUT_PATH
|
88 |
+
# img_v12_dir = IMG_V12_DIR
|
89 |
+
# name = images_with_gc_labelled[ind_img]
|
90 |
+
# gc_info_raw = gc_dict['bite/' + name]
|
91 |
+
# output:
|
92 |
+
# vertex_overview: np array of shape (n_verts_smal, 3) with [first: no-contact=0 contact=1 second: index of vertex third: dist]
|
93 |
+
n_verts_smal = 3889
|
94 |
+
gc_vertices = []
|
95 |
+
gc_info_np = np.zeros((n_verts_smal))
|
96 |
+
for ind_v in gc_info_raw:
|
97 |
+
if ind_v < n_verts_smal:
|
98 |
+
gc_vertices.append(ind_v)
|
99 |
+
gc_info_np[ind_v] = 1
|
100 |
+
# save a visualization of those annotations
|
101 |
+
if root_out_path_vis is not None:
|
102 |
+
my_mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False, maintain_order=True)
|
103 |
+
if img_v12_dir is not None and root_out_path_vis is not None:
|
104 |
+
vert_colors = np.repeat(255*gc_info_np[:, None], 3, 1)
|
105 |
+
my_mesh.visual.vertex_colors = vert_colors
|
106 |
+
my_mesh.export(root_out_path_vis + (name).replace('.jpg', '_withgc.obj'))
|
107 |
+
img_path = img_v12_dir + name
|
108 |
+
shutil.copy(img_path, root_out_path_vis + name)
|
109 |
+
# calculate for each vertex the distance to the closest element of the other group
|
110 |
+
non_gc_vertices = list(set(range(n_verts_smal)) - set(gc_vertices))
|
111 |
+
print('vertices in contact: ' + str(len(gc_vertices)))
|
112 |
+
print('vertices without contact: ' + str(len(non_gc_vertices)))
|
113 |
+
vertex_overview = np.zeros((n_verts_smal, 3)) # first: no-contact=0 contact=1 second: index of vertex third: dist
|
114 |
+
vertex_overview[:, 0] = gc_info_np
|
115 |
+
# loop through all contact vertices
|
116 |
+
for ind_v in gc_vertices:
|
117 |
+
min_length = 100
|
118 |
+
for ind_v_ps in non_gc_vertices: # possible solution
|
119 |
+
# this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
|
120 |
+
# this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
|
121 |
+
this_length = vert_dists[ind_v, ind_v_ps]
|
122 |
+
if this_length < min_length:
|
123 |
+
min_length = this_length
|
124 |
+
vertex_overview[ind_v, 1] = ind_v_ps
|
125 |
+
vertex_overview[ind_v, 2] = this_length
|
126 |
+
# loop through all non-contact vertices
|
127 |
+
for ind_v in non_gc_vertices:
|
128 |
+
min_length = 100
|
129 |
+
for ind_v_ps in gc_vertices: # possible solution
|
130 |
+
# this_path = nx.shortest_path(ga, source=ind_v, target=ind_v_ps, weight='length')
|
131 |
+
# this_length = nx.shortest_path_length(ga, source=ind_v, target=ind_v_ps, weight='length')
|
132 |
+
this_length = vert_dists[ind_v, ind_v_ps]
|
133 |
+
if this_length < min_length:
|
134 |
+
min_length = this_length
|
135 |
+
vertex_overview[ind_v, 1] = ind_v_ps
|
136 |
+
vertex_overview[ind_v, 2] = this_length
|
137 |
+
if root_out_path_vis is not None:
|
138 |
+
# save a colored mesh
|
139 |
+
my_mesh_dists = my_mesh.copy()
|
140 |
+
scale_0 = (vertex_overview[vertex_overview[:, 0]==0, 2]).max()
|
141 |
+
scale_1 = (vertex_overview[vertex_overview[:, 0]==1, 2]).max()
|
142 |
+
vert_col = np.zeros((n_verts_smal, 3))
|
143 |
+
vert_col[vertex_overview[:, 0]==0, 1] = vertex_overview[vertex_overview[:, 0]==0, 2] * 255 / scale_0 # green
|
144 |
+
vert_col[vertex_overview[:, 0]==1, 0] = vertex_overview[vertex_overview[:, 0]==1, 2] * 255 / scale_1 # red
|
145 |
+
my_mesh_dists.visual.vertex_colors = np.uint8(vert_col)
|
146 |
+
my_mesh_dists.export(root_out_path_vis + (name).replace('.jpg', '_withgcdists.obj'))
|
147 |
+
return vertex_overview
|
148 |
+
|
149 |
+
|
150 |
+
def summarize_results_stage2b(row_list, display_worker_performance=False):
|
151 |
+
# four catch trials are included in every batch
|
152 |
+
annot_n02088466_3184 = {'paw_rb': 0, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 1, 'additional_part': 0, 'no_contact': 0}
|
153 |
+
annot_n02100583_9922 = {'paw_rb': 1, 'paw_rf': 0, 'paw_lb': 0, 'paw_lf': 0, 'additional_part': 0, 'no_contact': 0}
|
154 |
+
annot_n02105056_2798 = {'paw_rb': 1, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 1, 'additional_part': 1, 'no_contact': 0}
|
155 |
+
annot_n02091831_2288 = {'paw_rb': 0, 'paw_rf': 1, 'paw_lb': 1, 'paw_lf': 0, 'additional_part': 0, 'no_contact': 0}
|
156 |
+
all_comments = []
|
157 |
+
all_annotations = {}
|
158 |
+
for row in row_list:
|
159 |
+
all_comments.append(row['Answer.submitComments'])
|
160 |
+
worker_id = row['WorkerId']
|
161 |
+
if display_worker_performance:
|
162 |
+
print('----------------------------------------------------------------------------------------------')
|
163 |
+
print('Worker ID: ' + worker_id)
|
164 |
+
n_wrong = 0
|
165 |
+
n_correct = 0
|
166 |
+
for ind in range(0, len(row['Answer.submitValuesNotSure'].split(';')) - 1):
|
167 |
+
input_image = (row['Input.images'].split(';')[ind]).split('StanExtV12_Images/')[-1]
|
168 |
+
paw_rb = row['Answer.submitValuesRightBack'].split(';')[ind]
|
169 |
+
paw_rf = row['Answer.submitValuesRightFront'].split(';')[ind]
|
170 |
+
paw_lb = row['Answer.submitValuesLeftBack'].split(';')[ind]
|
171 |
+
paw_lf = row['Answer.submitValuesLeftFront'].split(';')[ind]
|
172 |
+
addpart = row['Answer.submitValuesAdditional'].split(';')[ind]
|
173 |
+
no_contact = row['Answer.submitValuesNoContact'].split(';')[ind]
|
174 |
+
unsure = row['Answer.submitValuesNotSure'].split(';')[ind]
|
175 |
+
annot = {'paw_rb': paw_rb, 'paw_rf': paw_rf, 'paw_lb': paw_lb, 'paw_lf': paw_lf,
|
176 |
+
'additional_part': addpart, 'no_contact': no_contact, 'not_sure': unsure,
|
177 |
+
'worker_id': worker_id} # , 'input_image': input_image}
|
178 |
+
if ind == 0:
|
179 |
+
gt = annot_n02088466_3184
|
180 |
+
elif ind == 1:
|
181 |
+
gt = annot_n02105056_2798
|
182 |
+
elif ind == 2:
|
183 |
+
gt = annot_n02091831_2288
|
184 |
+
elif ind == 3:
|
185 |
+
gt = annot_n02100583_9922
|
186 |
+
else:
|
187 |
+
pass
|
188 |
+
if ind < 4:
|
189 |
+
for key in gt.keys():
|
190 |
+
if str(annot[key]) == str(gt[key]):
|
191 |
+
n_correct += 1
|
192 |
+
else:
|
193 |
+
if display_worker_performance:
|
194 |
+
print(input_image)
|
195 |
+
print(key + ':[ expected: ' + str(gt[key]) + ' predicted: ' + str(annot[key]) + ' ]')
|
196 |
+
n_wrong += 1
|
197 |
+
else:
|
198 |
+
all_annotations[input_image] = annot
|
199 |
+
if display_worker_performance:
|
200 |
+
print('n_correct: ' + str(n_correct))
|
201 |
+
print('n_wrong: ' + str(n_wrong))
|
202 |
+
return all_annotations, all_comments
|
203 |
+
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
def main():
|
210 |
+
|
211 |
+
ROOT_PATH_MESH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/graphcmr/data/meshes/'
|
212 |
+
ROOT_PATH_ANNOT = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/'
|
213 |
+
IMG_V12_DIR = '/ps/scratch/nrueegg/new_projects/Animals/data/dog_datasets/Stanford_Dogs_Dataset/StanfordExtra_V12/StanExtV12_Images/'
|
214 |
+
# ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/src/graph_networks/losses_for_vertex_wise_predictions/debugging_results/'
|
215 |
+
ROOT_OUT_PATH = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/'
|
216 |
+
ROOT_OUT_PATH_VIS = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/vis/'
|
217 |
+
ROOT_OUT_PATH_DISTSGCNONGC = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/stage2b/vertex_distances_gc_nongc/'
|
218 |
+
ROOT_PATH_ALL_VERT_DIST_TEMPLATE = '/is/cluster/work/nrueegg/icon_pifu_related/barc_for_bite/data/stanext_related_data/ground_contact_annotations/'
|
219 |
+
|
220 |
+
# load all vertex distances
|
221 |
+
path_mesh = ROOT_PATH_MESH + 'mesh_downsampling_meshesmy_smpl_39dogsnorm_Jr_4_dog_template_downsampled0.obj'
|
222 |
+
my_mesh = trimesh.load_mesh(path_mesh, process=False, maintain_order=True)
|
223 |
+
verts = my_mesh.vertices
|
224 |
+
faces = my_mesh.faces
|
225 |
+
# vert_dists, ga = prepare_graph_from_template_mesh_and_calculate_all_distances(path_mesh, ROOT_OUT_PATH, calc_dist_mat=False)
|
226 |
+
vert_dists = load_all_template_mesh_distances(ROOT_PATH_ALL_VERT_DIST_TEMPLATE, filename='all_vertex_distances.npy')
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
# paw vertices:
|
233 |
+
# left and right is a bit different, but that is ok (we will anyways mirror data at training time)
|
234 |
+
right_front_paw = [3829,+3827,+3825,+3718,+3722,+3723,+3743,+3831,+3719,+3726,+3716,+3724,+3828,+3717,+3721,+3725,+3832,+3830,+3720,+3288,+3740,+3714,+3826,+3715,+3728,+3712,+3287,+3284,+3727,+3285,+3742,+3291,+3710,+3697,+3711,+3289,+3730,+3713,+3739,+3282,+3738,+3708,+3709,+3741,+3698,+3696,+3308,+3695,+3706,+3700,+3707,+3306,+3305,+3737,+3304,+3303,+3307,+3736,+3735,+3250,+3261,+3732,+3734,+3733,+3731,+3729,+3299,+3297,+3298,+3295,+3293,+3296,+3294,+3292,+3312,+3311,+3314,+3309,+3290,+3313,+3410,+3315,+3411,+3412,+3316,+3421,+3317,+3415,+3445,+3327,+3328,+3283,+3343,+3326,+3325,+3330,+3286,+3399,+3398,+3329,+3446,+3400,+3331,+3401,+3281,+3332,+3279,+3402,+3419,+3407,+3356,+3358,+3357,+3280,+3354,+3277,+3278,+3346,+3347,+3377,+3378,+3345,+3386,+3379,+3348,+3384,+3418,+3372,+3276,+3275,+3374,+3274,+3373,+3375,+3369,+3371,+3376,+3273,+3396,+3397,+3395,+3388,+3360,+3370,+3361,+3394,+3387,+3420,+3359,+3389,+3272,+3391,+3393,+3390,+3392,+3363,+3362,+3367,+3365,+3705,+3271,+3704,+3703,+3270,+3269,+3702,+3268,+3224,+3267,+3701,+3225,+3699,+3265,+3264,+3266,+3263,+3262,+3249,+3228,+3230,+3251,+3301,+3300,+3302,+3252]
|
235 |
+
right_back_paw = [3472,+3627,+3470,+3469,+3471,+3473,+3626,+3625,+3475,+3655,+3519,+3468,+3629,+3466,+3476,+3624,+3521,+3654,+3657,+3838,+3518,+3653,+3839,+3553,+3474,+3516,+3656,+3628,+3834,+3535,+3630,+3658,+3477,+3520,+3517,+3595,+3522,+3597,+3596,+3501,+3534,+3503,+3478,+3500,+3479,+3502,+3607,+3499,+3608,+3496,+3605,+3609,+3504,+3606,+3642,+3614,+3498,+3480,+3631,+3610,+3613,+3506,+3659,+3660,+3632,+3841,+3661,+3836,+3662,+3633,+3663,+3664,+3634,+3635,+3486,+3665,+3636,+3637,+3666,+3490,+3837,+3667,+3493,+3638,+3492,+3495,+3616,+3644,+3494,+3835,+3643,+3833,+3840,+3615,+3650,+3668,+3652,+3651,+3645,+3646,+3647,+3649,+3648,+3622,+3617,+3448,+3621,+3618,+3623,+3462,+3464,+3460,+3620,+3458,+3461,+3463,+3465,+3573,+3571,+3467,+3569,+3557,+3558,+3572,+3570,+3556,+3585,+3593,+3594,+3459,+3566,+3592,+3567,+3568,+3538,+3539,+3555,+3537,+3536,+3554,+3575,+3574,+3583,+3541,+3550,+3576,+3581,+3639,+3577,+3551,+3582,+3580,+3552,+3578,+3542,+3549,+3579,+3523,+3526,+3598,+3525,+3600,+3640,+3599,+3601,+3602,+3603,+3529,+3604,+3530,+3533,+3532,+3611,+3612,+3482,+3481,+3505,+3452,+3455,+3456,+3454,+3457,+3619,+3451,+3450,+3449,+3591,+3589,+3641,+3584,+3561,+3587,+3559,+3488,+3484,+3483]
|
236 |
+
left_front_paw = [1791,+1950,+1948,+1790,+1789,+1746,+1788,+1747,+1949,+1944,+1792,+1945,+1356,+1775,+1759,+1777,+1787,+1946,+1757,+1761,+1745,+1943,+1947,+1744,+1309,+1786,+1771,+1354,+1774,+1765,+1767,+1768,+1772,+1763,+1770,+1773,+1769,+1764,+1766,+1758,+1760,+1762,+1336,+1333,+1330,+1325,+1756,+1323,+1755,+1753,+1749,+1754,+1751,+1321,+1752,+1748,+1750,+1312,+1319,+1315,+1313,+1317,+1318,+1316,+1314,+1311,+1310,+1299,+1276,+1355,+1297,+1353,+1298,+1300,+1352,+1351,+1785,+1784,+1349,+1783,+1782,+1781,+1780,+1779,+1778,+1776,+1343,+1341,+1344,+1339,+1342,+1340,+1360,+1335,+1338,+1362,+1357,+1361,+1363,+1458,+1337,+1459,+1456,+1460,+1493,+1332,+1375,+1376,+1331,+1374,+1378,+1334,+1373,+1494,+1377,+1446,+1448,+1379,+1449,+1329,+1327,+1404,+1406,+1405,+1402,+1328,+1426,+1432,+1434,+1403,+1394,+1395,+1433,+1425,+1286,+1380,+1466,+1431,+1290,+1401,+1381,+1427,+1450,+1393,+1430,+1326,+1396,+1428,+1397,+1429,+1398,+1420,+1324,+1422,+1417,+1419,+1421,+1443,+1418,+1423,+1444,+1442,+1424,+1445,+1495,+1440,+1441,+1468,+1436,+1408,+1322,+1435,+1415,+1439,+1409,+1283,+1438,+1416,+1407,+1437,+1411,+1413,+1414,+1320,+1273,+1272,+1278,+1469,+1463,+1457,+1358,+1464,+1465,+1359,+1372,+1391,+1390,+1455,+1447,+1454,+1467,+1453,+1452,+1451,+1383,+1345,+1347,+1348,+1350,+1364,+1392,+1410,+1412]
|
237 |
+
left_back_paw = [1957,+1958,+1701,+1956,+1951,+1703,+1715,+1702,+1700,+1673,+1705,+1952,+1955,+1674,+1699,+1675,+1953,+1704,+1954,+1698,+1677,+1671,+1672,+1714,+1706,+1676,+1519,+1523,+1686,+1713,+1692,+1685,+1543,+1664,+1712,+1691,+1959,+1541,+1684,+1542,+1496,+1663,+1540,+1497,+1499,+1498,+1500,+1693,+1665,+1694,+1716,+1666,+1695,+1501,+1502,+1696,+1667,+1503,+1697,+1504,+1668,+1669,+1506,+1670,+1508,+1510,+1507,+1509,+1511,+1512,+1621,+1606,+1619,+1605,+1513,+1620,+1618,+1604,+1633,+1641,+1642,+1607,+1617,+1514,+1632,+1614,+1689,+1640,+1515,+1586,+1616,+1516,+1517,+1603,+1615,+1639,+1585,+1521,+1602,+1587,+1584,+1601,+1623,+1622,+1631,+1598,+1624,+1629,+1589,+1687,+1625,+1599,+1630,+1569,+1570,+1628,+1626,+1597,+1627,+1590,+1594,+1571,+1568,+1567,+1574,+1646,+1573,+1645,+1648,+1564,+1688,+1647,+1643,+1649,+1650,+1651,+1577,+1644,+1565,+1652,+1566,+1578,+1518,+1524,+1583,+1582,+1520,+1581,+1522,+1525,+1549,+1551,+1580,+1552,+1550,+1656,+1658,+1554,+1657,+1659,+1548,+1655,+1690,+1660,+1556,+1653,+1558,+1661,+1544,+1662,+1654,+1547,+1545,+1527,+1560,+1526,+1678,+1679,+1528,+1708,+1707,+1680,+1529,+1530,+1709,+1546,+1681,+1710,+1711,+1682,+1532,+1531,+1683,+1534,+1533,+1536,+1538,+1600,+1553]
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
all_keys = []
|
244 |
+
gc_dict = {}
|
245 |
+
vertex_overview_nocontact = {}
|
246 |
+
# data/stanext_related_data/ground_contact_annotations/stage3/main_partA1667_20221021_140108.csv
|
247 |
+
for csv_file in ['Stage2b_finalResults.csv']:
|
248 |
+
# load all ground contact annotations
|
249 |
+
gc_annot_csv = ROOT_PATH_ANNOT + csv_file # 'my_gcannotations_qualification.csv'
|
250 |
+
gc_row_list = read_csv(gc_annot_csv)
|
251 |
+
all_annotations, all_comments = summarize_results_stage2b(gc_row_list, display_worker_performance=False)
|
252 |
+
for key, value in all_annotations.items():
|
253 |
+
if value['not_sure'] == '0':
|
254 |
+
if value['no_contact'] == '1':
|
255 |
+
vertex_overview_nocontact[key.split('.')[0]] = {'gc_vertdists_overview': 'no contact', 'gc_index_list': None}
|
256 |
+
else:
|
257 |
+
all_contact_vertices = []
|
258 |
+
if value['paw_rf'] == '1':
|
259 |
+
all_contact_vertices.extend(right_front_paw)
|
260 |
+
if value['paw_rb'] == '1':
|
261 |
+
all_contact_vertices.extend(right_back_paw)
|
262 |
+
if value['paw_lf'] == '1':
|
263 |
+
all_contact_vertices.extend(left_front_paw)
|
264 |
+
if value['paw_lb'] == '1':
|
265 |
+
all_contact_vertices.extend(left_back_paw)
|
266 |
+
gc_dict[key] = all_contact_vertices
|
267 |
+
print('number of labeled images: ' + str(len(gc_dict.keys())))
|
268 |
+
print('number of images without contact: ' + str(len(vertex_overview_nocontact.keys())))
|
269 |
+
|
270 |
+
# prepare and save contact annotations including distances
|
271 |
+
vertex_overview_dict = {}
|
272 |
+
for ind_img, name_ingcdict in enumerate(gc_dict.keys()): # range(len(gc_dict.keys())):
|
273 |
+
name = name_ingcdict # name_ingcdict.split('bite/')[1]
|
274 |
+
# name = images_with_gc_labelled[ind_img]
|
275 |
+
print('work on image ' + str(ind_img) + ': ' + name)
|
276 |
+
# gc_info_raw = gc_dict['bite/' + name] # a list with all vertex numbers that are in ground contact
|
277 |
+
gc_info_raw = gc_dict[name_ingcdict] # a list with all vertex numbers that are in ground contact
|
278 |
+
|
279 |
+
if not os.path.exists(ROOT_OUT_PATH_VIS + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_VIS + name.split('/')[0])
|
280 |
+
if not os.path.exists(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0]): os.makedirs(ROOT_OUT_PATH_DISTSGCNONGC + name.split('/')[0])
|
281 |
+
|
282 |
+
vertex_overview = calculate_vertex_overview_for_gc_annotation(name, gc_info_raw, vert_dists, root_out_path_vis=ROOT_OUT_PATH_VIS, verts=verts, faces=faces, img_v12_dir=None)
|
283 |
+
np.save(ROOT_OUT_PATH_DISTSGCNONGC + name.replace('.jpg', '_gc_vertdists_overview.npy'), vertex_overview)
|
284 |
+
|
285 |
+
vertex_overview_dict[name.split('.')[0]] = {'gc_vertdists_overview': vertex_overview, 'gc_index_list': gc_info_raw}
|
286 |
+
|
287 |
+
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
# import pdb; pdb.set_trace()
|
292 |
+
|
293 |
+
with open(ROOT_OUT_PATH + 'gc_annots_overview_stage2b_contact_complete_xx.pkl', 'wb') as fp:
|
294 |
+
pkl.dump(vertex_overview_dict, fp)
|
295 |
+
|
296 |
+
with open(ROOT_OUT_PATH + 'gc_annots_overview_stage2b_nocontact_complete_xx.pkl', 'wb') as fp:
|
297 |
+
pkl.dump(vertex_overview_nocontact, fp)
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
if __name__ == "__main__":
|
310 |
+
main()
|
311 |
+
|
312 |
+
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
|