Spaces:
Runtime error
Runtime error
File size: 4,305 Bytes
da48dbe 487ee6d da48dbe 487ee6d da48dbe fb140f6 da48dbe fb140f6 487ee6d fb140f6 487ee6d fb140f6 487ee6d fb140f6 487ee6d fb140f6 487ee6d fb140f6 487ee6d fb140f6 487ee6d fb140f6 487ee6d fb140f6 487ee6d da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe fb140f6 da48dbe 487ee6d da48dbe fb140f6 da48dbe fb140f6 da48dbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
import pytorch_lightning as pl
import torch
from termcolor import colored
from ..dataset.mesh_util import *
from ..net.geometry import orthogonal
class Format:
end = '\033[0m'
start = '\033[4m'
def init_loss():
losses = {
# Cloth: chamfer distance
"cloth": {"weight": 1e3, "value": 0.0},
# Stiffness: [RT]_v1 - [RT]_v2 (v1-edge-v2)
"stiff": {"weight": 1e5, "value": 0.0},
# Cloth: det(R) = 1
"rigid": {"weight": 1e5, "value": 0.0},
# Cloth: edge length
"edge": {"weight": 0, "value": 0.0},
# Cloth: normal consistency
"nc": {"weight": 0, "value": 0.0},
# Cloth: laplacian smoonth
"lapla": {"weight": 1e2, "value": 0.0},
# Body: Normal_pred - Normal_smpl
"normal": {"weight": 1e0, "value": 0.0},
# Body: Silhouette_pred - Silhouette_smpl
"silhouette": {"weight": 1e0, "value": 0.0},
# Joint: reprojected joints difference
"joint": {"weight": 5e0, "value": 0.0},
}
return losses
class SubTrainer(pl.Trainer):
def save_checkpoint(self, filepath, weights_only=False):
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
_checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
del_keys = []
for key in _checkpoint["state_dict"].keys():
for ignore_key in ["normal_filter", "voxelization", "reconEngine"]:
if ignore_key in key:
del_keys.append(key)
for key in del_keys:
del _checkpoint["state_dict"][key]
pl.utilities.cloud_io.atomic_save(_checkpoint, filepath)
def query_func(opt, netG, features, points, proj_matrix=None):
"""
- points: size of (bz, N, 3)
- proj_matrix: size of (bz, 4, 4)
return: size of (bz, 1, N)
"""
assert len(points) == 1
samples = points.repeat(opt.num_views, 1, 1)
samples = samples.permute(0, 2, 1) # [bz, 3, N]
# view specific query
if proj_matrix is not None:
samples = orthogonal(samples, proj_matrix)
calib_tensor = torch.stack([torch.eye(4).float()], dim=0).type_as(samples)
preds = netG.query(
features=features,
points=samples,
calibs=calib_tensor,
regressor=netG.if_regressor,
)
if type(preds) is list:
preds = preds[0]
return preds
def query_func_IF(batch, netG, points):
"""
- points: size of (bz, N, 3)
return: size of (bz, 1, N)
"""
batch["samples_geo"] = points
batch["calib"] = torch.stack([torch.eye(4).float()], dim=0).type_as(points)
preds = netG(batch)
return preds.unsqueeze(1)
def batch_mean(res, key):
return torch.stack([
x[key] if torch.is_tensor(x[key]) else torch.as_tensor(x[key]) for x in res
]).mean()
def accumulate(outputs, rot_num, split):
hparam_log_dict = {}
metrics = outputs[0].keys()
datasets = split.keys()
for dataset in datasets:
for metric in metrics:
keyword = f"{dataset}/{metric}"
if keyword not in hparam_log_dict.keys():
hparam_log_dict[keyword] = 0
for idx in range(split[dataset][0] * rot_num, split[dataset][1] * rot_num):
hparam_log_dict[keyword] += outputs[idx][metric].item()
hparam_log_dict[keyword] /= (split[dataset][1] - split[dataset][0]) * rot_num
print(colored(hparam_log_dict, "green"))
return hparam_log_dict
|