Spaces:
Runtime error
Runtime error
import pickle as pkl | |
import numpy as np | |
import torchvision.models as models | |
from torchvision import transforms | |
import torch | |
from torch import nn | |
from torch.nn.parameter import Parameter | |
from kornia.geometry.subpix import dsnt # kornia 0.4.0 | |
import os | |
import sys | |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) | |
from stacked_hourglass.utils.evaluation import get_preds_soft | |
from stacked_hourglass import hg1, hg2, hg8 | |
from lifting_to_3d.linear_model import LinearModelComplete, LinearModel | |
from lifting_to_3d.inn_model_for_shape import INNForShape | |
from lifting_to_3d.utils.geometry_utils import rot6d_to_rotmat, rotmat_to_rot6d | |
from smal_pytorch.smal_model.smal_torch_new import SMAL | |
from smal_pytorch.renderer.differentiable_renderer import SilhRenderer | |
from bps_2d.bps_for_segmentation import SegBPS | |
# from configs.SMAL_configs import SMAL_MODEL_DATA_PATH as SHAPE_PRIOR | |
from configs.SMAL_configs import SMAL_MODEL_CONFIG | |
from configs.SMAL_configs import MEAN_DOG_BONE_LENGTHS_NO_RED, VERTEX_IDS_TAIL | |
# NEW: for graph cnn part | |
from smal_pytorch.smal_model.smal_torch_new import SMAL | |
from configs.SMAL_configs import SMAL_MODEL_CONFIG | |
from graph_networks.graphcmr.utils_mesh import Mesh | |
from graph_networks.graphcmr.graph_cnn_groundcontact_multistage import GraphCNNMS | |
class SmallLinear(nn.Module): | |
def __init__(self, input_size=64, output_size=30, linear_size=128): | |
super(SmallLinear, self).__init__() | |
self.relu = nn.ReLU(inplace=True) | |
self.w1 = nn.Linear(input_size, linear_size) | |
self.w2 = nn.Linear(linear_size, linear_size) | |
self.w3 = nn.Linear(linear_size, output_size) | |
def forward(self, x): | |
# pre-processing | |
y = self.w1(x) | |
y = self.relu(y) | |
y = self.w2(y) | |
y = self.relu(y) | |
y = self.w3(y) | |
return y | |
class MyConv1d(nn.Module): | |
def __init__(self, input_size=37, output_size=30, start=True): | |
super(MyConv1d, self).__init__() | |
self.input_size = input_size | |
self.output_size = output_size | |
self.start = start | |
self.weight = Parameter(torch.ones((self.output_size))) | |
self.bias = Parameter(torch.zeros((self.output_size))) | |
def forward(self, x): | |
# pre-processing | |
if self.start: | |
y = x[:, :self.output_size] | |
else: | |
y = x[:, -self.output_size:] | |
y = y * self.weight[None, :] + self.bias[None, :] | |
return y | |
class ModelShapeAndBreed(nn.Module): | |
def __init__(self, smal_model_type, n_betas=10, n_betas_limbs=13, n_breeds=121, n_z=512, structure_z_to_betas='default'): | |
super(ModelShapeAndBreed, self).__init__() | |
self.n_betas = n_betas | |
self.n_betas_limbs = n_betas_limbs # n_betas_logscale | |
self.n_breeds = n_breeds | |
self.structure_z_to_betas = structure_z_to_betas | |
if self.structure_z_to_betas == '1dconv': | |
if not (n_z == self.n_betas+self.n_betas_limbs): | |
raise ValueError | |
self.smal_model_type = smal_model_type | |
# shape branch | |
self.resnet = models.resnet34(pretrained=False) | |
# replace the first layer | |
n_in = 3 + 1 | |
self.resnet.conv1 = nn.Conv2d(n_in, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | |
# replace the last layer | |
self.resnet.fc = nn.Linear(512, n_z) | |
# softmax | |
self.soft_max = torch.nn.Softmax(dim=1) | |
# fc network (and other versions) to connect z with betas | |
p_dropout = 0.2 | |
if self.structure_z_to_betas == 'default': | |
self.linear_betas = LinearModel(linear_size=1024, | |
num_stage=1, | |
p_dropout=p_dropout, | |
input_size=n_z, | |
output_size=self.n_betas) | |
self.linear_betas_limbs = LinearModel(linear_size=1024, | |
num_stage=1, | |
p_dropout=p_dropout, | |
input_size=n_z, | |
output_size=self.n_betas_limbs) | |
elif self.structure_z_to_betas == 'lin': | |
self.linear_betas = nn.Linear(n_z, self.n_betas) | |
self.linear_betas_limbs = nn.Linear(n_z, self.n_betas_limbs) | |
elif self.structure_z_to_betas == 'fc_0': | |
self.linear_betas = SmallLinear(linear_size=128, # 1024, | |
input_size=n_z, | |
output_size=self.n_betas) | |
self.linear_betas_limbs = SmallLinear(linear_size=128, # 1024, | |
input_size=n_z, | |
output_size=self.n_betas_limbs) | |
elif structure_z_to_betas == 'fc_1': | |
self.linear_betas = LinearModel(linear_size=64, # 1024, | |
num_stage=1, | |
p_dropout=0, | |
input_size=n_z, | |
output_size=self.n_betas) | |
self.linear_betas_limbs = LinearModel(linear_size=64, # 1024, | |
num_stage=1, | |
p_dropout=0, | |
input_size=n_z, | |
output_size=self.n_betas_limbs) | |
elif self.structure_z_to_betas == '1dconv': | |
self.linear_betas = MyConv1d(n_z, self.n_betas, start=True) | |
self.linear_betas_limbs = MyConv1d(n_z, self.n_betas_limbs, start=False) | |
elif self.structure_z_to_betas == 'inn': | |
self.linear_betas_and_betas_limbs = INNForShape(self.n_betas, self.n_betas_limbs, betas_scale=1.0, betas_limbs_scale=1.0) | |
else: | |
raise ValueError | |
# network to connect latent shape vector z with dog breed classification | |
self.linear_breeds = LinearModel(linear_size=1024, # 1024, | |
num_stage=1, | |
p_dropout=p_dropout, | |
input_size=n_z, | |
output_size=self.n_breeds) | |
# shape multiplicator | |
self.shape_multiplicator_np = np.ones(self.n_betas) | |
with open(SMAL_MODEL_CONFIG[self.smal_model_type]['smal_model_data_path'], 'rb') as file: | |
u = pkl._Unpickler(file) | |
u.encoding = 'latin1' | |
res = u.load() | |
# shape predictions are centered around the mean dog of our dog model | |
if 'dog_cluster_mean' in res.keys(): | |
self.betas_mean_np = res['dog_cluster_mean'] | |
else: | |
assert res['cluster_means'].shape[0]==1 | |
self.betas_mean_np = res['cluster_means'][0, :] | |
def forward(self, img, seg_raw=None, seg_prep=None): | |
# img is the network input image | |
# seg_raw is before softmax and subtracting 0.5 | |
# seg_prep would be the prepared_segmentation | |
if seg_prep is None: | |
seg_prep = self.soft_max(seg_raw)[:, 1:2, :, :] - 0.5 | |
input_img_and_seg = torch.cat((img, seg_prep), axis=1) | |
res_output = self.resnet(input_img_and_seg) | |
dog_breed_output = self.linear_breeds(res_output) | |
if self.structure_z_to_betas == 'inn': | |
shape_output_orig, shape_limbs_output_orig = self.linear_betas_and_betas_limbs(res_output) | |
else: | |
shape_output_orig = self.linear_betas(res_output) * 0.1 | |
betas_mean = torch.tensor(self.betas_mean_np).float().to(img.device) | |
shape_output = shape_output_orig + betas_mean[None, 0:self.n_betas] | |
shape_limbs_output_orig = self.linear_betas_limbs(res_output) | |
shape_limbs_output = shape_limbs_output_orig * 0.1 | |
output_dict = {'z': res_output, | |
'breeds': dog_breed_output, | |
'betas': shape_output_orig, | |
'betas_limbs': shape_limbs_output_orig} | |
return output_dict | |
class LearnableShapedirs(nn.Module): | |
def __init__(self, sym_ids_dict, shapedirs_init, n_betas, n_betas_fixed=10): | |
super(LearnableShapedirs, self).__init__() | |
# shapedirs_init = self.smal.shapedirs.detach() | |
self.n_betas = n_betas | |
self.n_betas_fixed = n_betas_fixed | |
self.sym_ids_dict = sym_ids_dict | |
sym_left_ids = self.sym_ids_dict['left'] | |
sym_right_ids = self.sym_ids_dict['right'] | |
sym_center_ids = self.sym_ids_dict['center'] | |
self.n_center = sym_center_ids.shape[0] | |
self.n_left = sym_left_ids.shape[0] | |
self.n_sd = self.n_betas - self.n_betas_fixed # number of learnable shapedirs | |
# get indices to go from half_shapedirs to shapedirs | |
inds_back = np.zeros((3889)) | |
for ind in range(0, sym_center_ids.shape[0]): | |
ind_in_forward = sym_center_ids[ind] | |
inds_back[ind_in_forward] = ind | |
for ind in range(0, sym_left_ids.shape[0]): | |
ind_in_forward = sym_left_ids[ind] | |
inds_back[ind_in_forward] = sym_center_ids.shape[0] + ind | |
for ind in range(0, sym_right_ids.shape[0]): | |
ind_in_forward = sym_right_ids[ind] | |
inds_back[ind_in_forward] = sym_center_ids.shape[0] + sym_left_ids.shape[0] + ind | |
self.register_buffer('inds_back_torch', torch.Tensor(inds_back).long()) | |
# self.smal.shapedirs: (51, 11667) | |
# shapedirs: (3889, 3, n_sd) | |
# shapedirs_half: (2012, 3, n_sd) | |
sd = shapedirs_init[:self.n_betas, :].permute((1, 0)).reshape((-1, 3, self.n_betas)) | |
self.register_buffer('sd', sd) | |
sd_center = sd[sym_center_ids, :, self.n_betas_fixed:] | |
sd_left = sd[sym_left_ids, :, self.n_betas_fixed:] | |
self.register_parameter('learnable_half_shapedirs_c0', torch.nn.Parameter(sd_center[:, 0, :].detach())) | |
self.register_parameter('learnable_half_shapedirs_c2', torch.nn.Parameter(sd_center[:, 2, :].detach())) | |
self.register_parameter('learnable_half_shapedirs_l0', torch.nn.Parameter(sd_left[:, 0, :].detach())) | |
self.register_parameter('learnable_half_shapedirs_l1', torch.nn.Parameter(sd_left[:, 1, :].detach())) | |
self.register_parameter('learnable_half_shapedirs_l2', torch.nn.Parameter(sd_left[:, 2, :].detach())) | |
def forward(self): | |
device = self.learnable_half_shapedirs_c0.device | |
half_shapedirs_center = torch.stack((self.learnable_half_shapedirs_c0, \ | |
torch.zeros((self.n_center, self.n_sd)).to(device), \ | |
self.learnable_half_shapedirs_c2), axis=1) | |
half_shapedirs_left = torch.stack((self.learnable_half_shapedirs_l0, \ | |
self.learnable_half_shapedirs_l1, \ | |
self.learnable_half_shapedirs_l2), axis=1) | |
half_shapedirs_right = torch.stack((self.learnable_half_shapedirs_l0, \ | |
- self.learnable_half_shapedirs_l1, \ | |
self.learnable_half_shapedirs_l2), axis=1) | |
half_shapedirs_tot = torch.cat((half_shapedirs_center, half_shapedirs_left, half_shapedirs_right)) | |
shapedirs = torch.index_select(half_shapedirs_tot, dim=0, index=self.inds_back_torch) | |
shapedirs_complete = torch.cat((self.sd[:, :, :self.n_betas_fixed], shapedirs), axis=2) # (3889, 3, n_sd) | |
shapedirs_complete_prepared = torch.cat((self.sd[:, :, :10], shapedirs), axis=2).reshape((-1, 30)).permute((1, 0)) # (n_sd, 11667) | |
return shapedirs_complete, shapedirs_complete_prepared | |
class ModelRefinement(nn.Module): | |
def __init__(self, n_betas=10, n_betas_limbs=7, n_breeds=121, n_keyp=20, n_joints=35, ref_net_type='add', graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'): | |
super(ModelRefinement, self).__init__() | |
self.n_betas = n_betas | |
self.n_betas_limbs = n_betas_limbs | |
self.n_breeds = n_breeds | |
self.n_keyp = n_keyp | |
self.n_joints = n_joints | |
self.n_out_seg = 256 | |
self.n_out_keyp = 256 | |
self.n_out_enc = 256 | |
self.linear_size = 1024 | |
self.linear_size_small = 128 | |
self.ref_net_type = ref_net_type | |
self.graphcnn_type = graphcnn_type | |
self.isflat_type = isflat_type | |
self.shaperef_type = shaperef_type | |
p_dropout = 0.2 | |
# --- segmentation encoder | |
if self.ref_net_type in ['multrot_res34', 'multrot01all_res34']: | |
self.ref_res = models.resnet34(pretrained=False) | |
else: | |
self.ref_res = models.resnet18(pretrained=False) | |
# replace the first layer | |
self.ref_res.conv1 = nn.Conv2d(2, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) | |
# replace the last layer | |
self.ref_res.fc = nn.Linear(512, self.n_out_seg) | |
# softmax | |
self.soft_max = torch.nn.Softmax(dim=1) | |
# --- keypoint encoder | |
self.linear_keyp = LinearModel(linear_size=self.linear_size, | |
num_stage=1, | |
p_dropout=p_dropout, | |
input_size=n_keyp*2*2, | |
output_size=self.n_out_keyp) | |
# --- decoder | |
self.linear_combined = LinearModel(linear_size=self.linear_size, | |
num_stage=1, | |
p_dropout=p_dropout, | |
input_size=self.n_out_seg+self.n_out_keyp, | |
output_size=self.n_out_enc) | |
# output info | |
pose = {'name': 'pose', 'n': self.n_joints*6, 'out_shape':[self.n_joints, 6]} | |
trans = {'name': 'trans_notnorm', 'n': 3} | |
cam = {'name': 'flength_notnorm', 'n': 1} | |
betas = {'name': 'betas', 'n': self.n_betas} | |
betas_limbs = {'name': 'betas_limbs', 'n': self.n_betas_limbs} | |
if self.shaperef_type=='inexistent': | |
self.output_info = [pose, trans, cam] # , betas] | |
else: | |
self.output_info = [pose, trans, cam, betas, betas_limbs] | |
# output branches | |
self.output_info_linear_models = [] | |
for ind_el, element in enumerate(self.output_info): | |
n_in = self.n_out_enc + element['n'] | |
self.output_info_linear_models.append(LinearModel(linear_size=self.linear_size, | |
num_stage=1, | |
p_dropout=p_dropout, | |
input_size=n_in, | |
output_size=element['n'])) | |
element['linear_model_index'] = ind_el | |
self.output_info_linear_models = nn.ModuleList(self.output_info_linear_models) | |
# new: predict if the ground is flat | |
if not self.isflat_type=='inexistent': | |
self.linear_isflat = LinearModel(linear_size=self.linear_size_small, | |
num_stage=1, | |
p_dropout=p_dropout, | |
input_size=self.n_out_enc, | |
output_size=2) # answer is just yes or no | |
# new for ground contact prediction: graph cnn | |
if not self.graphcnn_type=='inexistent': | |
num_downsampling = 1 | |
smal_model_type = '39dogs_norm' | |
smal = SMAL(smal_model_type=smal_model_type, template_name='neutral') | |
ROOT_smal_downsampling = os.path.join(os.path.dirname(__file__), './../../data/graphcmr_data/') | |
smal_downsampling_npz_name = 'mesh_downsampling_' + os.path.basename(SMAL_MODEL_CONFIG[smal_model_type]['smal_model_path']).replace('.pkl', '_template.npz') | |
smal_downsampling_npz_path = ROOT_smal_downsampling + smal_downsampling_npz_name # 'data/mesh_downsampling.npz' | |
self.my_custom_smal_dog_mesh = Mesh(filename=smal_downsampling_npz_path, num_downsampling=num_downsampling, nsize=1, body_model=smal) # , device=device) | |
# create GraphCNN | |
num_layers = 2 # <= len(my_custom_mesh._A)-1 | |
n_resnet_out = self.n_out_enc # 256 | |
num_channels = 256 # 512 | |
self.graph_cnn = GraphCNNMS(mesh=self.my_custom_smal_dog_mesh, | |
num_downsample = num_downsampling, | |
num_layers = num_layers, | |
n_resnet_out = n_resnet_out, | |
num_channels = num_channels) # .to(device) | |
def forward(self, keyp_sh, keyp_pred, in_pose_3x3, in_trans_notnorm, in_cam_notnorm, in_betas, in_betas_limbs, seg_pred_prep=None, seg_sh_raw=None, seg_sh_prep=None): | |
# img is the network input image | |
# seg_raw is before softmax and subtracting 0.5 | |
# seg_prep would be the prepared_segmentation | |
batch_size = in_pose_3x3.shape[0] | |
device = in_pose_3x3.device | |
dtype = in_pose_3x3.dtype | |
# --- segmentation encoder | |
if seg_sh_prep is None: | |
seg_sh_prep = self.soft_max(seg_sh_raw)[:, 1:2, :, :] - 0.5 # class 1 is the dog | |
input_seg_conc = torch.cat((seg_sh_prep, seg_pred_prep), axis=1) | |
network_output_seg = self.ref_res(input_seg_conc) | |
# --- keypoint encoder | |
keyp_conc = torch.cat((keyp_sh.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2])), keyp_pred.reshape((-1, keyp_sh.shape[1]*keyp_sh.shape[2]))), axis=1) | |
network_output_keyp = self.linear_keyp(keyp_conc) | |
# --- decoder | |
x = torch.cat((network_output_seg, network_output_keyp), axis=1) | |
y_comb = self.linear_combined(x) | |
in_pose_6d = rotmat_to_rot6d(in_pose_3x3.reshape((-1, 3, 3))).reshape((in_pose_3x3.shape[0], -1, 6)) | |
in_dict = {'pose': in_pose_6d, | |
'trans_notnorm': in_trans_notnorm, | |
'flength_notnorm': in_cam_notnorm, | |
'betas': in_betas, | |
'betas_limbs': in_betas_limbs} | |
results = {} | |
for element in self.output_info: | |
# import pdb; pdb.set_trace() | |
linear_model = self.output_info_linear_models[element['linear_model_index']] | |
y = torch.cat((y_comb, in_dict[element['name']].reshape((-1, element['n']))), axis=1) | |
if 'out_shape' in element.keys(): | |
if element['name'] == 'pose': | |
if self.ref_net_type in ['multrot', 'multrot01', 'multrot01all', 'multrotxx', 'multrot_res34', 'multrot01all_res34']: # if self.ref_net_type == 'multrot' or self.ref_net_type == 'multrot_res34': | |
# multiply the rotations with each other -> just predict a correction | |
# the correction should be initialized as identity | |
# res_pose_out = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']] | |
identity_rot6d = torch.tensor(([1., 0., 0., 1., 0., 0.])).repeat((in_pose_3x3.shape[0]*in_pose_3x3.shape[1], 1)).to(device=device, dtype=dtype) | |
if self.ref_net_type in ['multrot01', 'multrot01all', 'multrot01all_res34']: | |
res_pose_out = identity_rot6d + 0.1*(linear_model(y)).reshape((-1, element['out_shape'][1])) | |
elif self.ref_net_type == 'multrotxx': | |
res_pose_out = identity_rot6d + 0.0*(linear_model(y)).reshape((-1, element['out_shape'][1])) | |
else: | |
res_pose_out = identity_rot6d + (linear_model(y)).reshape((-1, element['out_shape'][1])) | |
res_pose_rotmat = rot6d_to_rotmat(res_pose_out.reshape((-1, 6))) # (bs*35, 3, 3) .reshape((batch_size, -1, 3, 3)) | |
res_tot_rotmat = torch.bmm(res_pose_rotmat.reshape((-1, 3, 3)), in_pose_3x3.reshape((-1, 3, 3))).reshape((batch_size, -1, 3, 3)) # (bs, 5, 3, 3) | |
results['pose_rotmat'] = res_tot_rotmat | |
elif self.ref_net_type == 'add': | |
res_6d = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict['pose'] | |
results['pose_rotmat'] = rot6d_to_rotmat(res_6d.reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) | |
else: | |
raise ValueError | |
else: | |
if self.ref_net_type in ['multrot01all', 'multrot01all_res34']: | |
results[element['name']] = (0.1*linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']] | |
else: | |
results[element['name']] = (linear_model(y)).reshape((-1, element['out_shape'][0], element['out_shape'][1])) + in_dict[element['name']] | |
else: | |
if self.ref_net_type in ['multrot01all', 'multrot01all_res34']: | |
results[element['name']] = 0.1*linear_model(y) + in_dict[element['name']] | |
else: | |
results[element['name']] = linear_model(y) + in_dict[element['name']] | |
# add prediction if ground is flat | |
if not self.isflat_type=='inexistent': | |
isflat = self.linear_isflat(y_comb) | |
results['isflat'] = isflat | |
# add graph cnn | |
if not self.graphcnn_type=='inexistent': | |
ground_contact_downsampled, ground_cntact_all_stages_output = self.graph_cnn(y_comb) | |
ground_contact = self.my_custom_smal_dog_mesh.upsample(ground_contact_downsampled.transpose(1,2)) | |
results['vertexwise_ground_contact'] = ground_contact | |
return results | |
class ModelImageToBreed(nn.Module): | |
def __init__(self, smal_model_type, arch='hg8', n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=7, n_breeds=121, image_size=256, n_z=512, thr_keyp_sc=None, add_partseg=True): | |
super(ModelImageToBreed, self).__init__() | |
self.n_classes = n_classes | |
self.n_partseg = n_partseg | |
self.n_betas = n_betas | |
self.n_betas_limbs = n_betas_limbs | |
self.n_keyp = n_keyp | |
self.n_bones = n_bones | |
self.n_breeds = n_breeds | |
self.image_size = image_size | |
self.upsample_seg = True | |
self.threshold_scores = thr_keyp_sc | |
self.n_z = n_z | |
self.add_partseg = add_partseg | |
self.smal_model_type = smal_model_type | |
# ------------------------------ STACKED HOUR GLASS ------------------------------ | |
if arch == 'hg8': | |
self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) | |
else: | |
raise Exception('unrecognised model architecture: ' + arch) | |
# ------------------------------ SHAPE AND BREED MODEL ------------------------------ | |
self.breed_model = ModelShapeAndBreed(smal_model_type=self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z) | |
def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): | |
batch_size = input_img.shape[0] | |
device = input_img.device | |
# ------------------------------ STACKED HOUR GLASS ------------------------------ | |
hourglass_out_dict = self.stacked_hourglass(input_img) | |
last_seg = hourglass_out_dict['seg_final'] | |
last_heatmap = hourglass_out_dict['out_list_kp'][-1] | |
# - prepare keypoints (from heatmap) | |
# normalize predictions -> from logits to probability distribution | |
# last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) | |
# keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) | |
# keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) | |
keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) | |
if self.threshold_scores is not None: | |
scores[scores>self.threshold_scores] = 1.0 | |
scores[scores<=self.threshold_scores] = 0.0 | |
# ------------------------------ SHAPE AND BREED MODEL ------------------------------ | |
# breed_model takes as input the image as well as the predicted segmentation map | |
# -> we need to split up ModelImageTo3d, such that we can use the silhouette | |
resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) | |
pred_breed = resnet_output['breeds'] # (bs, n_breeds) | |
pred_betas = resnet_output['betas'] | |
pred_betas_limbs = resnet_output['betas_limbs'] | |
small_output = {'keypoints_norm': keypoints_norm, | |
'keypoints_scores': scores} | |
small_output_reproj = {'betas': pred_betas, | |
'betas_limbs': pred_betas_limbs, | |
'dog_breed': pred_breed} | |
return small_output, None, small_output_reproj | |
class ModelImageTo3d_withshape_withproj(nn.Module): | |
def __init__(self, smal_model_type, smal_keyp_conf=None, arch='hg8', num_stage_comb=2, num_stage_heads=1, num_stage_heads_pose=1, trans_sep=False, n_joints=35, n_classes=20, n_partseg=15, n_keyp=20, n_bones=24, n_betas=10, n_betas_limbs=6, n_breeds=121, image_size=256, n_z=512, n_segbps=64*2, thr_keyp_sc=None, add_z_to_3d_input=True, add_segbps_to_3d_input=False, add_partseg=True, silh_no_tail=True, fix_flength=False, render_partseg=False, structure_z_to_betas='default', structure_pose_net='default', nf_version=None, ref_net_type='add', ref_detach_shape=True, graphcnn_type='inexistent', isflat_type='inexistent', shaperef_type='inexistent'): | |
super(ModelImageTo3d_withshape_withproj, self).__init__() | |
self.n_classes = n_classes | |
self.n_partseg = n_partseg | |
self.n_betas = n_betas | |
self.n_betas_limbs = n_betas_limbs | |
self.n_keyp = n_keyp | |
self.n_joints = n_joints | |
self.n_bones = n_bones | |
self.n_breeds = n_breeds | |
self.image_size = image_size | |
self.threshold_scores = thr_keyp_sc | |
self.upsample_seg = True | |
self.silh_no_tail = silh_no_tail | |
self.add_z_to_3d_input = add_z_to_3d_input | |
self.add_segbps_to_3d_input = add_segbps_to_3d_input | |
self.add_partseg = add_partseg | |
self.ref_net_type = ref_net_type | |
self.ref_detach_shape = ref_detach_shape | |
self.graphcnn_type = graphcnn_type | |
self.isflat_type = isflat_type | |
self.shaperef_type = shaperef_type | |
assert (not self.add_segbps_to_3d_input) or (not self.add_z_to_3d_input) | |
self.n_z = n_z | |
if add_segbps_to_3d_input: | |
self.n_segbps = n_segbps # 64 | |
self.segbps_model = SegBPS() | |
else: | |
self.n_segbps = 0 | |
self.fix_flength = fix_flength | |
self.render_partseg = render_partseg | |
self.structure_z_to_betas = structure_z_to_betas | |
self.structure_pose_net = structure_pose_net | |
assert self.structure_pose_net in ['default', 'vae', 'normflow'] | |
self.nf_version = nf_version | |
self.smal_model_type = smal_model_type | |
assert (smal_keyp_conf is not None) | |
self.smal_keyp_conf = smal_keyp_conf | |
self.register_buffer('betas_zeros', torch.zeros((1, self.n_betas))) | |
self.register_buffer('mean_dog_bone_lengths', torch.tensor(MEAN_DOG_BONE_LENGTHS_NO_RED, dtype=torch.float32)) | |
p_dropout = 0.2 # 0.5 | |
# ------------------------------ SMAL MODEL ------------------------------ | |
self.smal = SMAL(smal_model_type=self.smal_model_type, template_name='neutral') | |
print('SMAL model type: ' + self.smal.smal_model_type) | |
# New for rendering without tail | |
f_np = self.smal.faces.detach().cpu().numpy() | |
self.f_no_tail_np = f_np[np.isin(f_np[:,:], VERTEX_IDS_TAIL).sum(axis=1)==0, :] | |
# in theory we could optimize for improved shapedirs, but we do not do that | |
# -> would need to implement regularizations | |
# -> there are better ways than changing the shapedirs | |
self.model_learnable_shapedirs = LearnableShapedirs(self.smal.sym_ids_dict, self.smal.shapedirs.detach(), self.n_betas, 10) | |
# ------------------------------ STACKED HOUR GLASS ------------------------------ | |
if arch == 'hg8': | |
self.stacked_hourglass = hg8(pretrained=False, num_classes=self.n_classes, num_partseg=self.n_partseg, upsample_seg=self.upsample_seg, add_partseg=self.add_partseg) | |
else: | |
raise Exception('unrecognised model architecture: ' + arch) | |
# ------------------------------ SHAPE AND BREED MODEL ------------------------------ | |
self.breed_model = ModelShapeAndBreed(self.smal_model_type, n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_z=self.n_z, structure_z_to_betas=self.structure_z_to_betas) | |
# ------------------------------ LINEAR 3D MODEL ------------------------------ | |
# 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} | |
self.soft_max = torch.nn.Softmax(dim=1) | |
input_size = self.n_keyp*3 + self.n_bones | |
self.model_3d = LinearModelComplete(linear_size=1024, | |
num_stage_comb=num_stage_comb, | |
num_stage_heads=num_stage_heads, | |
num_stage_heads_pose=num_stage_heads_pose, | |
trans_sep=trans_sep, | |
p_dropout=p_dropout, # 0.5, | |
input_size=input_size, | |
intermediate_size=1024, | |
output_info=None, | |
n_joints=self.n_joints, | |
n_z=self.n_z, | |
add_z_to_3d_input=self.add_z_to_3d_input, | |
n_segbps=self.n_segbps, | |
add_segbps_to_3d_input=self.add_segbps_to_3d_input, | |
structure_pose_net=self.structure_pose_net, | |
nf_version = self.nf_version) | |
# ------------------------------ RENDERING ------------------------------ | |
self.silh_renderer = SilhRenderer(image_size) | |
# ------------------------------ REFINEMENT ----------------------------- | |
self.refinement_model = ModelRefinement(n_betas=self.n_betas, n_betas_limbs=self.n_betas_limbs, n_breeds=self.n_breeds, n_keyp=self.n_keyp, n_joints=self.n_joints, ref_net_type=self.ref_net_type, graphcnn_type=self.graphcnn_type, isflat_type=self.isflat_type, shaperef_type=self.shaperef_type) | |
def forward(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): | |
batch_size = input_img.shape[0] | |
device = input_img.device | |
# ------------------------------ STACKED HOUR GLASS ------------------------------ | |
hourglass_out_dict = self.stacked_hourglass(input_img) | |
last_seg = hourglass_out_dict['seg_final'] | |
last_heatmap = hourglass_out_dict['out_list_kp'][-1] | |
# - prepare keypoints (from heatmap) | |
# normalize predictions -> from logits to probability distribution | |
# last_heatmap_norm = dsnt.spatial_softmax2d(last_heatmap, temperature=torch.tensor(1)) | |
# keypoints = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=False) + 1 # (bs, 20, 2) | |
# keypoints_norm = dsnt.spatial_expectation2d(last_heatmap_norm, normalized_coordinates=True) # (bs, 20, 2) | |
keypoints_norm, scores = get_preds_soft(last_heatmap, return_maxval=True, norm_coords=True) | |
if self.threshold_scores is not None: | |
scores[scores>self.threshold_scores] = 1.0 | |
scores[scores<=self.threshold_scores] = 0.0 | |
# ------------------------------ LEARNABLE SHAPE MODEL ------------------------------ | |
# in our cvpr 2022 paper we do not change the shapedirs | |
# learnable_sd_complete has shape (3889, 3, n_sd) | |
# learnable_sd_complete_prepared has shape (n_sd, 11667) | |
learnable_sd_complete, learnable_sd_complete_prepared = self.model_learnable_shapedirs() | |
shapedirs_sel = learnable_sd_complete_prepared # None | |
# ------------------------------ SHAPE AND BREED MODEL ------------------------------ | |
# breed_model takes as input the image as well as the predicted segmentation map | |
# -> we need to split up ModelImageTo3d, such that we can use the silhouette | |
resnet_output = self.breed_model(img=input_img, seg_raw=last_seg) | |
pred_breed = resnet_output['breeds'] # (bs, n_breeds) | |
pred_z = resnet_output['z'] | |
# - prepare shape | |
pred_betas = resnet_output['betas'] | |
pred_betas_limbs = resnet_output['betas_limbs'] | |
# - calculate bone lengths | |
with torch.no_grad(): | |
use_mean_bone_lengths = False | |
if use_mean_bone_lengths: | |
bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) | |
else: | |
assert (bone_lengths_prepared is None) | |
bone_lengths_prepared = self.smal.caclulate_bone_lengths(pred_betas, pred_betas_limbs, shapedirs_sel=shapedirs_sel, short=True) | |
# ------------------------------ LINEAR 3D MODEL ------------------------------ | |
# 3d model -> from image to 3d parameters {2d keypoints from heatmap, pose, trans, flength} | |
# prepare input for 2d-to-3d network | |
keypoints_prepared = torch.cat((keypoints_norm, scores), axis=2) | |
if bone_lengths_prepared is None: | |
bone_lengths_prepared = torch.cat(batch_size*[self.mean_dog_bone_lengths.reshape((1, -1))]) | |
# should we add silhouette to 3d input? should we add z? | |
if self.add_segbps_to_3d_input: | |
seg_raw = last_seg | |
seg_prep_bps = self.soft_max(seg_raw)[:, 1, :, :] # class 1 is the dog | |
with torch.no_grad(): | |
seg_prep_np = seg_prep_bps.detach().cpu().numpy() | |
bps_output_np = self.segbps_model.calculate_bps_points_batch(seg_prep_np) # (bs, 64, 2) | |
bps_output = torch.tensor(bps_output_np, dtype=torch.float32).to(device).reshape((batch_size, -1)) | |
bps_output_prep = bps_output * 2. - 1 | |
input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) | |
input_vec = torch.cat((input_vec_keyp_bones, bps_output_prep), dim=1) | |
elif self.add_z_to_3d_input: | |
# we do not use this in our cvpr 2022 version | |
input_vec_keyp_bones = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) | |
input_vec_additional = pred_z | |
input_vec = torch.cat((input_vec_keyp_bones, input_vec_additional), dim=1) | |
else: | |
input_vec = torch.cat((keypoints_prepared.reshape((batch_size, -1)), bone_lengths_prepared), axis=1) | |
# predict 3d parameters (those are normalized, we need to correct mean and std in a next step) | |
output = self.model_3d(input_vec) | |
# add predicted keypoints to the output dict | |
output['keypoints_norm'] = keypoints_norm | |
output['keypoints_scores'] = scores | |
# add predicted segmentation to output dictc | |
output['seg_hg'] = hourglass_out_dict['seg_final'] | |
# - denormalize 3d parameters -> so far predictions were normalized, now we denormalize them again | |
pred_trans = output['trans'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) | |
if self.structure_pose_net == 'default': | |
pred_pose_rot6d = output['pose'] + norm_dict['pose_rot6d_mean'][None, :] | |
elif self.structure_pose_net == 'normflow': | |
pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) | |
pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] | |
pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros | |
else: | |
pose_rot6d_mean_zeros = torch.zeros_like(norm_dict['pose_rot6d_mean'][None, :]) | |
pose_rot6d_mean_zeros[:, 0, :] = norm_dict['pose_rot6d_mean'][None, 0, :] | |
pred_pose_rot6d = output['pose'] + pose_rot6d_mean_zeros | |
pred_pose_reshx33 = rot6d_to_rotmat(pred_pose_rot6d.reshape((-1, 6))) | |
pred_pose = pred_pose_reshx33.reshape((batch_size, -1, 3, 3)) | |
pred_pose_rot6d = rotmat_to_rot6d(pred_pose_reshx33).reshape((batch_size, -1, 6)) | |
if self.fix_flength: | |
output['flength'] = torch.zeros_like(output['flength']) | |
pred_flength = torch.ones_like(output['flength'])*2100 # norm_dict['flength_mean'][None, :] | |
else: | |
pred_flength_orig = output['flength'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) | |
pred_flength = pred_flength_orig.clone() # torch.abs(pred_flength_orig) | |
pred_flength[pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] | |
# ------------------------------ RENDERING ------------------------------ | |
# get 3d model (SMAL) | |
V, keyp_green_3d, _ = self.smal(beta=pred_betas, betas_limbs=pred_betas_limbs, pose=pred_pose, trans=pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=shapedirs_sel) | |
keyp_3d = keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) | |
# render silhouette | |
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) | |
if not self.silh_no_tail: | |
pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, | |
points=keyp_3d, faces=faces_prep, focal_lengths=pred_flength) | |
else: | |
faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) | |
pred_silh_images, pred_keyp = self.silh_renderer(vertices=V, | |
points=keyp_3d, faces=faces_no_tail_prep, focal_lengths=pred_flength) | |
# get torch 'Meshes' | |
torch_meshes = self.silh_renderer.get_torch_meshes(vertices=V, faces=faces_prep) | |
# render body parts (not part of cvpr 2022 version) | |
if self.render_partseg: | |
raise NotImplementedError | |
else: | |
partseg_images = None | |
partseg_images_hg = None | |
# ------------------------------ REFINEMENT MODEL ------------------------------ | |
# refinement model | |
pred_keyp_norm = (pred_keyp.detach() / (self.image_size - 1) - 0.5)*2 | |
'''output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \ | |
seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \ | |
in_pose=output['pose'].detach(), in_trans=output['trans'].detach(), in_cam=output['flength'].detach(), in_betas=pred_betas.detach())''' | |
output_ref = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \ | |
seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \ | |
in_pose_3x3=pred_pose.detach(), in_trans_notnorm=output['trans'].detach(), in_cam_notnorm=output['flength'].detach(), in_betas=pred_betas.detach(), in_betas_limbs=pred_betas_limbs.detach()) | |
# a better alternative would be to submit pred_pose_reshx33 | |
# nothing changes for betas or shapedirs or z ##################### should probably not be detached in the end | |
if self.shaperef_type == 'inexistent': | |
if self.ref_detach_shape: | |
output_ref['betas'] = pred_betas.detach() | |
output_ref['betas_limbs'] = pred_betas_limbs.detach() | |
output_ref['z'] = pred_z.detach() | |
output_ref['shapedirs'] = shapedirs_sel.detach() | |
else: | |
output_ref['betas'] = pred_betas | |
output_ref['betas_limbs'] = pred_betas_limbs | |
output_ref['z'] = pred_z | |
output_ref['shapedirs'] = shapedirs_sel | |
else: | |
assert ('betas' in output_ref.keys()) | |
assert ('betas_limbs' in output_ref.keys()) | |
output_ref['shapedirs'] = shapedirs_sel | |
# we denormalize flength and trans, but pose is handled differently | |
if self.fix_flength: | |
output_ref['flength_notnorm'] = torch.zeros_like(output['flength']) | |
ref_pred_flength = torch.ones_like(output['flength_notnorm'])*2100 # norm_dict['flength_mean'][None, :] | |
raise ValueError # not sure if we want to have a fixed flength in refinement | |
else: | |
ref_pred_flength_orig = output_ref['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) | |
ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig) | |
ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] | |
ref_pred_trans = output_ref['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) | |
# ref_pred_pose_rot6d = output_ref['pose'] | |
# ref_pred_pose_reshx33 = rot6d_to_rotmat(output_ref['pose'].reshape((-1, 6))).reshape((batch_size, -1, 3, 3)) | |
ref_pred_pose_reshx33 = output_ref['pose_rotmat'].reshape((batch_size, -1, 3, 3)) | |
ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6)) | |
ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref['betas'], betas_limbs=output_ref['betas_limbs'], | |
pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, | |
shapedirs_sel=output_ref['shapedirs']) | |
ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) | |
if not self.silh_no_tail: | |
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) | |
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, | |
points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength) | |
else: | |
faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) | |
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, | |
points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength) | |
output_ref_unnorm = {'vertices_smal': ref_V, | |
'keyp_3d': ref_keyp_3d, | |
'keyp_2d': ref_pred_keyp, | |
'silh': ref_pred_silh_images, | |
'trans': ref_pred_trans, | |
'flength': ref_pred_flength, | |
'betas': output_ref['betas'], | |
'betas_limbs': output_ref['betas_limbs'], | |
# 'z': output_ref['z'], | |
'pose_rot6d': ref_pred_pose_rot6d, | |
'pose_rotmat': ref_pred_pose_reshx33} | |
# 'shapedirs': shapedirs_sel} | |
if not self.graphcnn_type == 'inexistent': | |
output_ref_unnorm['vertexwise_ground_contact'] = output_ref['vertexwise_ground_contact'] | |
if not self.isflat_type=='inexistent': | |
output_ref_unnorm['isflat'] = output_ref['isflat'] | |
if self.shaperef_type == 'inexistent': | |
output_ref_unnorm['z'] = output_ref['z'] | |
# REMARK: we will want to have the predicted differences, for pose this would | |
# be a rotation matrix, ... | |
# -> TODO: adjust output_orig_ref_comparison | |
output_orig_ref_comparison = {#'pose': output['pose'].detach(), | |
#'trans': output['trans'].detach(), | |
#'flength': output['flength'].detach(), | |
# 'pose': output['pose'], | |
'old_pose_rotmat': pred_pose_reshx33, | |
'old_trans_notnorm': output['trans'], | |
'old_flength_notnorm': output['flength'], | |
# 'ref_pose': output_ref['pose'], | |
'ref_pose_rotmat': ref_pred_pose_reshx33, | |
'ref_trans_notnorm': output_ref['trans_notnorm'], | |
'ref_flength_notnorm': output_ref['flength_notnorm']} | |
# ------------------------------ PREPARE OUTPUT ------------------------------ | |
# create output dictionarys | |
# output: contains all output from model_image_to_3d | |
# output_unnorm: same as output, but normalizations are undone | |
# output_reproj: smal output and reprojected keypoints as well as silhouette | |
keypoints_heatmap_256 = (output['keypoints_norm'] / 2. + 0.5) * (self.image_size - 1) | |
output_unnorm = {'pose_rotmat': pred_pose, | |
'flength': pred_flength, | |
'trans': pred_trans, | |
'keypoints':keypoints_heatmap_256} | |
output_reproj = {'vertices_smal': V, | |
'torch_meshes': torch_meshes, | |
'keyp_3d': keyp_3d, | |
'keyp_2d': pred_keyp, | |
'silh': pred_silh_images, | |
'betas': pred_betas, | |
'betas_limbs': pred_betas_limbs, | |
'pose_rot6d': pred_pose_rot6d, # used for pose prior... | |
'dog_breed': pred_breed, | |
'shapedirs': shapedirs_sel, | |
'z': pred_z, | |
'flength_unnorm': pred_flength, | |
'flength': output['flength'], | |
'partseg_images_rend': partseg_images, | |
'partseg_images_hg_nograd': partseg_images_hg, | |
'normflow_z': output['normflow_z']} | |
return output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison | |
def forward_with_multiple_refinements(self, input_img, norm_dict=None, bone_lengths_prepared=None, betas=None): | |
# import pdb; pdb.set_trace() | |
# run normal network part | |
output, output_unnorm, output_reproj, output_ref_unnorm, output_orig_ref_comparison = self.forward(input_img, norm_dict=norm_dict, bone_lengths_prepared=bone_lengths_prepared, betas=betas) | |
# prepare input for second refinement stage | |
batch_size = output['keypoints_norm'].shape[0] | |
keypoints_norm = output['keypoints_norm'] | |
pred_keyp_norm = (output_ref_unnorm['keyp_2d'].detach() / (self.image_size - 1) - 0.5)*2 | |
last_seg = output['seg_hg'] | |
pred_silh_images = output_ref_unnorm['silh'].detach() | |
trans_notnorm = output_orig_ref_comparison['ref_trans_notnorm'] | |
flength_notnorm = output_orig_ref_comparison['ref_flength_notnorm'] | |
# trans_notnorm = output_orig_ref_comparison['ref_pose_rotmat'] | |
pred_pose = output_ref_unnorm['pose_rotmat'].reshape((batch_size, -1, 3, 3)) | |
# run second refinement step | |
output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, \ | |
seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, \ | |
in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), \ | |
in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach()) | |
# output_ref_new = self.refinement_model(keypoints_norm.detach(), pred_keyp_norm, seg_sh_raw=last_seg[:, :, :, :].detach(), seg_pred_prep=pred_silh_images[:, :, :, :].detach()-0.5, in_pose_3x3=pred_pose.detach(), in_trans_notnorm=trans_notnorm.detach(), in_cam_notnorm=flength_notnorm.detach(), in_betas=output_ref_unnorm['betas'].detach(), in_betas_limbs=output_ref_unnorm['betas_limbs'].detach()) | |
# new shape | |
if self.shaperef_type == 'inexistent': | |
if self.ref_detach_shape: | |
output_ref_new['betas'] = output_ref_unnorm['betas'].detach() | |
output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs'].detach() | |
output_ref_new['z'] = output_ref_unnorm['z'].detach() | |
output_ref_new['shapedirs'] = output_reproj['shapedirs'].detach() | |
else: | |
output_ref_new['betas'] = output_ref_unnorm['betas'] | |
output_ref_new['betas_limbs'] = output_ref_unnorm['betas_limbs'] | |
output_ref_new['z'] = output_ref_unnorm['z'] | |
output_ref_new['shapedirs'] = output_reproj['shapedirs'] | |
else: | |
assert ('betas' in output_ref_new.keys()) | |
assert ('betas_limbs' in output_ref_new.keys()) | |
output_ref_new['shapedirs'] = output_reproj['shapedirs'] | |
# we denormalize flength and trans, but pose is handled differently | |
if self.fix_flength: | |
raise ValueError # not sure if we want to have a fixed flength in refinement | |
else: | |
ref_pred_flength_orig = output_ref_new['flength_notnorm'] * norm_dict['flength_std'][None, :] + norm_dict['flength_mean'][None, :] # (bs, 1) | |
ref_pred_flength = ref_pred_flength_orig.clone() # torch.abs(pred_flength_orig) | |
ref_pred_flength[ref_pred_flength_orig<=0] = norm_dict['flength_mean'][None, :] | |
ref_pred_trans = output_ref_new['trans_notnorm'] * norm_dict['trans_std'][None, :] + norm_dict['trans_mean'][None, :] # (bs, 3) | |
ref_pred_pose_reshx33 = output_ref_new['pose_rotmat'].reshape((batch_size, -1, 3, 3)) | |
ref_pred_pose_rot6d = rotmat_to_rot6d(ref_pred_pose_reshx33.reshape((-1, 3, 3))).reshape((batch_size, -1, 6)) | |
ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'], | |
pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, | |
shapedirs_sel=output_ref_new['shapedirs']) | |
# ref_V, ref_keyp_green_3d, _ = self.smal(beta=output_ref_new['betas'], betas_limbs=output_ref_new['betas_limbs'], pose=ref_pred_pose_reshx33, trans=ref_pred_trans, get_skin=True, keyp_conf=self.smal_keyp_conf, shapedirs_sel=output_ref_new['shapedirs']) | |
ref_keyp_3d = ref_keyp_green_3d[:, :self.n_keyp, :] # (bs, 20, 3) | |
if not self.silh_no_tail: | |
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) | |
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, | |
points=ref_keyp_3d, faces=faces_prep, focal_lengths=ref_pred_flength) | |
else: | |
faces_no_tail_prep = torch.tensor(self.f_no_tail_np).to(device).expand((batch_size, -1, -1)) | |
ref_pred_silh_images, ref_pred_keyp = self.silh_renderer(vertices=ref_V, | |
points=ref_keyp_3d, faces=faces_no_tail_prep, focal_lengths=ref_pred_flength) | |
output_ref_unnorm_new = {'vertices_smal': ref_V, | |
'keyp_3d': ref_keyp_3d, | |
'keyp_2d': ref_pred_keyp, | |
'silh': ref_pred_silh_images, | |
'trans': ref_pred_trans, | |
'flength': ref_pred_flength, | |
'betas': output_ref_new['betas'], | |
'betas_limbs': output_ref_new['betas_limbs'], | |
'pose_rot6d': ref_pred_pose_rot6d, | |
'pose_rotmat': ref_pred_pose_reshx33} | |
if not self.graphcnn_type == 'inexistent': | |
output_ref_unnorm_new['vertexwise_ground_contact'] = output_ref_new['vertexwise_ground_contact'] | |
if not self.isflat_type=='inexistent': | |
output_ref_unnorm_new['isflat'] = output_ref_new['isflat'] | |
if self.shaperef_type == 'inexistent': | |
output_ref_unnorm_new['z'] = output_ref_new['z'] | |
output_orig_ref_comparison_new = {'ref_pose_rotmat': ref_pred_pose_reshx33, | |
'ref_trans_notnorm': output_ref_new['trans_notnorm'], | |
'ref_flength_notnorm': output_ref_new['flength_notnorm']} | |
results = { | |
'output': output, | |
'output_unnorm': output_unnorm, | |
'output_reproj':output_reproj, | |
'output_ref_unnorm': output_ref_unnorm, | |
'output_orig_ref_comparison':output_orig_ref_comparison, | |
'output_ref_unnorm_new': output_ref_unnorm_new, | |
'output_orig_ref_comparison_new': output_orig_ref_comparison_new} | |
return results | |
def render_vis_nograd(self, vertices, focal_lengths, color=0): | |
# this function is for visualization only | |
# vertices: (bs, n_verts, 3) | |
# focal_lengths: (bs, 1) | |
# color: integer, either 0 or 1 | |
# returns a torch tensor of shape (bs, image_size, image_size, 3) | |
with torch.no_grad(): | |
batch_size = vertices.shape[0] | |
faces_prep = self.smal.faces.unsqueeze(0).expand((batch_size, -1, -1)) | |
visualizations = self.silh_renderer.get_visualization_nograd(vertices, | |
faces_prep, focal_lengths, color=color) | |
return visualizations | |