Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import torch.nn.init as init | |
import torchvision.models as models | |
import nvdiffrast.torch as dr | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import os | |
import os.path as osp | |
import pickle | |
from video3d.render.regularizer import get_edge_length, normal_consistency, laplace_regularizer_const | |
from . import networks | |
from .renderer import * | |
from .utils import misc, meters, flow_viz, arap, custom_loss | |
from .dataloaders import get_sequence_loader, get_image_loader | |
from .dataloaders_ddp import get_sequence_loader_ddp, get_image_loader_ddp | |
from .cub_dataloaders import get_cub_loader | |
from .cub_dataloaders_ddp import get_cub_loader_ddp | |
from .utils.skinning_v4 import estimate_bones, skinning | |
import lpips | |
from einops import rearrange, repeat | |
# import clip | |
import torchvision.transforms.functional as tvf | |
from . import discriminator_architecture | |
from .geometry.dmtet import DMTetGeometry | |
from .geometry.dlmesh import DLMesh | |
from .triplane_texture.triplane_predictor import TriPlaneTex | |
from .render import renderutils as ru | |
from .render import material | |
from .render import mlptexture | |
from .render import util | |
from .render import mesh | |
from .render import light | |
from .render import render | |
EPS = 1e-7 | |
def get_optimizer(model, lr=0.0001, betas=(0.9, 0.999), weight_decay=0): | |
return torch.optim.Adam( | |
filter(lambda p: p.requires_grad, model.parameters()), | |
lr=lr, betas=betas, weight_decay=weight_decay) | |
def set_requires_grad(model, requires_grad): | |
if model is not None: | |
for param in model.parameters(): | |
param.requires_grad = requires_grad | |
def forward_to_matrix(vec_forward, up=[0,1,0]): | |
up = torch.FloatTensor(up).to(vec_forward.device) | |
# vec_forward = nn.functional.normalize(vec_forward, p=2, dim=-1) # x right, y up, z forward | |
vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1) | |
vec_right = nn.functional.normalize(vec_right, p=2, dim=-1) | |
vec_up = vec_forward.cross(vec_right, dim=-1) | |
vec_up = nn.functional.normalize(vec_up, p=2, dim=-1) | |
rot_mat = torch.stack([vec_right, vec_up, vec_forward], -2) | |
return rot_mat | |
def sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, pose_xflip_recon=False, input_image_xflip_flag=None, rot_temp_scalar=1., num_hypos=4, naive_probs_iter=2000, best_pose_start_iter=6000, random_sample=True, temp_clip_low = 1., temp_clip_high=100.): | |
rots_pred = poses_raw[..., :num_hypos*4].view(-1, num_hypos, 4) | |
rots_logits = rots_pred[..., 0] # Nx4 | |
# temp = 1 / np.clip(total_iter / 1000 / rot_temp_scalar, 1., 100.) | |
temp = 1 / np.clip(total_iter / 1000 / rot_temp_scalar, temp_clip_low, temp_clip_high) | |
rots_probs = torch.nn.functional.softmax(-rots_logits / temp, dim=1) # N x K | |
# naive_probs = torch.FloatTensor([10] + [1] * (num_hypos - 1)).to(rots_logits.device) | |
naive_probs = torch.ones(num_hypos).to(rots_logits.device) | |
naive_probs = naive_probs / naive_probs.sum() | |
naive_probs_weight = np.clip(1 - (total_iter - naive_probs_iter) / 2000, 0, 1) | |
rots_probs = naive_probs.view(1, num_hypos) * naive_probs_weight + rots_probs * (1 - naive_probs_weight) | |
rots_pred = rots_pred[..., 1:4] | |
trans_pred = poses_raw[..., -3:] | |
best_rot_idx = torch.argmax(rots_probs, dim=1) # N | |
#print("best_rot_idx", best_rot_idx) | |
#print("best_of_best", torch.argmax(rots_probs)) | |
#print("similar 7", torch.zeros_like(best_rot_idx) + 7) | |
#print("similar 2", torch.zeros_like(best_rot_idx) + torch.argmax(rots_probs)) | |
if random_sample: | |
# rand_rot_idx = torch.randint(0, 4, (batch_size * num_frames,), device=poses_raw.device) # N | |
rand_rot_idx = torch.randperm(batch_size * num_frames, device=poses_raw.device) % num_hypos # N | |
# rand_rot_idx = torch.randperm(batch_size, device=poses_raw.device)[:,None].repeat(1, num_frames).view(-1) % 4 # N | |
best_flag = (torch.randperm(batch_size * num_frames, device=poses_raw.device) / (batch_size * num_frames) < np.clip((total_iter - best_pose_start_iter)/2000, 0, 0.8)).long() | |
rand_flag = 1 - best_flag | |
# best_flag = torch.zeros_like(best_rot_idx) | |
rot_idx = best_rot_idx * best_flag + rand_rot_idx * (1 - best_flag) | |
else: | |
rand_flag = torch.zeros_like(best_rot_idx) | |
#rot_idx = torch.full_like(torch.argmax(rots_probs, dim=1), torch.argmax(rots_probs), device=poses_raw.device) | |
rot_idx = best_rot_idx | |
rot_pred = torch.gather(rots_pred, 1, rot_idx[:, None, None].expand(-1, 1, 3))[:, 0] # Nx3 | |
pose_raw = torch.cat([rot_pred, trans_pred], -1) | |
rot_prob = torch.gather(rots_probs, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N | |
rot_logit = torch.gather(rots_logits, 1, rot_idx[:, None].expand(-1, 1))[:, 0] # N | |
if pose_xflip_recon: | |
raise NotImplementedError | |
#up = torch.FloatTensor([0, 1, 0]).to(pose_raw.device) | |
rot_mat = forward_to_matrix(pose_raw[:, :3], up=[0, 1, 0]) | |
pose = torch.cat([rot_mat.view(batch_size * num_frames, -1), pose_raw[:, 3:]], -1) | |
return pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_flag | |
def get_joints_20_bones(bones, aux): | |
# the bones shape is [1, 1, 20, 2, 3] | |
body_bones_to_joints = aux['bones_to_joints'] | |
body_bones = bones[:, :, :len(body_bones_to_joints), :, :] | |
body_joints = torch.empty(bones.shape[0], bones.shape[1], len(body_bones_to_joints) + 1, 3) | |
for i, (a, b) in enumerate(body_bones_to_joints): | |
body_joints[:, :, a, :] = body_bones[:, :, i, 0, :] | |
body_joints[:, :, b, :] = body_bones[:, :, i, 1, :] | |
leg_aux = aux['legs'] | |
all_leg_joints = [] | |
for i in range(len(leg_aux)): | |
leg_bones = bones[:, :, 8+i*3:11+i*3, :, :] | |
leg_joints = torch.empty(bones.shape[0], bones.shape[1], len(leg_aux[i]['leg_bones_to_joints']), 3) | |
for j in range(len(leg_aux[i]['leg_bones_to_joints'])-1): | |
leg_joint_idx_a = leg_aux[i]['leg_bones_to_joints'][j][0] | |
leg_joint_idx_b = leg_aux[i]['leg_bones_to_joints'][j][1] | |
leg_joints[:, :, leg_joint_idx_a, :] = leg_bones[:, :, j, 0, :] | |
leg_joints[:, :, leg_joint_idx_b, :] = leg_bones[:, :, j, 1, :] | |
all_leg_joints.append(leg_joints) | |
all_joints = [body_joints] + all_leg_joints | |
all_joints = torch.cat(all_joints, dim=2) | |
return all_joints | |
def get_20_bones_joints(joints, aux): | |
# the joints shape is [1, 1, 21, 3] | |
body_bones_to_joints = aux['bones_to_joints'] | |
body_bones = [] | |
for a,b in body_bones_to_joints: | |
body_bones += [torch.stack([joints[:, :, a, :], joints[:, :, b, :]], dim=2)] | |
body_bones = torch.stack(body_bones, dim=2) # [1, 1, 8, 2, 3] | |
legs_bones = [] | |
legs_aux = aux['legs'] | |
for i in range(len(legs_aux)): | |
leg_aux = legs_aux[i] | |
leg_bones = [] | |
leg_bones_to_joints = leg_aux['leg_bones_to_joints'] | |
for j in range(len(leg_bones_to_joints)-1): | |
leg_bones += [torch.stack([joints[:, :, 9+i*3+leg_bones_to_joints[j][0], :], joints[:, :, 9+i*3+leg_bones_to_joints[j][1], :]], dim=2)] | |
# the last bone is attached to the body | |
leg_bones += [torch.stack([ | |
body_bones[:, :, leg_aux['body_bone_idx'], 1, :], joints[:, :, 9+i*3+leg_bones_to_joints[-1][1], :] | |
], dim=2)] | |
leg_bones = torch.stack(leg_bones, dim=2) | |
legs_bones.append(leg_bones) | |
bones = torch.cat([body_bones] + legs_bones, dim=2) | |
return bones | |
class FixedDirectionLight(torch.nn.Module): | |
def __init__(self, direction, amb, diff): | |
super(FixedDirectionLight, self).__init__() | |
self.light_dir = direction | |
self.amb = amb | |
self.diff = diff | |
self.is_hacking = not (isinstance(self.amb, float) | |
or isinstance(self.amb, int)) | |
def forward(self, feat): | |
batch_size = feat.shape[0] | |
if self.is_hacking: | |
return torch.concat([self.light_dir, self.amb, self.diff], -1) | |
else: | |
return torch.concat([self.light_dir, torch.FloatTensor([self.amb, self.diff]).to(self.light_dir.device)], -1).expand(batch_size, -1) | |
def shade(self, feat, kd, normal): | |
light_params = self.forward(feat) | |
light_dir = light_params[..., :3][:, None, None, :] | |
int_amb = light_params[..., 3:4][:, None, None, :] | |
int_diff = light_params[..., 4:5][:, None, None, :] | |
shading = (int_amb + int_diff * | |
torch.clamp(util.dot(light_dir, normal), min=0.0)) | |
shaded = shading * kd | |
return shaded, shading | |
class SmoothLoss(nn.Module): | |
def __init__(self, dim=0, smooth_type=None, loss_type="l2"): | |
super(SmoothLoss, self).__init__() | |
self.dim = dim | |
supported_smooth_types = ['mid_frame', 'dislocation', 'avg'] | |
assert smooth_type in supported_smooth_types, f"supported smooth type: {supported_smooth_types}" | |
self.smooth_type = smooth_type | |
supported_loss_types = ['l2', 'mse', 'l1'] | |
assert loss_type in supported_loss_types, f"supported loss type: {supported_loss_types}" | |
self.loss_type = loss_type | |
if self.loss_type in ['l2', 'mse']: | |
self.loss_fn = torch.nn.MSELoss(reduction='mean') | |
elif self.loss_type in ['l1']: | |
self.loss_fn = torch.nn.L1Loss() | |
else: | |
raise NotImplementedError | |
def mid_frame_smooth(self, inputs): | |
nframe = inputs.shape[self.dim] | |
mid_num = (nframe-1) // 2 | |
# from IPython import embed; embed(); | |
mid_frame = torch.index_select(inputs, self.dim, torch.tensor([mid_num], device=inputs.device)) | |
repeat_num = self.get_repeat_num(inputs) | |
smooth = mid_frame.repeat(repeat_num) | |
loss = self.loss_fn(inputs, smooth) | |
# print(loss) | |
return loss | |
def dislocation_smooth(self, inputs): | |
# from IPython import embed; embed() | |
nframe = inputs.shape[self.dim] | |
t = torch.index_select(inputs, self.dim, torch.arange(0, nframe-1).to(inputs.device)) | |
t_1 = torch.index_select(inputs, self.dim, torch.arange(1, nframe).to(inputs.device)) | |
loss = self.loss_fn(t, t_1) | |
return loss | |
def avg_smooth(self, inputs): | |
# nframe = inputs.shape[self.dim] | |
# from IPython import embed; embed() | |
avg = inputs.mean(dim=self.dim, keepdim=True) | |
repeat_num = self.get_repeat_num(inputs) | |
smooth = avg.repeat(repeat_num) | |
loss = self.loss_fn(inputs, smooth) | |
return loss | |
def get_repeat_num(self, inputs): | |
repeat_num = [1] * inputs.dim() | |
repeat_num[self.dim] = inputs.shape[self.dim] | |
return repeat_num | |
def forward(self, inputs): | |
print(f"smooth_type: {self.smooth_type}") | |
if self.smooth_type is None: | |
return 0. | |
elif self.smooth_type == 'mid_frame': | |
return self.mid_frame_smooth(inputs) | |
elif self.smooth_type == 'dislocation': | |
return self.dislocation_smooth(inputs) | |
elif self.smooth_type == 'avg': | |
return self.avg_smooth(inputs) | |
else: | |
raise NotImplementedError() | |
class PriorPredictor(nn.Module): | |
def __init__(self, cfgs): | |
super().__init__() | |
#add nnParameters | |
dmtet_grid = cfgs.get('dmtet_grid', 64) | |
grid_scale = cfgs.get('grid_scale', 5) | |
prior_sdf_mode = cfgs.get('prior_sdf_mode', 'mlp') | |
num_layers_shape = cfgs.get('num_layers_shape', 5) | |
hidden_size = cfgs.get('hidden_size', 64) | |
embedder_freq_shape = cfgs.get('embedder_freq_shape', 8) | |
embed_concat_pts = cfgs.get('embed_concat_pts', True) | |
init_sdf = cfgs.get('init_sdf', None) | |
jitter_grid = cfgs.get('jitter_grid', 0.) | |
perturb_sdf_iter = cfgs.get('perturb_sdf_iter', 10000) | |
sym_prior_shape = cfgs.get('sym_prior_shape', False) | |
train_data_dir = cfgs.get("train_data_dir", None) | |
if isinstance(train_data_dir, str): | |
num_of_classes = 1 | |
elif isinstance(train_data_dir, dict): | |
self.category_id_map = {} | |
num_of_classes = len(train_data_dir) | |
for i, (k, _) in enumerate(train_data_dir.items()): | |
self.category_id_map[k] = i | |
dim_of_classes = cfgs.get('dim_of_classes', 256) if num_of_classes > 1 else 0 | |
condition_choice = cfgs.get('prior_condition_choice', 'concat') | |
self.netShape = DMTetGeometry(dmtet_grid, grid_scale, prior_sdf_mode, num_layers=num_layers_shape, hidden_size=hidden_size, embedder_freq=embedder_freq_shape, embed_concat_pts=embed_concat_pts, init_sdf=init_sdf, jitter_grid=jitter_grid, perturb_sdf_iter=perturb_sdf_iter, sym_prior_shape=sym_prior_shape, | |
dim_of_classes=dim_of_classes, condition_choice=condition_choice) | |
mlp_hidden_size = cfgs.get('hidden_size', 64) | |
tet_bbox = self.netShape.getAABB() | |
self.render_dino_mode = cfgs.get('render_dino_mode', None) | |
num_layers_dino = cfgs.get("num_layers_dino", 5) | |
dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64) | |
sym_dino = cfgs.get("sym_dino", False) | |
dino_min = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_min', 0.) | |
dino_max = torch.zeros(dino_feature_recon_dim) + cfgs.get('dino_max', 1.) | |
min_max = torch.stack((dino_min, dino_max), dim=0) | |
if self.render_dino_mode is None: | |
pass | |
elif self.render_dino_mode == 'feature_mlpnv': | |
#MLPTexture3D predict the dino for each single point. | |
self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_feature_recon_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=min_max, bsdf=None, perturb_normal=False, symmetrize=sym_dino) | |
elif self.render_dino_mode == 'feature_mlp': | |
embedder_scaler = 2 * np.pi / grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 | |
embed_concat_pts = cfgs.get('embed_concat_pts', True) | |
self.netDINO = networks.MLPTextureSimple( | |
3, # x, y, z coordinates | |
dino_feature_recon_dim, | |
num_layers_dino, | |
nf=mlp_hidden_size, | |
dropout=0, | |
activation="sigmoid", | |
min_max=min_max, | |
n_harmonic_functions=cfgs.get('embedder_freq_dino', 8), | |
omega0=embedder_scaler, | |
extra_dim=dim_of_classes, | |
embed_concat_pts=embed_concat_pts, | |
perturb_normal=False, | |
symmetrize=sym_dino | |
) | |
elif self.render_dino_mode == 'cluster': | |
num_layers_dino = cfgs.get("num_layers_dino", 5) | |
dino_cluster_dim = cfgs.get('dino_cluster_dim', 64) | |
self.netDINO = mlptexture.MLPTexture3D(tet_bbox, channels=dino_cluster_dim, internal_dims=mlp_hidden_size, hidden=num_layers_dino-1, feat_dim=0, min_max=None, bsdf=None, perturb_normal=False, symmetrize=sym_dino) | |
else: | |
raise NotImplementedError | |
self.classes_vectors = None | |
if num_of_classes > 1: | |
self.classes_vectors = torch.nn.Parameter(torch.nn.init.uniform_(torch.empty(num_of_classes, dim_of_classes), a=-0.05, b=0.05)) | |
def forward(self, category_name=None, perturb_sdf=False, total_iter=None, is_training=True, class_embedding=None): | |
class_vector = None | |
if category_name is not None: | |
# print(category_name) | |
if class_embedding is not None: | |
class_vector = class_embedding[0] # [128] | |
return_classes_vectors = class_vector | |
else: | |
class_vector = self.classes_vectors[self.category_id_map[category_name]] | |
return_classes_vectors = self.classes_vectors | |
prior_shape = self.netShape.getMesh(perturb_sdf=perturb_sdf, total_iter=total_iter, jitter_grid=is_training, class_vector=class_vector) | |
# print(prior_shape.v_pos.shape) | |
# return prior_shape, self.netDINO, self.classes_vectors | |
return prior_shape, self.netDINO, return_classes_vectors | |
class InstancePredictor(nn.Module): | |
def __init__(self, cfgs, tet_bbox=None): | |
super().__init__() | |
self.cfgs = cfgs | |
self.grid_scale = cfgs.get('grid_scale', 5) | |
self.enable_encoder = cfgs.get('enable_encoder', False) | |
if self.enable_encoder: | |
encoder_latent_dim = cfgs.get('latent_dim', 256) | |
encoder_pretrained = cfgs.get('encoder_pretrained', False) | |
encoder_frozen = cfgs.get('encoder_frozen', False) | |
encoder_arch = cfgs.get('encoder_arch', 'simple') | |
in_image_size = cfgs.get('in_image_size', 256) | |
self.dino_feature_input = cfgs.get('dino_feature_input', False) | |
dino_feature_dim = cfgs.get('dino_feature_dim', 64) | |
if encoder_arch == 'simple': | |
if self.dino_feature_input: | |
self.netEncoder = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None) | |
else: | |
self.netEncoder = networks.Encoder(cin=3, cout=encoder_latent_dim, in_size=in_image_size, zdim=None, nf=64, activation=None) | |
elif encoder_arch == 'vgg': | |
self.netEncoder = networks.VGGEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained) | |
elif encoder_arch == 'resnet': | |
self.netEncoder = networks.ResnetEncoder(cout=encoder_latent_dim, pretrained=encoder_pretrained) | |
elif encoder_arch == 'vit': | |
which_vit = cfgs.get('which_vit', 'dino_vits8') | |
vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv') | |
root_dir = cfgs.get('root_dir', '/root') | |
self.netEncoder = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type, root=root_dir) | |
else: | |
raise NotImplementedError | |
else: | |
encoder_latent_dim = 0 | |
mlp_hidden_size = cfgs.get('hidden_size', 64) | |
bsdf = cfgs.get("bsdf", 'diffuse') | |
num_layers_tex = cfgs.get("num_layers_tex", 5) | |
feat_dim = cfgs.get("latent_dim", 64) if self.enable_encoder else 0 | |
perturb_normal = cfgs.get("perturb_normal", False) | |
sym_texture = cfgs.get("sym_texture", False) | |
kd_min = torch.FloatTensor(cfgs.get('kd_min', [0., 0., 0., 0.])) | |
kd_max = torch.FloatTensor(cfgs.get('kd_max', [1., 1., 1., 1.])) | |
ks_min = torch.FloatTensor(cfgs.get('ks_min', [0., 0., 0.])) | |
ks_max = torch.FloatTensor(cfgs.get('ks_max', [0., 0., 0.])) | |
nrm_min = torch.FloatTensor(cfgs.get('nrm_min', [-1., -1., 0.])) | |
nrm_max = torch.FloatTensor(cfgs.get('nrm_max', [1., 1., 1.])) | |
mlp_min = torch.cat((kd_min[0:3], ks_min, nrm_min), dim=0) | |
mlp_max = torch.cat((kd_max[0:3], ks_max, nrm_max), dim=0) | |
min_max = torch.stack((mlp_min, mlp_max), dim=0) | |
out_chn = 9 | |
# TODO: if the tet verts are deforming, we need to recompute tet_bbox | |
texture_mode = cfgs.get("texture_mode", 'mlp') | |
if texture_mode == 'mlpnv': | |
self.netTexture = mlptexture.MLPTexture3D(tet_bbox, channels=out_chn, internal_dims=mlp_hidden_size, hidden=num_layers_tex-1, feat_dim=feat_dim, min_max=min_max, bsdf=bsdf, perturb_normal=perturb_normal, symmetrize=sym_texture) | |
elif texture_mode == 'mlp': | |
embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 | |
embed_concat_pts = cfgs.get('embed_concat_pts', True) | |
self.texture_way = cfgs.get('texture_way', None) | |
if self.texture_way is None: | |
texture_act = cfgs.get('texture_act', 'relu') | |
texture_bias = cfgs.get('texture_bias', False) | |
self.netTexture = networks.MLPTextureSimple( | |
3, # x, y, z coordinates | |
out_chn, | |
num_layers_tex, | |
nf=mlp_hidden_size, | |
dropout=0, | |
activation="sigmoid", | |
min_max=min_max, | |
n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), | |
omega0=embedder_scaler, | |
extra_dim=feat_dim, | |
embed_concat_pts=embed_concat_pts, | |
perturb_normal=perturb_normal, | |
symmetrize=sym_texture, | |
texture_act=texture_act, | |
linear_bias=texture_bias | |
) | |
else: | |
self.netTexture = networks.MLPTextureTriplane( | |
3, # x, y, z coordinates | |
out_chn, | |
num_layers_tex, | |
nf=mlp_hidden_size, | |
dropout=0, | |
activation="sigmoid", | |
min_max=min_max, | |
n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), | |
omega0=embedder_scaler, | |
extra_dim=feat_dim, | |
embed_concat_pts=embed_concat_pts, | |
perturb_normal=perturb_normal, | |
symmetrize=sym_texture, | |
texture_act='relu', | |
linear_bias=False, | |
cam_pos_z_offset=cfgs.get('cam_pos_z_offset', 10.), | |
grid_scale=self.grid_scale | |
) | |
# if 'lift' in self.texture_way: | |
# # GET3D use global feature to get a tri-plane | |
# self.netTexture = TriPlaneTex( | |
# w_dim=512, | |
# img_channels=out_chn, | |
# tri_plane_resolution=256, | |
# device=cfgs.get('device', 'cpu'), | |
# mlp_latent_channel=32, | |
# n_implicit_layer=1, | |
# feat_dim=256, | |
# n_mapping_layer=8, | |
# sym_texture=sym_texture, | |
# grid_scale=self.grid_scale, | |
# min_max=min_max, | |
# perturb_normal=perturb_normal | |
# ) | |
# # # project the local feature map into a grid | |
# # self.netTexture = networks.LiftTexture( | |
# # 3, # x, y, z coordinates | |
# # out_chn, | |
# # num_layers_tex, | |
# # nf=mlp_hidden_size, | |
# # dropout=0, | |
# # activation="sigmoid", | |
# # min_max=min_max, | |
# # n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), | |
# # omega0=embedder_scaler, | |
# # extra_dim=feat_dim, | |
# # embed_concat_pts=embed_concat_pts, | |
# # perturb_normal=perturb_normal, | |
# # symmetrize=sym_texture, | |
# # texture_way=self.texture_way, | |
# # cam_pos_z_offset=cfgs.get('cam_pos_z_offset', 10.), | |
# # grid_scale=self.grid_scale, | |
# # local_feat_dim=cfgs.get("lift_local_feat_dim", 128), | |
# # grid_size=cfgs.get("lift_grid_size", 32), | |
# # optim_latent=cfgs.get("lift_optim_latent", False) | |
# # ) | |
# else: | |
# # a texture mlp with local feature map from patch_out | |
# self.netTexture = networks.MLPTextureLocal( | |
# 3, # x, y, z coordinates | |
# out_chn, | |
# num_layers_tex, | |
# nf=mlp_hidden_size, | |
# dropout=0, | |
# activation="sigmoid", | |
# min_max=min_max, | |
# n_harmonic_functions=cfgs.get('embedder_freq_tex', 10), | |
# omega0=embedder_scaler, | |
# extra_dim=feat_dim, | |
# embed_concat_pts=embed_concat_pts, | |
# perturb_normal=perturb_normal, | |
# symmetrize=sym_texture, | |
# texture_way=self.texture_way, | |
# larger_tex_dim=cfgs.get('larger_tex_dim', False), | |
# cam_pos_z_offset=cfgs.get('cam_pos_z_offset', 10.), | |
# grid_scale=self.grid_scale | |
# ) | |
self.rot_rep = cfgs.get('rot_rep', 'euler_angle') | |
self.enable_pose = cfgs.get('enable_pose', False) | |
if self.enable_pose: | |
cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) | |
fov = cfgs.get('crop_fov_approx', 25) | |
half_range = np.tan(fov /2 /180 * np.pi) * cam_pos_z_offset # 2.22 | |
self.max_trans_xy_range = half_range * cfgs.get('max_trans_xy_range_ratio', 1.) | |
self.max_trans_z_range = half_range * cfgs.get('max_trans_z_range_ratio', 1.) | |
self.lookat_init = cfgs.get('lookat_init', None) | |
self.lookat_zeroy = cfgs.get('lookat_zeroy', False) | |
self.rot_temp_scalar = cfgs.get('rot_temp_scalar', 1.) | |
self.naive_probs_iter = cfgs.get('naive_probs_iter', 2000) | |
self.best_pose_start_iter = cfgs.get('best_pose_start_iter', 6000) | |
if self.rot_rep == 'euler_angle': | |
pose_cout = 6 | |
elif self.rot_rep == 'quaternion': | |
pose_cout = 7 | |
elif self.rot_rep == 'lookat': | |
pose_cout = 6 | |
elif self.rot_rep == 'quadlookat': | |
self.num_pose_hypos = 4 | |
pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 4 quadrants, 4 quadrant classification logits, 3 for translation | |
self.orthant_signs = torch.FloatTensor([[1,1,1], [-1,1,1], [-1,1,-1], [1,1,-1]]) | |
elif self.rot_rep == 'octlookat': | |
self.num_pose_hypos = 8 | |
pose_cout = (3 + 1) * self.num_pose_hypos + 3 # 4 forward vectors for 8 octants, 8 octant classification logits, 3 for translation | |
self.orthant_signs = torch.stack(torch.meshgrid([torch.arange(1, -2, -2)] *3), -1).view(-1, 3) # 8x3 | |
else: | |
raise NotImplementedError | |
self.pose_arch = cfgs.get('pose_arch', 'mlp') | |
if self.pose_arch == 'mlp': | |
num_layers_pose = cfgs.get('num_layers_pose', 5) | |
self.netPose = networks.MLP( | |
encoder_latent_dim, | |
pose_cout, | |
num_layers_pose, | |
nf=mlp_hidden_size, | |
dropout=0, | |
activation=None | |
) | |
elif self.pose_arch == 'encoder': | |
if self.dino_feature_input: | |
dino_feature_dim = cfgs.get('dino_feature_dim', 64) | |
self.netPose = networks.EncoderWithDINO(cin_rgb=3, cin_dino=dino_feature_dim, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None) | |
else: | |
self.netPose = networks.Encoder(cin=3, cout=pose_cout, in_size=in_image_size, zdim=None, nf=64, activation=None) | |
elif self.pose_arch in ['encoder_dino_patch_out', 'encoder_dino_patch_key']: | |
if which_vit == 'dino_vits8': | |
dino_feat_dim = 384 | |
elif which_vit == 'dinov2_vits14': | |
dino_feat_dim = 384 | |
elif which_vit == 'dino_vitb8': | |
dino_feat_dim = 768 | |
self.netPose = networks.Encoder32(cin=dino_feat_dim, cout=pose_cout, nf=256, activation=None) | |
elif self.pose_arch == 'vit': | |
encoder_pretrained = cfgs.get('encoder_pretrained', False) | |
encoder_frozen = cfgs.get('encoder_frozen', False) | |
which_vit = cfgs.get('which_vit', 'dino_vits8') | |
vit_final_layer_type = cfgs.get('vit_final_layer_type', 'conv') | |
root_dir = cfgs.get('root_dir', '/root') | |
self.netPose = networks.ViTEncoder(cout=encoder_latent_dim, which_vit=which_vit, pretrained=encoder_pretrained, frozen=encoder_frozen, in_size=in_image_size, final_layer_type=vit_final_layer_type, root=root_dir) | |
else: | |
raise NotImplementedError | |
self.enable_deform = cfgs.get('enable_deform', False) | |
if self.enable_deform: | |
embedder_scaler = 2 * np.pi / self.grid_scale * 0.9 # originally (-0.5*s, 0.5*s) rescale to (-pi, pi) * 0.9 | |
embed_concat_pts = cfgs.get('embed_concat_pts', True) | |
num_layers_deform = cfgs.get('num_layers_deform', 5) | |
self.deform_epochs = np.arange(*cfgs.get('deform_epochs', [0, 0])) | |
sym_deform = cfgs.get("sym_deform", False) | |
self.netDeform = networks.MLPWithPositionalEncoding( | |
3, # x, y, z coordinates | |
3, # dx, dy, dz deformation | |
num_layers_deform, | |
nf=mlp_hidden_size, | |
dropout=0, | |
activation=None, | |
n_harmonic_functions=cfgs.get('embedder_freq_deform', 10), | |
omega0=embedder_scaler, | |
extra_dim=encoder_latent_dim, | |
embed_concat_pts=embed_concat_pts, | |
symmetrize=sym_deform | |
) | |
# self.avg_deform = cfgs.get('avg_deform', False) | |
# print(f'********avg_deform: {self.avg_deform}********') | |
self.enable_articulation = cfgs.get('enable_articulation', False) | |
if self.enable_articulation: | |
self.num_body_bones = cfgs.get('num_body_bones', 4) | |
self.articulation_multiplier = cfgs.get('articulation_multiplier', 1) | |
self.static_root_bones = cfgs.get('static_root_bones', False) | |
self.skinning_temperature = cfgs.get('skinning_temperature', 1) | |
self.articulation_epochs = np.arange(*cfgs.get('articulation_epochs', [0, 0])) | |
self.num_legs = cfgs.get('num_legs', 0) | |
self.num_leg_bones = cfgs.get('num_leg_bones', 0) | |
self.body_bones_type = cfgs.get('body_bones_type', 'z_minmax') | |
self.perturb_articulation_epochs = np.arange(*cfgs.get('perturb_articulation_epochs', [0, 0])) | |
self.num_bones = self.num_body_bones + self.num_legs * self.num_leg_bones | |
self.constrain_legs = cfgs.get('constrain_legs', False) | |
self.attach_legs_to_body_epochs = np.arange(*cfgs.get('attach_legs_to_body_epochs', [0, 0])) | |
self.max_arti_angle = cfgs.get('max_arti_angle', 60) | |
num_layers_arti = cfgs.get('num_layers_arti', 5) | |
which_vit = cfgs.get('which_vit', 'dino_vits8') | |
if which_vit == 'dino_vits8': | |
dino_feat_dim = 384 | |
elif which_vit == 'dino_vitb8': | |
dino_feat_dim = 768 | |
self.articulation_arch = cfgs.get('articulation_arch', 'mlp') | |
self.articulation_feature_mode = cfgs.get('articulation_feature_mode', 'sample') | |
embedder_freq_arti = cfgs.get('embedder_freq_arti', 8) | |
if self.articulation_feature_mode == 'global': | |
feat_dim = encoder_latent_dim | |
elif self.articulation_feature_mode == 'sample': | |
feat_dim = dino_feat_dim | |
elif self.articulation_feature_mode == 'sample+global': | |
feat_dim = encoder_latent_dim + dino_feat_dim | |
if self.articulation_feature_mode == 'attention': | |
arti_feat_attn_zdim = cfgs.get('arti_feat_attn_zdim', 128) | |
pos_dim = 1 + 2 + 3*2 | |
self.netFeatureAttn = networks.FeatureAttention(which_vit, pos_dim, embedder_freq_arti, arti_feat_attn_zdim, img_size=in_image_size) | |
embedder_scaler = np.pi * 0.9 # originally (-1, 1) rescale to (-pi, pi) * 0.9 | |
enable_articulation_idadd = cfgs.get('enable_articulation_idadd', False) | |
self.netArticulation = networks.ArticulationNetwork(self.articulation_arch, feat_dim, 1+2+3*2, num_layers_arti, mlp_hidden_size, n_harmonic_functions=embedder_freq_arti, omega0=embedder_scaler, | |
enable_articulation_idadd=enable_articulation_idadd) | |
self.kinematic_tree_epoch = -1 | |
self.enable_lighting = cfgs.get('enable_lighting', False) | |
if self.enable_lighting: | |
num_layers_light = cfgs.get('num_layers_light', 5) | |
amb_diff_min = torch.FloatTensor(cfgs.get('amb_diff_min', [0., 0.])) | |
amb_diff_max = torch.FloatTensor(cfgs.get('amb_diff_max', [1., 1.])) | |
intensity_min_max = torch.stack((amb_diff_min, amb_diff_max), dim=0) | |
self.netLight = light.DirectionalLight(encoder_latent_dim, num_layers_light, mlp_hidden_size, intensity_min_max=intensity_min_max) | |
self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) | |
self.crop_fov_approx = cfgs.get("crop_fov_approx", 25) | |
self.temp_clip_low = cfgs.get('temp_clip_low', 1.) | |
self.temp_clip_high = cfgs.get('temp_clip_high', 100.) | |
# if the articulation and deformation is set as iterations, then use iteration to decide, not epoch | |
self.iter_articulation_start = cfgs.get('iter_articulation_start', None) | |
self.iter_deformation_start = cfgs.get('iter_deformation_start', None) | |
self.iter_nozeroy_start = cfgs.get('iter_nozeroy_start', None) | |
self.iter_attach_leg_to_body_start = cfgs.get('iter_attach_leg_to_body_start', None) | |
def forward_encoder(self, images, dino_features=None): | |
images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) | |
patch_out = patch_key = None | |
if self.dino_feature_input and self.cfgs.get('encoder_arch', 'simple') != 'vit': | |
dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1) | |
feat_out = self.netEncoder(images_in, dino_features_in) # Shape: (B, latent_dim) | |
elif self.cfgs.get('encoder_arch', 'simple') == 'vit': | |
feat_out, feat_key, patch_out, patch_key = self.netEncoder(images_in, return_patches=True) | |
else: | |
feat_out = self.netEncoder(images_in) # Shape: (B, latent_dim) | |
return feat_out, feat_key, patch_out, patch_key | |
def forward_pose(self, images, feat, patch_out, patch_key, dino_features): | |
if self.pose_arch == 'mlp': | |
pose = self.netPose(feat) | |
elif self.pose_arch == 'encoder': | |
images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) | |
if self.dino_feature_input: | |
dino_features_in = dino_features.view(-1, *dino_features.shape[2:]) * 2 - 1 # rescale to (-1, 1) | |
pose = self.netPose(images_in, dino_features_in) # Shape: (B, latent_dim) | |
else: | |
pose = self.netPose(images_in) # Shape: (B, latent_dim) | |
elif self.pose_arch == 'vit': | |
images_in = images.view(-1, *images.shape[2:]) * 2 - 1 # rescale to (-1, 1) | |
pose = self.netPose(images_in) | |
elif self.pose_arch == 'encoder_dino_patch_out': | |
pose = self.netPose(patch_out) # Shape: (B, latent_dim) | |
elif self.pose_arch == 'encoder_dino_patch_key': | |
pose = self.netPose(patch_key) # Shape: (B, latent_dim) | |
else: | |
raise NotImplementedError | |
trans_pred = pose[...,-3:].tanh() * torch.FloatTensor([self.max_trans_xy_range, self.max_trans_xy_range, self.max_trans_z_range]).to(pose.device) | |
if self.rot_rep == 'euler_angle': | |
multiplier = 1. | |
if self.gradually_expand_yaw: | |
# multiplier += (min(iteration, 20000) // 500) * 0.25 | |
multiplier *= 1.2 ** (min(iteration, 20000) // 500) # 1.125^40 = 111.200 | |
rot_pred = torch.cat([pose[...,:1], pose[...,1:2]*multiplier, pose[...,2:3]], -1).tanh() | |
rot_pred = rot_pred * torch.FloatTensor([self.max_rot_x_range, self.max_rot_y_range, self.max_rot_z_range]).to(pose.device) /180 * np.pi | |
elif self.rot_rep == 'quaternion': | |
quat_init = torch.FloatTensor([0.01,0,0,0]).to(pose.device) | |
rot_pred = pose[...,:4] + quat_init | |
rot_pred = nn.functional.normalize(rot_pred, p=2, dim=-1) | |
# rot_pred = torch.cat([rot_pred[...,:1].abs(), rot_pred[...,1:]], -1) # make real part non-negative | |
rot_pred = rot_pred * rot_pred[...,:1].sign() # make real part non-negative | |
elif self.rot_rep == 'lookat': | |
vec_forward_raw = pose[...,:3] | |
if self.lookat_init is not None: | |
vec_forward_raw = vec_forward_raw + torch.FloatTensor(self.lookat_init).to(pose.device) | |
if self.lookat_zeroy: | |
vec_forward_raw = vec_forward_raw * torch.FloatTensor([1,0,1]).to(pose.device) | |
vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward | |
rot_pred = vec_forward_raw | |
elif self.rot_rep in ['quadlookat', 'octlookat']: | |
rots_pred = pose[..., :self.num_pose_hypos*4].view(-1, self.num_pose_hypos, 4) # (B, T, K, 4) | |
rots_logits = rots_pred[..., :1] | |
vec_forward_raw = rots_pred[..., 1:4] | |
xs, ys, zs = vec_forward_raw.unbind(-1) | |
margin = 0. | |
xs = nn.functional.softplus(xs, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5 | |
if self.rot_rep == 'octlookat': | |
ys = nn.functional.softplus(ys, beta=np.log(2)/(0.5+margin)) - margin # initialize to 0.5 | |
if self.lookat_zeroy: | |
ys = ys * 0 | |
zs = nn.functional.softplus(zs, beta=2*np.log(2)) # initialize to 0.5 | |
vec_forward_raw = torch.stack([xs, ys, zs], -1) | |
vec_forward_raw = vec_forward_raw * self.orthant_signs.to(pose.device) | |
vec_forward_raw = nn.functional.normalize(vec_forward_raw, p=2, dim=-1) # x right, y up, z forward | |
rot_pred = torch.cat([rots_logits, vec_forward_raw], -1).view(-1, self.num_pose_hypos*4) | |
else: | |
raise NotImplementedError | |
pose = torch.cat([rot_pred, trans_pred], -1) | |
return pose | |
def forward_deformation(self, shape, feat=None, batch_size=None, num_frames=None): | |
original_verts = shape.v_pos | |
num_verts = original_verts.shape[1] | |
if feat is not None: | |
deform_feat = feat[:, None, :].repeat(1, num_verts, 1) # Shape: (B, num_verts, latent_dim) | |
original_verts = original_verts.repeat(len(feat),1,1) | |
deformation = self.netDeform(original_verts, deform_feat) * 0.1 # Shape: (B, num_verts, 3) | |
# if self.avg_deform: | |
# assert batch_size is not None and num_frames is not None | |
# assert deformation.shape[0] == batch_size * num_frames | |
# deformation = deformation.view(batch_size, num_frames, *deformation.shape[1:]) | |
# deformation = deformation.mean(dim=1, keepdim=True) | |
# deformation = deformation.repeat(1,num_frames,*[1]*(deformation.dim()-2)) | |
# deformation = deformation.view(batch_size*num_frames, *deformation.shape[2:]) | |
shape = shape.deform(deformation) | |
return shape, deformation | |
def forward_articulation(self, shape, feat, patch_feat, mvp, w2c, batch_size, num_frames, epoch, category, total_iter=None): | |
""" | |
Forward propagation of articulation. For each bone, the network takes: 1) the 3D location of the bone; 2) the feature of the patch which | |
the bone is projected to; and 3) an encoding of the bone's index to predict the bone's rotation (represented by an Euler angle). | |
Args: | |
shape: a Mesh object, whose v_pos has batch size BxF or 1. | |
feat: the feature of the patches. Shape: (BxF, feat_dim, num_patches_per_axis, num_patches_per_axis) | |
mvp: the model-view-projection matrix. Shape: (BxF, 4, 4) | |
Returns: | |
shape: a Mesh object, whose v_pos has batch size BxF (collapsed). | |
articulation_angles: the predicted bone rotations. Shape: (B, F, num_bones, 3) | |
aux: a dictionary containing auxiliary information. | |
""" | |
verts = shape.v_pos | |
if len(verts) == 1: | |
verts = verts[None] | |
else: | |
verts = verts.view(batch_size, num_frames, *verts.shape[1:]) | |
if self.kinematic_tree_epoch != epoch: | |
# if (epoch == self.articulation_epochs[0]) and (self.kinematic_tree_epoch != epoch): | |
# if (epoch in [self.articulation_epochs[0], self.articulation_epochs[0]+2, self.articulation_epochs[0]+4]) and (self.kinematic_tree_epoch != epoch): | |
if total_iter is not None and self.iter_attach_leg_to_body_start is not None: | |
attach_legs_to_body = total_iter > self.iter_attach_leg_to_body_start | |
else: | |
attach_legs_to_body = epoch in self.attach_legs_to_body_epochs | |
# bone_y_thresh = None if category is None or not category == "giraffe" else 0.1 | |
bone_y_thresh = self.cfgs.get('bone_y_thresh', None) | |
# trivial set here | |
body_bone_idx_preset_cfg = self.cfgs.get('body_bone_idx_preset', [0, 0, 0, 0]) | |
if isinstance(body_bone_idx_preset_cfg, list): | |
body_bone_idx_preset = body_bone_idx_preset_cfg | |
elif isinstance(body_bone_idx_preset_cfg, dict): | |
iter_point = list(body_bone_idx_preset_cfg.keys())[1] | |
if total_iter <= iter_point: | |
body_bone_idx_preset = body_bone_idx_preset_cfg[0] # the first is start from 0 iter | |
else: | |
body_bone_idx_preset = body_bone_idx_preset_cfg[iter_point] | |
else: | |
raise NotImplementedError | |
bones, self.kinematic_tree, self.bone_aux = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=True, attach_legs_to_body=attach_legs_to_body, bone_y_threshold=bone_y_thresh, body_bone_idx_preset=body_bone_idx_preset) | |
# self.kinematic_tree_epoch = epoch | |
else: | |
bones = estimate_bones(verts.detach(), self.num_body_bones, n_legs=self.num_legs, n_leg_bones=self.num_leg_bones, body_bones_type=self.body_bones_type, compute_kinematic_chain=False, aux=self.bone_aux) | |
bones_pos = bones # Shape: (B, F, K, 2, 3) | |
if batch_size > bones_pos.shape[0] or num_frames > bones_pos.shape[1]: | |
assert bones_pos.shape[0] == 1 and bones_pos.shape[1] == 1, "If there is a mismatch, then there must be only one canonical mesh." | |
bones_pos = bones_pos.repeat(batch_size, num_frames, 1, 1, 1) | |
num_bones = bones_pos.shape[2] | |
bones_pos = bones_pos.view(batch_size*num_frames, num_bones, 2, 3) # NxKx2x3 | |
bones_mid_pos = bones_pos.mean(2) # NxKx3 | |
bones_idx = torch.arange(num_bones).to(bones_pos.device) | |
bones_mid_pos_world4 = torch.cat([bones_mid_pos, torch.ones_like(bones_mid_pos[..., :1])], -1) # NxKx4 | |
bones_mid_pos_clip4 = bones_mid_pos_world4 @ mvp.transpose(-1, -2) | |
bones_mid_pos_uv = bones_mid_pos_clip4[..., :2] / bones_mid_pos_clip4[..., 3:4] | |
bones_mid_pos_uv = bones_mid_pos_uv.detach() | |
bones_pos_world4 = torch.cat([bones_pos, torch.ones_like(bones_pos[..., :1])], -1) # NxKx2x4 | |
bones_pos_cam4 = bones_pos_world4 @ w2c[:,None].transpose(-1, -2) | |
bones_pos_cam3 = bones_pos_cam4[..., :3] / bones_pos_cam4[..., 3:4] | |
bones_pos_cam3 = bones_pos_cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(bones_pos_cam3.device).view(1, 1, 1, 3) | |
bones_pos_in = bones_pos_cam3.view(batch_size*num_frames, num_bones, 2*3) / self.grid_scale * 2 # (-1, 1), NxKx(2*3) | |
bones_idx_in = ((bones_idx[None, :, None] + 0.5) / num_bones * 2 - 1).repeat(batch_size * num_frames, 1, 1) # (-1, 1) | |
bones_pos_in = torch.cat([bones_mid_pos_uv, bones_pos_in, bones_idx_in], -1).detach() | |
if self.articulation_feature_mode == 'global': | |
bones_patch_features = feat[:, None].repeat(1, num_bones, 1) # (BxF, K, feat_dim) | |
elif self.articulation_feature_mode == 'sample': | |
bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim) | |
elif self.articulation_feature_mode == 'sample+global': | |
bones_patch_features = F.grid_sample(patch_feat, bones_mid_pos_uv.view(batch_size * num_frames, 1, -1, 2), mode='bilinear').squeeze(dim=-2).permute(0, 2, 1) # (BxF, K, feat_dim) | |
bones_patch_features = torch.cat([feat[:, None].repeat(1, num_bones, 1), bones_patch_features], -1) | |
elif self.articulation_feature_mode == 'attention': | |
bones_patch_features = self.netFeatureAttn(bones_pos_in, patch_feat) | |
else: | |
raise NotImplementedError | |
articulation_angles = self.netArticulation(bones_patch_features, bones_pos_in).view(batch_size, num_frames, num_bones, 3) * self.articulation_multiplier | |
if self.static_root_bones: | |
root_bones = [self.num_body_bones // 2 - 1, self.num_body_bones - 1] | |
tmp_mask = torch.ones_like(articulation_angles) | |
tmp_mask[:, :, root_bones] = 0 | |
articulation_angles = articulation_angles * tmp_mask | |
articulation_angles = articulation_angles.tanh() | |
if self.cfgs.get('iter_leg_rotation_start', -1) > 0: | |
if total_iter <= self.cfgs.get('iter_leg_rotation_start', -1): | |
self.constrain_legs = True | |
else: | |
self.constrain_legs = False | |
if self.constrain_legs: | |
leg_bones_posx = [self.num_body_bones + i for i in range(self.num_leg_bones * self.num_legs // 2)] | |
leg_bones_negx = [self.num_body_bones + self.num_leg_bones * self.num_legs // 2 + i for i in range(self.num_leg_bones * self.num_legs // 2)] | |
tmp_mask = torch.zeros_like(articulation_angles) | |
tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 2] = 1 | |
articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # no twist | |
tmp_mask = torch.zeros_like(articulation_angles) | |
tmp_mask[:, :, leg_bones_posx + leg_bones_negx, 1] = 1 | |
articulation_angles = tmp_mask * (articulation_angles * 0.3) + (1 - tmp_mask) * articulation_angles # (-0.4, 0.4), limit side bending | |
# new regularizations, for bottom 2 bones of each leg, they can only rotate around x-axis, | |
# and for the toppest bone of legs, restrict its angles in a smaller range | |
if (self.cfgs.get('iter_leg_rotation_start', -1) > 0) and (total_iter > self.cfgs.get('iter_leg_rotation_start', -1)): | |
if self.cfgs.get('forbid_leg_rotate', False): | |
if self.cfgs.get('small_leg_angle', False): | |
# regularize the rotation angle of first leg bones | |
leg_bones_top = [8, 11, 14, 17] | |
# leg_bones_top = [10, 13, 16, 19] | |
tmp_mask = torch.zeros_like(articulation_angles) | |
tmp_mask[:, :, leg_bones_top, 1] = 1 | |
tmp_mask[:, :, leg_bones_top, 2] = 1 | |
articulation_angles = tmp_mask * (articulation_angles * 0.05) + (1 - tmp_mask) * articulation_angles | |
leg_bones_bottom = [9, 10, 12, 13, 15, 16, 18, 19] | |
# leg_bones_bottom = [8, 9, 11, 12, 14, 15, 17, 18] | |
tmp_mask = torch.ones_like(articulation_angles) | |
tmp_mask[:, :, leg_bones_bottom, 1] = 0 | |
tmp_mask[:, :, leg_bones_bottom, 2] = 0 | |
# tmp_mask[:, :, leg_bones_bottom, 0] = 0.3 | |
articulation_angles = tmp_mask * articulation_angles | |
if epoch in self.perturb_articulation_epochs: | |
articulation_angles = articulation_angles + torch.randn_like(articulation_angles) * 0.1 | |
articulation_angles = articulation_angles * self.max_arti_angle / 180 * np.pi | |
# check if regularize the leg-connecting body bones z-rotation first | |
# then check if regularize all the body bones z-rotation | |
# regularize z-rotation using 0.1 in pi-space | |
body_rotate_mult = self.cfgs.get('reg_body_rotate_mult', 0.1) | |
body_rotate_mult = body_rotate_mult * 180 * 1.0 / (self.max_arti_angle * np.pi) # the max angle = mult*original_max_angle | |
body_rotate_reg_mode = self.cfgs.get('body_rotate_reg_mode', 'nothing') | |
if body_rotate_reg_mode == 'leg-connect': | |
body_bones_mask = [2, 3, 4, 5] | |
tmp_body_mask = torch.zeros_like(articulation_angles) | |
tmp_body_mask[:, :, body_bones_mask, 2] = 1 | |
articulation_angles = tmp_body_mask * (articulation_angles * body_rotate_mult) + (1 - tmp_body_mask) * articulation_angles | |
elif body_rotate_reg_mode == 'all-bones': | |
body_bones_mask = [0, 1, 2, 3, 4, 5, 6, 7] | |
tmp_body_mask = torch.zeros_like(articulation_angles) | |
tmp_body_mask[:, :, body_bones_mask, 2] = 1 | |
articulation_angles = tmp_body_mask * (articulation_angles * body_rotate_mult) + (1 - tmp_body_mask) * articulation_angles | |
elif body_rotate_reg_mode == 'nothing': | |
articulation_angles = articulation_angles * 1. | |
else: | |
raise NotImplementedError | |
verts_articulated, aux = skinning(verts, bones, self.kinematic_tree, articulation_angles, | |
output_posed_bones=True, temperature=self.skinning_temperature) | |
verts_articulated = verts_articulated.view(batch_size*num_frames, *verts_articulated.shape[2:]) | |
v_tex = shape.v_tex | |
if len(v_tex) != len(verts_articulated): | |
v_tex = v_tex.repeat(len(verts_articulated), 1, 1) | |
shape = mesh.make_mesh( | |
verts_articulated, | |
shape.t_pos_idx, | |
v_tex, | |
shape.t_tex_idx, | |
shape.material) | |
return shape, articulation_angles, aux | |
def get_camera_extrinsics_from_pose(self, pose, znear=0.1, zfar=1000., crop_fov_approx=None, offset_extra=None): | |
if crop_fov_approx is None: | |
crop_fov_approx = self.crop_fov_approx | |
N = len(pose) | |
if offset_extra is not None: | |
cam_pos_offset = torch.FloatTensor([0, 0, -self.cam_pos_z_offset - offset_extra]).to(pose.device) | |
else: | |
cam_pos_offset = torch.FloatTensor([0, 0, -self.cam_pos_z_offset]).to(pose.device) | |
pose_R = pose[:, :9].view(N, 3, 3).transpose(2, 1) | |
pose_T = pose[:, -3:] + cam_pos_offset[None, None, :] | |
pose_T = pose_T.view(N, 3, 1) | |
pose_RT = torch.cat([pose_R, pose_T], axis=2) # Nx3x4 | |
w2c = torch.cat([pose_RT, torch.FloatTensor([0, 0, 0, 1]).repeat(N, 1, 1).to(pose.device)], axis=1) # Nx4x4 | |
# We assume the images are perfect square. | |
if isinstance(crop_fov_approx, float) or isinstance(crop_fov_approx, int): | |
proj = util.perspective(crop_fov_approx / 180 * np.pi, 1, znear, zfar)[None].to(pose.device) | |
elif isinstance(crop_fov_approx, torch.Tensor): | |
proj = util.batched_perspective(crop_fov_approx / 180 * np.pi, 1, znear, zfar).to(pose.device) | |
else: | |
raise ValueError('crop_fov_approx must be float or torch.Tensor') | |
mvp = torch.matmul(proj, w2c) | |
campos = -torch.matmul(pose_R.transpose(2, 1), pose_T).view(N, 3) | |
return mvp, w2c, campos | |
def forward(self, category=None, images=None, prior_shape=None, epoch=None, dino_features=None, dino_clusters=None, total_iter=None, is_training=True): | |
batch_size, num_frames = images.shape[:2] | |
if self.enable_encoder: | |
feat_out, feat_key, patch_out, patch_key = self.forward_encoder(images, dino_features) | |
else: | |
feat_out = feat_key = patch_out = patch_key = None | |
shape = prior_shape | |
texture = self.netTexture | |
multi_hypothesis_aux = {} | |
if self.iter_nozeroy_start is not None and total_iter >= self.iter_nozeroy_start: | |
self.lookat_zeroy = False | |
if self.enable_pose: | |
poses_raw = self.forward_pose(images, feat_out, patch_out, patch_key, dino_features) | |
pose_raw, pose, rot_idx, rot_prob, rot_logit, rots_probs, rand_pose_flag = sample_pose_hypothesis_from_quad_prediction(poses_raw, total_iter, batch_size, num_frames, rot_temp_scalar=self.rot_temp_scalar, num_hypos=self.num_pose_hypos, naive_probs_iter=self.naive_probs_iter, best_pose_start_iter=self.best_pose_start_iter, random_sample=is_training, temp_clip_low=self.temp_clip_low, temp_clip_high=self.temp_clip_high) | |
multi_hypothesis_aux['rot_idx'] = rot_idx | |
multi_hypothesis_aux['rot_prob'] = rot_prob | |
multi_hypothesis_aux['rot_logit'] = rot_logit | |
multi_hypothesis_aux['rots_probs'] = rots_probs | |
multi_hypothesis_aux['rand_pose_flag'] = rand_pose_flag | |
else: | |
raise NotImplementedError | |
mvp, w2c, campos = self.get_camera_extrinsics_from_pose(pose) | |
deformation = None | |
if self.iter_deformation_start is not None: | |
if self.enable_deform and total_iter >= self.iter_deformation_start: | |
shape, deformation = self.forward_deformation(shape, feat_key, batch_size, num_frames) | |
else: | |
if self.enable_deform and epoch in self.deform_epochs: | |
shape, deformation = self.forward_deformation(shape, feat_key, batch_size, num_frames) | |
arti_params, articulation_aux = None, {} | |
if self.iter_articulation_start is not None: | |
if self.enable_articulation and total_iter >= self.iter_articulation_start: | |
shape, arti_params, articulation_aux = self.forward_articulation(shape, feat_key, patch_key, mvp, w2c, batch_size, num_frames, epoch, category, total_iter=total_iter) | |
else: | |
if self.enable_articulation and epoch in self.articulation_epochs: | |
shape, arti_params, articulation_aux = self.forward_articulation(shape, feat_key, patch_key, mvp, w2c, batch_size, num_frames, epoch, category, total_iter=None) | |
if self.enable_lighting: | |
light = self.netLight | |
else: | |
light = None | |
aux = articulation_aux | |
aux.update(multi_hypothesis_aux) | |
# if using texture_way to control a local texture, output patch_out | |
if self.texture_way is None: | |
return shape, pose_raw, pose, mvp, w2c, campos, texture, feat_out, patch_key, deformation, arti_params, light, aux | |
else: | |
return shape, pose_raw, pose, mvp, w2c, campos, texture, feat_out, patch_key, deformation, arti_params, light, aux, patch_out | |
class Unsup3DDDP: | |
def __init__(self, cfgs): | |
self.cfgs = cfgs | |
self.device = cfgs.get('device', 'cpu') | |
self.in_image_size = cfgs.get('in_image_size', 128) | |
self.out_image_size = cfgs.get('out_image_size', 128) | |
self.num_epochs = cfgs.get('num_epochs', 10) | |
self.lr = cfgs.get('lr', 1e-4) | |
self.use_scheduler = cfgs.get('use_scheduler', False) | |
if self.use_scheduler: | |
scheduler_milestone = cfgs.get('scheduler_milestone', [1,2,3,4,5]) | |
scheduler_gamma = cfgs.get('scheduler_gamma', 0.5) | |
self.make_scheduler = lambda optim: torch.optim.lr_scheduler.MultiStepLR(optim, milestones=scheduler_milestone, gamma=scheduler_gamma) | |
self.cam_pos_z_offset = cfgs.get('cam_pos_z_offset', 10.) | |
self.full_size_h = cfgs.get('full_size_h', 1080) | |
self.full_size_w = cfgs.get('full_size_w', 1920) | |
# self.fov_w = cfgs.get('fov_w', 60) | |
# self.fov_h = np.arctan(np.tan(self.fov_w /2 /180*np.pi) / self.full_size_w * self.full_size_h) *2 /np.pi*180 # 36 | |
self.crop_fov_approx = cfgs.get("crop_fov_approx", 25) | |
self.mesh_regularization_mode = cfgs.get('mesh_regularization_mode', 'seq') | |
self.enable_prior = cfgs.get('enable_prior', False) | |
if self.enable_prior: | |
self.netPrior = PriorPredictor(self.cfgs) #DOR - add label | |
self.prior_lr = cfgs.get('prior_lr', self.lr) | |
self.prior_weight_decay = cfgs.get('prior_weight_decay', 0.) | |
self.prior_only_epochs = cfgs.get('prior_only_epochs', 0) | |
self.netInstance = InstancePredictor(self.cfgs, tet_bbox=self.netPrior.netShape.getAABB()) | |
self.perturb_sdf = cfgs.get('perturb_sdf', False) | |
self.blur_mask = cfgs.get('blur_mask', False) | |
self.blur_mask_iter = cfgs.get('blur_mask_iter', 1) | |
self.seqshape_epochs = np.arange(*cfgs.get('seqshape_epochs', [0, self.num_epochs])) | |
self.avg_texture_epochs = np.arange(*cfgs.get('avg_texture_epochs', [0, 0])) | |
self.swap_texture_epochs = np.arange(*cfgs.get('swap_texture_epochs', [0, 0])) | |
self.swap_priorshape_epochs = np.arange(*cfgs.get('swap_priorshape_epochs', [0, 0])) | |
self.avg_seqshape_epochs = np.arange(*cfgs.get('avg_seqshape_epochs', [0, 0])) | |
self.swap_seqshape_epochs = np.arange(*cfgs.get('swap_seqshape_epochs', [0, 0])) | |
self.pose_epochs = np.arange(*cfgs.get('pose_epochs', [0, 0])) | |
self.pose_iters = cfgs.get('pose_iters', 0) | |
self.deform_type = cfgs.get('deform_type', None) | |
self.mesh_reg_decay_epoch = cfgs.get('mesh_reg_decay_epoch', 0) | |
self.sdf_reg_decay_start_iter = cfgs.get('sdf_reg_decay_start_iter', 0) | |
self.mesh_reg_decay_rate = cfgs.get('mesh_reg_decay_rate', 1) | |
self.texture_epochs = np.arange(*cfgs.get('texture_epochs', [0, self.num_epochs])) | |
self.zflip_epochs = np.arange(*cfgs.get('zflip_epochs', [0, self.num_epochs])) | |
self.lookat_zflip_loss_epochs = np.arange(*cfgs.get('lookat_zflip_loss_epochs', [0, self.num_epochs])) | |
self.lookat_zflip_no_other_losses = cfgs.get('lookat_zflip_no_other_losses', False) | |
self.flow_loss_epochs = np.arange(*cfgs.get('flow_loss_epochs', [0, self.num_epochs])) | |
self.sdf_inflate_reg_loss_epochs = np.arange(*cfgs.get('sdf_inflate_reg_loss_epochs', [0, self.num_epochs])) | |
self.arti_reg_loss_epochs = np.arange(*cfgs.get('arti_reg_loss_epochs', [0, self.num_epochs])) | |
self.background_mode = cfgs.get('background_mode', 'background') | |
self.shape_prior_type = cfgs.get('shape_prior_type', 'deform') | |
self.backward_prior = cfgs.get('backward_prior', True) | |
self.resume_prior_optim = cfgs.get('resume_prior_optim', True) | |
self.dmtet_grid_smaller_epoch = cfgs.get('dmtet_grid_smaller_epoch', 0) | |
self.dmtet_grid_smaller = cfgs.get('dmtet_grid_smaller', 128) | |
self.dmtet_grid = cfgs.get('dmtet_grid', 256) | |
self.pose_xflip_recon_epochs = np.arange(*cfgs.get('pose_xflip_recon_epochs', [0, 0])) | |
self.rot_rand_quad_epochs = np.arange(*cfgs.get('rot_rand_quad_epochs', [0, 0])) | |
self.rot_all_quad_epochs = np.arange(*cfgs.get('rot_all_quad_epochs', [0, 0])) | |
self.calc_dino_features = cfgs.get('calc_dino_features', False) | |
# self.smooth_type = cfgs.get('smooth_type', 'None') | |
# print(f"****smooth_type: {self.smooth_type}****") | |
## smooth losses | |
# smooth articulation | |
self.arti_smooth_type = cfgs.get('arti_smooth_type', None) | |
self.arti_smooth_loss_type = cfgs.get('arti_smooth_loss_type', None) | |
self.arti_smooth_loss_weight = cfgs.get('arti_smooth_loss_weight', 0.) | |
self.using_arti_smooth_loss = self.arti_smooth_type and self.arti_smooth_loss_type and self.arti_smooth_loss_weight > 0. | |
if self.using_arti_smooth_loss: | |
self.arti_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.arti_smooth_type, loss_type=self.arti_smooth_loss_type) | |
else: | |
self.arti_smooth_loss_fn = None | |
# smooth deformation | |
self.deform_smooth_type = cfgs.get('deform_smooth_type', None) | |
self.deform_smooth_loss_type = cfgs.get('deform_smooth_loss_type', None) | |
self.deform_smooth_loss_weight = cfgs.get('deform_smooth_loss_weight', 0.) | |
self.using_deform_smooth_loss = self.deform_smooth_type and self.deform_smooth_loss_type and self.deform_smooth_loss_weight > 0. | |
if self.using_deform_smooth_loss: | |
self.deform_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.deform_smooth_type, loss_type=self.deform_smooth_loss_type) | |
else: | |
self.deform_smooth_loss_fn = None | |
# smooth camera pose | |
self.campos_smooth_type = cfgs.get('campos_smooth_type', None) | |
self.campos_smooth_loss_type = cfgs.get('campos_smooth_loss_type', None) | |
self.campos_smooth_loss_weight = cfgs.get('campos_smooth_loss_weight', 0.) | |
self.using_campos_smooth_loss = self.campos_smooth_type and self.campos_smooth_loss_type and self.campos_smooth_loss_weight > 0. | |
if self.using_campos_smooth_loss: | |
self.campos_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.campos_smooth_type, loss_type=self.campos_smooth_loss_type) | |
else: | |
self.campos_smooth_loss_fn = None | |
# smooth articulation velocity | |
self.artivel_smooth_type = cfgs.get('artivel_smooth_type', None) | |
self.artivel_smooth_loss_type = cfgs.get('artivel_smooth_loss_type', None) | |
self.artivel_smooth_loss_weight = cfgs.get('artivel_smooth_loss_weight', 0.) | |
self.using_artivel_smooth_loss = self.artivel_smooth_type and self.artivel_smooth_loss_type and self.artivel_smooth_loss_weight > 0. | |
if self.using_artivel_smooth_loss: | |
self.artivel_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.artivel_smooth_type, loss_type=self.artivel_smooth_loss_type) | |
else: | |
self.artivel_smooth_loss_fn = None | |
# smooth bone | |
self.bone_smooth_type = cfgs.get('bone_smooth_type', None) | |
self.bone_smooth_loss_type = cfgs.get('bone_smooth_loss_type', None) | |
self.bone_smooth_loss_weight = cfgs.get('bone_smooth_loss_weight', 0.) | |
self.using_bone_smooth_loss = self.bone_smooth_type and self.bone_smooth_loss_type and self.bone_smooth_loss_weight > 0. | |
if self.using_bone_smooth_loss: | |
self.bone_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.bone_smooth_type, loss_type=self.bone_smooth_loss_type) | |
else: | |
self.bone_smooth_loss_fn = None | |
# smooth bone velocity | |
self.bonevel_smooth_type = cfgs.get('bonevel_smooth_type', None) | |
self.bonevel_smooth_loss_type = cfgs.get('bonevel_smooth_loss_type', None) | |
self.bonevel_smooth_loss_weight = cfgs.get('bonevel_smooth_loss_weight', 0.) | |
self.using_bonevel_smooth_loss = self.bonevel_smooth_type and self.bonevel_smooth_loss_type and self.bonevel_smooth_loss_weight > 0. | |
if self.using_bonevel_smooth_loss: | |
self.bonevel_smooth_loss_fn = SmoothLoss(dim=1, smooth_type=self.bonevel_smooth_type, loss_type=self.bonevel_smooth_loss_type) | |
else: | |
self.bonevel_smooth_loss_fn = None | |
## perceptual loss | |
if cfgs.get('perceptual_loss_weight', 0.) > 0: | |
self.perceptual_loss_use_lin = cfgs.get('perceptual_loss_use_lin', True) | |
self.perceptual_loss = lpips.LPIPS(net='vgg', lpips=self.perceptual_loss_use_lin) | |
# self.glctx = dr.RasterizeGLContext() | |
self.glctx = dr.RasterizeCudaContext() | |
self.render_flow = self.cfgs.get('flow_loss_weight', 0.) > 0. | |
self.extra_renders = cfgs.get('extra_renders', []) | |
self.renderer_spp = cfgs.get('renderer_spp', 1) | |
self.dino_feature_recon_dim = cfgs.get('dino_feature_recon_dim', 64) | |
self.total_loss = 0. | |
self.all_scores = torch.Tensor() | |
self.checkpoint_dir = cfgs.get('checkpoint_dir', 'results') | |
# iter | |
self.iter_arti_reg_loss_start = cfgs.get('iter_arti_reg_loss_start', None) | |
# mask distribution | |
self.enable_mask_distribution = cfgs.get('enable_mask_distribution', False) | |
self.enable_mask_distribution = False | |
self.random_mask_law = cfgs.get('random_mask_law', 'batch_swap_noy') # batch_swap, batch_swap_noy, # random_azimuth # random_all | |
self.mask_distribution_path = cfgs.get('mask_distribution_path', None) | |
self.enable_clip = cfgs.get('enable_clip', False) | |
self.enable_clip = False | |
self.enable_disc = cfgs.get('enable_disc', False) | |
self.enable_disc = False | |
self.few_shot_gan_tex = False | |
self.few_shot_clip_tex = False | |
self.enable_sds = cfgs.get('enable_sds', False) | |
self.enable_vsd = cfgs.get('enable_vsd', False) | |
self.enable_sds = False | |
self.enable_vsd = False | |
def get_data_loaders(cfgs, dataset, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False): | |
train_loader = val_loader = test_loader = None | |
color_jitter_train = cfgs.get('color_jitter_train', None) | |
color_jitter_val = cfgs.get('color_jitter_val', None) | |
random_flip_train = cfgs.get('random_flip_train', False) | |
## video dataset | |
if dataset == 'video': | |
data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') | |
skip_beginning = cfgs.get('skip_beginning', 4) | |
skip_end = cfgs.get('skip_end', 4) | |
num_sample_frames = cfgs.get('num_sample_frames', 2) | |
min_seq_len = cfgs.get('min_seq_len', 10) | |
max_seq_len = cfgs.get('max_seq_len', 10) | |
debug_seq = cfgs.get('debug_seq', False) | |
random_sample_train_frames = cfgs.get('random_sample_train_frames', False) | |
shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) | |
random_sample_val_frames = cfgs.get('random_sample_val_frames', False) | |
load_background = cfgs.get('background_mode', 'none') == 'background' | |
rgb_suffix = cfgs.get('rgb_suffix', '.png') | |
load_dino_feature = cfgs.get('load_dino_feature', False) | |
load_dino_cluster = cfgs.get('load_dino_cluster', False) | |
dino_feature_dim = cfgs.get('dino_feature_dim', 64) | |
get_loader = lambda **kwargs: get_sequence_loader( | |
mode=data_loader_mode, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
in_image_size=in_image_size, | |
out_image_size=out_image_size, | |
debug_seq=debug_seq, | |
skip_beginning=skip_beginning, | |
skip_end=skip_end, | |
num_sample_frames=num_sample_frames, | |
min_seq_len=min_seq_len, | |
max_seq_len=max_seq_len, | |
load_background=load_background, | |
rgb_suffix=rgb_suffix, | |
load_dino_feature=load_dino_feature, | |
load_dino_cluster=load_dino_cluster, | |
dino_feature_dim=dino_feature_dim, | |
flow_bool=flow_bool, | |
**kwargs) | |
if run_train: | |
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" | |
print(f"Loading training data from {train_data_dir}") | |
train_loader = get_loader(data_dir=train_data_dir, is_validation=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) | |
if val_data_dir is not None: | |
assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" | |
print(f"Loading validation data from {val_data_dir}") | |
val_loader = get_loader(data_dir=val_data_dir, is_validation=True, random_sample=random_sample_val_frames, shuffle=False, dense_sample=False, color_jitter=color_jitter_val, random_flip=False) | |
if run_test: | |
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" | |
print(f"Loading testing data from {test_data_dir}") | |
test_loader = get_loader(data_dir=test_data_dir, is_validation=True, dense_sample=False, color_jitter=None, random_flip=False) | |
## CUB dataset | |
elif dataset == 'cub': | |
get_loader = lambda **kwargs: get_cub_loader( | |
batch_size=batch_size, | |
num_workers=num_workers, | |
image_size=in_image_size, | |
**kwargs) | |
if run_train: | |
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" | |
print(f"Loading training data from {train_data_dir}") | |
train_loader = get_loader(data_dir=train_data_dir, split='train', is_validation=False) | |
val_loader = get_loader(data_dir=val_data_dir, split='val', is_validation=True) | |
if run_test: | |
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" | |
print(f"Loading testing data from {test_data_dir}") | |
test_loader = get_loader(data_dir=test_data_dir, split='test', is_validation=True) | |
## other datasets | |
else: | |
get_loader = lambda **kwargs: get_image_loader( | |
batch_size=batch_size, | |
num_workers=num_workers, | |
image_size=in_image_size, | |
**kwargs) | |
if run_train: | |
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" | |
print(f"Loading training data from {train_data_dir}") | |
train_loader = get_loader(data_dir=train_data_dir, is_validation=False, color_jitter=color_jitter_train) | |
if val_data_dir is not None: | |
assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" | |
print(f"Loading validation data from {val_data_dir}") | |
val_loader = get_loader(data_dir=val_data_dir, is_validation=True, color_jitter=color_jitter_val) | |
if run_test: | |
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" | |
print(f"Loading testing data from {test_data_dir}") | |
test_loader = get_loader(data_dir=test_data_dir, is_validation=True, color_jitter=None) | |
return train_loader, val_loader, test_loader | |
def get_data_loaders_ddp(cfgs, dataset, rank, world_size, in_image_size=256, out_image_size=256, batch_size=64, num_workers=4, run_train=False, run_test=False, train_data_dir=None, val_data_dir=None, test_data_dir=None, flow_bool=False): | |
train_loader = val_loader = test_loader = None | |
color_jitter_train = cfgs.get('color_jitter_train', None) | |
color_jitter_val = cfgs.get('color_jitter_val', None) | |
random_flip_train = cfgs.get('random_flip_train', False) | |
## video dataset | |
if dataset == 'video': | |
data_loader_mode = cfgs.get('data_loader_mode', 'n_frame') | |
skip_beginning = cfgs.get('skip_beginning', 4) | |
skip_end = cfgs.get('skip_end', 4) | |
num_sample_frames = cfgs.get('num_sample_frames', 2) | |
min_seq_len = cfgs.get('min_seq_len', 10) | |
max_seq_len = cfgs.get('max_seq_len', 10) | |
debug_seq = cfgs.get('debug_seq', False) | |
random_sample_train_frames = cfgs.get('random_sample_train_frames', False) | |
shuffle_train_seqs = cfgs.get('shuffle_train_seqs', False) | |
random_sample_val_frames = cfgs.get('random_sample_val_frames', False) | |
load_background = cfgs.get('background_mode', 'none') == 'background' | |
rgb_suffix = cfgs.get('rgb_suffix', '.png') | |
load_dino_feature = cfgs.get('load_dino_feature', False) | |
load_dino_cluster = cfgs.get('load_dino_cluster', False) | |
dino_feature_dim = cfgs.get('dino_feature_dim', 64) | |
get_loader_ddp = lambda **kwargs: get_sequence_loader_ddp( | |
mode=data_loader_mode, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
in_image_size=in_image_size, | |
out_image_size=out_image_size, | |
debug_seq=debug_seq, | |
skip_beginning=skip_beginning, | |
skip_end=skip_end, | |
num_sample_frames=num_sample_frames, | |
min_seq_len=min_seq_len, | |
max_seq_len=max_seq_len, | |
load_background=load_background, | |
rgb_suffix=rgb_suffix, | |
load_dino_feature=load_dino_feature, | |
load_dino_cluster=load_dino_cluster, | |
dino_feature_dim=dino_feature_dim, | |
flow_bool=flow_bool, | |
**kwargs) | |
get_loader = lambda **kwargs: get_sequence_loader( | |
mode=data_loader_mode, | |
batch_size=batch_size, | |
num_workers=num_workers, | |
in_image_size=in_image_size, | |
out_image_size=out_image_size, | |
debug_seq=debug_seq, | |
skip_beginning=skip_beginning, | |
skip_end=skip_end, | |
num_sample_frames=num_sample_frames, | |
min_seq_len=min_seq_len, | |
max_seq_len=max_seq_len, | |
load_background=load_background, | |
rgb_suffix=rgb_suffix, | |
load_dino_feature=load_dino_feature, | |
load_dino_cluster=load_dino_cluster, | |
dino_feature_dim=dino_feature_dim, | |
**kwargs) | |
if run_train: | |
if isinstance(train_data_dir, dict): | |
for data_path in train_data_dir.values(): | |
assert osp.isdir(data_path), f"Training data directory does not exist: {data_path}" | |
elif isinstance(train_data_dir, str): | |
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" | |
else: | |
raise ValueError("train_data_dir must be a string or a dict of strings") | |
print(f"Loading training data...") | |
train_loader = get_loader_ddp(data_dir=train_data_dir, rank=rank, world_size=world_size, is_validation=False, random_sample=random_sample_train_frames, shuffle=shuffle_train_seqs, dense_sample=True, color_jitter=color_jitter_train, random_flip=random_flip_train) | |
if val_data_dir is not None: | |
if isinstance(val_data_dir, dict): | |
for data_path in val_data_dir.values(): | |
assert osp.isdir(data_path), f"Training data directory does not exist: {data_path}" | |
elif isinstance(val_data_dir, str): | |
assert osp.isdir(val_data_dir), f"Training data directory does not exist: {val_data_dir}" | |
else: | |
raise ValueError("train_data_dir must be a string or a dict of strings") | |
print(f"Loading validation data...") | |
# No need for data parallel for the validation data loader. | |
val_loader = get_loader(data_dir=val_data_dir, is_validation=True, random_sample=random_sample_val_frames, shuffle=False, dense_sample=False, color_jitter=color_jitter_val, random_flip=False) | |
if run_test: | |
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" | |
print(f"Loading testing data from {test_data_dir}") | |
test_loader = get_loader_ddp(data_dir=test_data_dir, rank=rank, world_size=world_size, is_validation=True, dense_sample=False, color_jitter=None, random_flip=False) | |
## CUB dataset | |
elif dataset == 'cub': | |
get_loader = lambda **kwargs: get_cub_loader_ddp( | |
batch_size=batch_size, | |
num_workers=num_workers, | |
image_size=in_image_size, | |
**kwargs) | |
if run_train: | |
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" | |
print(f"Loading training data from {train_data_dir}") | |
train_loader = get_loader(data_dir=train_data_dir, rank=rank, world_size=world_size, split='train', is_validation=False) | |
val_loader = get_loader(data_dir=val_data_dir, rank=rank, world_size=world_size, split='val', is_validation=True) | |
if run_test: | |
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" | |
print(f"Loading testing data from {test_data_dir}") | |
test_loader = get_loader(data_dir=test_data_dir, rank=rank, world_size=world_size, split='test', is_validation=True) | |
## other datasets | |
else: | |
get_loader = lambda **kwargs: get_image_loader_ddp( | |
batch_size=batch_size, | |
num_workers=num_workers, | |
image_size=in_image_size, | |
**kwargs) | |
if run_train: | |
assert osp.isdir(train_data_dir), f"Training data directory does not exist: {train_data_dir}" | |
print(f"Loading training data from {train_data_dir}") | |
train_loader = get_loader(data_dir=train_data_dir, rank=rank, world_size=world_size, is_validation=False, color_jitter=color_jitter_train) | |
if val_data_dir is not None: | |
assert osp.isdir(val_data_dir), f"Validation data directory does not exist: {val_data_dir}" | |
print(f"Loading validation data from {val_data_dir}") | |
val_loader = get_loader(data_dir=val_data_dir, rank=rank, world_size=world_size, is_validation=True, color_jitter=color_jitter_val) | |
if run_test: | |
assert osp.isdir(test_data_dir), f"Testing data directory does not exist: {test_data_dir}" | |
print(f"Loading testing data from {test_data_dir}") | |
test_loader = get_loader(data_dir=test_data_dir, rank=rank, world_size=world_size, is_validation=True, color_jitter=None) | |
return train_loader, val_loader, test_loader | |
def load_model_state(self, cp): | |
# TODO: very hacky: if using local texture, which is also usually finetuned from global texture | |
# we need to check if needs some handcrafted load in netInstance | |
if (self.netInstance.texture_way is not None) or (self.cfgs.get('texture_act', 'relu') != 'relu'): | |
new_netInstance_weights = {k: v for k, v in cp['netInstance'].items() if 'netTexture' not in k} | |
#find the new texture weights | |
texture_weights = self.netInstance.netTexture.state_dict() | |
#add the new weights to the new model weights | |
for k, v in texture_weights.items(): | |
new_netInstance_weights['netTexture.' + k] = v | |
self.netInstance.load_state_dict(new_netInstance_weights) | |
else: | |
self.netInstance.load_state_dict(cp["netInstance"]) | |
if self.enable_disc and "net_mask_disc" in cp: | |
self.mask_disc.load_state_dict(cp["net_mask_disc"]) | |
if self.enable_prior: | |
self.netPrior.load_state_dict(cp["netPrior"]) | |
def load_optimizer_state(self, cp): | |
# TODO: also very hacky here, as the load_model_state above | |
if self.netInstance.texture_way is not None: | |
opt_state_dict = self.optimizerInstance.state_dict() | |
param_ids = [id(p) for p in self.netInstance.netTexture.parameters()] | |
new_opt_state_dict = {} | |
new_opt_state_dict['state'] = {k: v for k, v in opt_state_dict['state'].items() if k not in param_ids} | |
new_param_groups = [] | |
for param_group in opt_state_dict['param_groups']: | |
new_param_group = {k: v for k, v in param_group.items() if k != 'params'} | |
new_param_group['params'] = [p_id for p_id in param_group['params'] if p_id not in param_ids] | |
new_param_groups.append(new_param_group) | |
new_opt_state_dict['param_groups'] = new_param_groups | |
self.optimizerInstance.load_state_dict(new_opt_state_dict) | |
else: | |
self.optimizerInstance.load_state_dict(cp["optimizerInstance"]) | |
# add parameters into optimizerInstance here | |
# if self.enable_disc: | |
# print('add mask discriminator parameters to Instance optimizer') | |
# self.optimizerInstance.add_param_group({'params': self.mask_disc.parameters()}) | |
if self.use_scheduler: | |
if 'schedulerInstance' in cp: | |
self.schedulerInstance.load_state_dict(cp["schedulerInstance"]) | |
if self.enable_disc and "optimizerDiscriminator" in cp: | |
self.optimizerDiscriminator.load_state_dict(cp["optimizerDiscriminator"]) | |
if self.enable_prior and self.resume_prior_optim: | |
self.optimizerPrior.load_state_dict(cp["optimizerPrior"]) | |
if self.use_scheduler: | |
if 'schedulerPrior' in cp: | |
self.schedulerPrior.load_state_dict(cp["schedulerPrior"]) | |
def get_model_state(self): | |
state = {"netInstance": self.netInstance.state_dict()} | |
if self.enable_disc: | |
state["net_mask_disc"] = self.mask_disc.state_dict() | |
if self.enable_prior: | |
state["netPrior"] = self.netPrior.state_dict() | |
return state | |
def get_optimizer_state(self): | |
state = {"optimizerInstance": self.optimizerInstance.state_dict()} | |
if self.enable_disc: | |
state['optimizerDiscriminator'] = self.optimizerDiscriminator.state_dict() | |
if self.use_scheduler: | |
state["schedulerInstance"] = self.schedulerInstance.state_dict() | |
if self.enable_prior: | |
state["optimizerPrior"] = self.optimizerPrior.state_dict() | |
if self.use_scheduler: | |
state["schedulerPrior"] = self.schedulerPrior.state_dict() | |
return state | |
def to(self, device): | |
self.device = device | |
self.netInstance.to(device) | |
if self.enable_prior: | |
self.netPrior.to(device) | |
for v in vars(self.netPrior.netShape): | |
attr = getattr(self.netPrior.netShape,v) | |
if type(attr) == torch.Tensor: | |
setattr(self.netPrior.netShape, v, attr.to(device)) | |
if hasattr(self, 'perceptual_loss'): | |
self.perceptual_loss.to(device) | |
def ddp(self, rank, world_size): | |
self.rank = rank | |
self.world_size = world_size | |
if self.world_size > 1: | |
self.netInstance_ddp = DDP( | |
self.netInstance, device_ids=[rank], | |
find_unused_parameters=True) | |
self.netInstance_ddp._set_static_graph() | |
self.netInstance = self.netInstance_ddp.module | |
if self.enable_prior: | |
self.netPrior_ddp = DDP( | |
self.netPrior, device_ids=[rank], | |
find_unused_parameters=True) | |
self.netPrior_ddp._set_static_graph() | |
self.netPrior = self.netPrior_ddp.module | |
if hasattr(self, 'perceptual_loss'): | |
self.perceptual_loss_ddp = DDP( | |
self.perceptual_loss, device_ids=[rank], | |
find_unused_parameters=True) | |
self.perceptual_loss = self.perceptual_loss_ddp.module | |
else: | |
print('actually no DDP for model') | |
def set_train(self): | |
if self.world_size > 1: | |
self.netInstance_ddp.train() | |
if self.enable_prior: | |
self.netPrior_ddp.train() | |
else: | |
self.netInstance.train() | |
if self.enable_disc: | |
self.mask_disc.train() | |
if self.enable_prior: | |
self.netPrior.train() | |
def set_eval(self): | |
if self.world_size > 1: | |
self.netInstance_ddp.eval() | |
if self.enable_prior: | |
self.netPrior_ddp.eval() | |
else: | |
self.netInstance.eval() | |
if self.enable_disc: | |
self.mask_disc.eval() | |
if self.enable_prior: | |
self.netPrior.eval() | |
def reset_optimizers(self): | |
print("Resetting optimizers...") | |
self.optimizerInstance = get_optimizer(self.netInstance, self.lr) | |
if self.enable_disc: | |
self.optimizerDiscriminator = get_optimizer(self.mask_disc, self.lr) | |
if self.use_scheduler: | |
self.schedulerInstance = self.make_scheduler(self.optimizerInstance) | |
if self.enable_prior: | |
self.optimizerPrior = get_optimizer(self.netPrior, lr=self.prior_lr, weight_decay=self.prior_weight_decay) | |
if self.use_scheduler: | |
self.schedulerPrior = self.make_scheduler(self.optimizerPrior) | |
def reset_only_disc_optimizer(self): | |
if self.enable_disc: | |
self.optimizerDiscriminator = get_optimizer(self.mask_disc, self.lr) | |
def backward(self): | |
self.optimizerInstance.zero_grad() | |
if self.backward_prior: | |
self.optimizerPrior.zero_grad() | |
# self.total_loss = self.add_unused() | |
self.total_loss.backward() | |
self.optimizerInstance.step() | |
if self.backward_prior: | |
self.optimizerPrior.step() | |
self.total_loss = 0. | |
def scheduler_step(self): | |
if self.use_scheduler: | |
self.schedulerInstance.step() | |
if self.enable_prior: | |
self.schedulerPrior.step() | |
def zflip_pose(self, pose): | |
if self.rot_rep == 'lookat': | |
vec_forward = pose[:,:,6:9] | |
vec_forward = vec_forward * torch.FloatTensor([1,1,-1]).view(1,1,3).to(vec_forward.device) | |
up = torch.FloatTensor([0,1,0]).to(pose.device).view(1,1,3) | |
vec_right = up.expand_as(vec_forward).cross(vec_forward, dim=-1) | |
vec_right = nn.functional.normalize(vec_right, p=2, dim=-1) | |
vec_up = vec_forward.cross(vec_right, dim=-1) | |
vec_up = nn.functional.normalize(vec_up, p=2, dim=-1) | |
rot_mat = torch.stack([vec_right, vec_up, vec_forward], 2) | |
rot_pred = rot_mat.reshape(*pose.shape[:-1], -1) | |
pose_zflip = torch.cat([rot_pred, pose[:,:,9:]], -1) | |
else: | |
raise NotImplementedError | |
return pose_zflip | |
def render(self, shape, texture, mvp, w2c, campos, resolution, background='none', im_features=None, light=None, prior_shape=None, render_flow=False, dino_pred=None, class_vector=None, render_mode='diffuse', two_sided_shading=True, num_frames=None, spp=1, bg_image=None, im_features_map=None): | |
h, w = resolution | |
N = len(mvp) | |
if bg_image is None: | |
if background in ['none', 'black']: | |
bg_image = torch.zeros((N, h, w, 3), device=mvp.device) | |
elif background == 'white': | |
bg_image = torch.ones((N, h, w, 3), device=mvp.device) | |
elif background == 'checkerboard': | |
bg_image = torch.FloatTensor(util.checkerboard((h, w), 8), device=self.device).repeat(N, 1, 1, 1) # NxHxWxC | |
elif background == 'random': | |
bg_image = torch.rand((N, h, w, 3), device=mvp.device) # NxHxWxC | |
elif background == 'random-pure': | |
random_values = torch.rand(N) | |
bg_image = random_values[..., None, None, None].repeat(1, h, w, 3).to(self.device) | |
else: | |
raise NotImplementedError | |
#insider render_mesh -> render_layer -> shade DOR | |
frame_rendered = render.render_mesh( | |
self.glctx, | |
shape, | |
mtx_in=mvp, | |
w2c=w2c, | |
view_pos=campos, | |
material=texture, | |
lgt=light, | |
resolution=resolution, | |
spp=spp, | |
msaa=True, | |
background=bg_image, | |
bsdf=render_mode, | |
feat=im_features, | |
prior_mesh=prior_shape, | |
two_sided_shading=two_sided_shading, | |
render_flow=render_flow, | |
dino_pred=dino_pred, | |
class_vector=class_vector, | |
num_frames=num_frames, | |
im_features_map=im_features_map) | |
shaded = frame_rendered['shaded'].permute(0, 3, 1, 2) | |
image_pred = shaded[:, :3, :, :] | |
mask_pred = shaded[:, 3, :, :] | |
albedo = frame_rendered['kd'].permute(0, 3, 1, 2)[:, :3, :, :] | |
if 'shading' in frame_rendered: | |
shading = frame_rendered['shading'].permute(0, 3, 1, 2)[:, :1, :, :] | |
else: | |
shading = None | |
if render_flow: | |
flow_pred = frame_rendered['flow'] | |
flow_pred = flow_pred.permute(0, 3, 1, 2)[:, :2, :, :] | |
else: | |
flow_pred = None | |
if dino_pred is not None: | |
dino_feat_im_pred = frame_rendered['dino_feat_im_pred'] | |
dino_feat_im_pred = dino_feat_im_pred.permute(0, 3, 1, 2)[:, :-1] | |
else: | |
dino_feat_im_pred = None | |
return image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading | |
def compute_reconstruction_losses(self, image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode='none', reduce=False): | |
losses = {} | |
batch_size, num_frames, _, h, w = image_pred.shape # BxFxCxHxW | |
# image_loss = (image_pred - image_gt) ** 2 | |
image_loss = (image_pred - image_gt).abs() | |
## silhouette loss | |
mask_pred_valid = mask_pred * mask_valid | |
# mask_pred_valid = mask_pred | |
# losses["silhouette_loss"] = ((mask_pred - mask_gt) ** 2).mean() | |
# mask_loss_mask = (image_loss.mean(2).detach() > 0.05).float() | |
mask_loss = (mask_pred_valid - mask_gt) ** 2 | |
# mask_loss = nn.functional.mse_loss(mask_pred, mask_gt) | |
# num_mask_pixels = mask_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1) | |
# losses["silhouette_loss"] = (mask_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean() | |
losses['silhouette_loss'] = mask_loss.view(batch_size, num_frames, -1).mean(2) | |
losses['silhouette_dt_loss'] = (mask_pred * mask_dt[:,:,1]).view(batch_size, num_frames, -1).mean(2) | |
losses['silhouette_inv_dt_loss'] = ((1-mask_pred) * mask_dt[:,:,0]).view(batch_size, num_frames, -1).mean(2) | |
mask_pred_binary = (mask_pred_valid > 0.).float().detach() | |
mask_both_binary = (mask_pred_binary * mask_gt).view(batch_size*num_frames, 1, *mask_pred.shape[2:]) | |
mask_both_binary = (nn.functional.avg_pool2d(mask_both_binary, 3, stride=1, padding=1).view(batch_size, num_frames, *mask_pred.shape[2:]) > 0.99).float().detach() # erode by 1 pixel | |
## reconstruction loss | |
# image_loss_mask = (mask_pred*mask_gt).unsqueeze(2).expand_as(image_gt) | |
# image_loss = image_loss * image_loss_mask | |
# num_mask_pixels = image_loss_mask.reshape(batch_size*num_frames, -1).sum(1).clamp(min=1) | |
# losses["rgb_loss"] = (image_loss.reshape(batch_size*num_frames, -1).sum(1) / num_mask_pixels).mean() | |
if background_mode in ['background', 'input']: | |
pass | |
else: | |
image_loss = image_loss * mask_both_binary.unsqueeze(2) | |
losses['rgb_loss'] = image_loss.reshape(batch_size, num_frames, -1).mean(2) | |
if self.cfgs.get('perceptual_loss_weight', 0.) > 0: | |
if background_mode in ['background', 'input']: | |
perc_image_pred = image_pred | |
perc_image_gt = image_gt | |
else: | |
perc_image_pred = image_pred * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2)) | |
perc_image_gt = image_gt * mask_pred_binary.unsqueeze(2) + 0.5 * (1-mask_pred_binary.unsqueeze(2)) | |
losses['perceptual_loss'] = self.perceptual_loss(perc_image_pred.view(-1, *image_pred.shape[2:]) *2-1, perc_image_gt.view(-1, *image_gt.shape[2:]) *2-1).view(batch_size, num_frames) | |
## flow loss - between first and second frame | |
if flow_pred is not None: | |
flow_loss = (flow_pred - flow_gt).abs() | |
flow_loss_mask = mask_both_binary[:,:-1].unsqueeze(2).expand_as(flow_gt).detach() | |
## ignore frames where GT flow is too large (likely inaccurate) | |
large_flow = (flow_gt.abs() > 0.5).float() * flow_loss_mask | |
large_flow = (large_flow.view(batch_size, num_frames-1, -1).sum(2) > 0).float() | |
self.large_flow = large_flow | |
flow_loss = flow_loss * flow_loss_mask * (1 - large_flow[:,:,None,None,None]) | |
num_mask_pixels = flow_loss_mask.reshape(batch_size, num_frames-1, -1).sum(2).clamp(min=1) | |
losses['flow_loss'] = (flow_loss.reshape(batch_size, num_frames-1, -1).sum(2) / num_mask_pixels) | |
# losses["flow_loss"] = flow_loss.mean() | |
if dino_feat_im_pred is not None and dino_feat_im_gt is not None: | |
dino_feat_loss = (dino_feat_im_pred - dino_feat_im_gt) ** 2 | |
dino_feat_loss = dino_feat_loss * mask_both_binary.unsqueeze(2) | |
losses['dino_feat_im_loss'] = dino_feat_loss.reshape(batch_size, num_frames, -1).mean(2) | |
if reduce: | |
for k, v in losses.item(): | |
losses[k] = v.mean() | |
return losses | |
def compute_pose_xflip_reg_loss(self, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None): | |
image_xflip = input_image.flip(4) | |
if dino_feat_im is not None: | |
dino_feat_im_xflip = dino_feat_im.flip(4) | |
else: | |
dino_feat_im_xflip = None | |
if self.world_size > 1: | |
netInst = self.netInstance_ddp | |
else: | |
netInst = self.netInstance | |
# feat_xflip, _ = self.netInstance_ddp.forward_encoder(image_xflip, dino_feat_im_xflip) | |
feat_xflip, _ = netInst.forward_encoder(image_xflip, dino_feat_im_xflip) | |
batch_size, num_frames = input_image.shape[:2] | |
# pose_xflip_raw = self.netInstance_ddp.forward_pose(image_xflip, feat_xflip, dino_feat_im_xflip) | |
pose_xflip_raw = netInst.forward_pose(image_xflip, feat_xflip, dino_feat_im_xflip) | |
if input_image_xflip_flag is not None: | |
pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x | |
pose_xflip_raw = pose_xflip_raw * (1 - input_image_xflip_flag.view(batch_size * num_frames, 1)) + pose_xflip_raw_xflip * input_image_xflip_flag.view(batch_size * num_frames, 1) | |
# rot_rep = self.netInstance_ddp.rot_rep | |
rot_rep = netInst.rot_rep | |
if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': | |
pose_xflip_xflip = pose_xflip * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x | |
pose_xflip_reg_loss = ((pose_xflip_xflip - pose) ** 2.).mean() | |
elif rot_rep == 'quaternion': | |
rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose[...,:4]), convention='XYZ') | |
pose_euler = torch.cat([rot_euler, pose[...,4:]], -1) | |
rot_xflip_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip[...,:4]), convention='XYZ') | |
pose_xflip_euler = torch.cat([rot_xflip_euler, pose_xflip[...,4:]], -1) | |
pose_xflip_euler_xflip = pose_xflip_euler * torch.FloatTensor([1,-1,-1,-1,1,1]).to(pose_xflip.device) # rot y+z, trans x | |
pose_xflip_reg_loss = ((pose_xflip_euler_xflip - pose_euler) ** 2.).mean() | |
elif rot_rep == 'lookat': | |
pose_xflip_raw_xflip = pose_xflip_raw * torch.FloatTensor([-1,1,1,-1,1,1]).to(pose_raw.device) # forward x, trans x | |
pose_xflip_reg_loss = ((pose_xflip_raw_xflip - pose_raw)[...,0] ** 2.) # compute x only | |
# if epoch >= self.nolookat_zflip_loss_epochs and self.lookat_zflip_no_other_losses: | |
# pose_xflip_reg_loss = pose_xflip_reg_loss.mean(1) * is_pose_1_better | |
pose_xflip_reg_loss = pose_xflip_reg_loss.mean() | |
return pose_xflip_reg_loss, pose_xflip_raw | |
def compute_edge_length_reg_loss(self, mesh, prior_mesh): | |
prior_edge_lengths = get_edge_length(prior_mesh.v_pos, prior_mesh.t_pos_idx) | |
max_length = prior_edge_lengths.max().detach() *1.1 | |
edge_lengths = get_edge_length(mesh.v_pos, mesh.t_pos_idx) | |
mesh_edge_length_loss = ((edge_lengths - max_length).clamp(min=0)**2).mean() | |
return mesh_edge_length_loss, edge_lengths | |
def compute_regularizers(self, mesh, prior_mesh, input_image, dino_feat_im, pose_raw, input_image_xflip_flag=None, arti_params=None, deformation=None, mid_img_idx=0, posed_bones=None, class_vector=None): | |
losses = {} | |
aux = {} | |
if self.enable_prior: | |
losses.update(self.netPrior.netShape.get_sdf_reg_loss(class_vector=class_vector)) | |
if self.cfgs.get('pose_xflip_reg_loss_weight', 0.) > 0: | |
losses["pose_xflip_reg_loss"], aux['pose_xflip_raw'] = self.compute_pose_xflip_reg_loss(input_image, dino_feat_im, pose_raw, input_image_xflip_flag) | |
if self.using_campos_smooth_loss: | |
# from IPython import embed; embed() | |
pose_raw_ = pose_raw.view(self.bs, self.nf, *pose_raw.shape[1:]) | |
losses['campos_smooth_loss'] = self.campos_smooth_loss_fn(pose_raw_) | |
b, f = input_image.shape[:2] | |
if b >= 2: | |
vec_forward = pose_raw[..., :3] | |
losses['pose_entropy_loss'] = (vec_forward[:b//2] * vec_forward[b//2:(b//2)*2]).sum(-1).mean() | |
else: | |
losses['pose_entropy_loss'] = 0. | |
losses['mesh_normal_consistency_loss'] = normal_consistency(mesh.v_pos, mesh.t_pos_idx) | |
losses['mesh_laplacian_consistency_loss'] = laplace_regularizer_const(mesh.v_pos, mesh.t_pos_idx) | |
losses['mesh_edge_length_loss'], aux['edge_lengths'] = self.compute_edge_length_reg_loss(mesh, prior_mesh) | |
if arti_params is not None: | |
#losses['arti_reg_loss'] = (arti_params ** 2).mean() | |
losses['arti_reg_loss'] = (arti_params ** 2).mean() #TODO dor Rart | |
if arti_params is not None and self.using_arti_smooth_loss: | |
arti_smooth_loss = self.arti_smooth_loss_fn(arti_params) | |
losses['arti_smooth_loss'] = arti_smooth_loss | |
# if arti_params is not None and self.cfgs.get('arti_smooth_loss_weight', 0.) > 0: | |
# if self.smooth_type == 'loss' and mid_img_idx > 0: | |
# # print("+++++++++++++++++add smooth to *articulation* loss") | |
# # from IPython import embed; embed() | |
# arti_smooth_loss = ( | |
# ((arti_params[:,mid_img_idx,:,:] - arti_params[:,0:mid_img_idx,:,:])**2) | |
# + ((arti_params[:,mid_img_idx,:,:] - arti_params[:,mid_img_idx+1:2*mid_img_idx+1,:,:])**2) | |
# ).mean() | |
# losses['arti_smooth_loss'] = arti_smooth_loss | |
if arti_params is not None and self.using_artivel_smooth_loss: | |
# from IPython import embed; embed() | |
_, nf, _, _= arti_params.shape | |
arti_vel = arti_params[:,1:nf,:,:] - arti_params[:,:(nf-1),:,:] | |
artivel_smooth_loss = self.artivel_smooth_loss_fn(arti_vel) | |
losses['artivel_smooth_loss'] = artivel_smooth_loss | |
if deformation is not None: | |
#losses['deformation_reg_loss'] = (deformation ** 2).mean() | |
losses['deformation_reg_loss'] = (deformation ** 2).mean() #TODO dor - Rdef | |
d1 = deformation[:, mesh.t_pos_idx[0, :, 0], :] | |
d2 = deformation[:, mesh.t_pos_idx[0, :, 1], :] | |
d3 = deformation[:, mesh.t_pos_idx[0, :, 2], :] | |
num_samples = 5000 | |
sample_idx1 = torch.randperm(d1.shape[1])[:num_samples].to(self.device) | |
sample_idx2 = torch.randperm(d1.shape[1])[:num_samples].to(self.device) | |
sample_idx3 = torch.randperm(d1.shape[1])[:num_samples].to(self.device) | |
dist1 = ((d1[:, sample_idx1, :] - d2[:, sample_idx1, :]) ** 2).mean() | |
dist2 = ((d2[:, sample_idx2, :] - d3[:, sample_idx2, :]) ** 2).mean() | |
dist3 = ((d3[:, sample_idx3, :] - d1[:, sample_idx3, :]) ** 2).mean() | |
losses['smooth_deformation_loss'] = dist1 + dist2 + dist3 | |
if deformation is not None and self.using_deform_smooth_loss: | |
deformation_ = deformation.view(self.bs, self.nf, *deformation.shape[1:]) | |
losses['deform_smooth_loss'] = self.deform_smooth_loss_fn(deformation_) | |
# if deformation is not None and self.cfgs.get('deformation_smooth_loss_weight', 0.) > 0: | |
# if self.smooth_type == 'loss' and mid_img_idx > 0: | |
# # print("+++++++++++++++++add smooth to *deformation* loss") | |
# deformation = deformation.view(self.bs, self.nf, *deformation.shape[1:]) | |
# deformation_smooth_loss = ( | |
# ((deformation[:, mid_img_idx,:,:] - deformation[:, 0:mid_img_idx,:,:]) ** 2) | |
# + ((deformation[:, mid_img_idx,:,:] - deformation[:, mid_img_idx+1:2*mid_img_idx+1,:,:]) ** 2) | |
# ).mean() | |
# losses['deformation_smooth_loss'] = deformation_smooth_loss | |
# # deformation = deformation.view(self.bs * self.nf, *deformation.shape[2:]) | |
# # losses['deformation_reg_loss'] = deformation.abs().mean() | |
## posed bones. | |
if posed_bones is not None and self.using_bone_smooth_loss: | |
bone_smooth_loss = self.bone_smooth_loss_fn(posed_bones) | |
losses['bone_smooth_loss'] = bone_smooth_loss | |
if posed_bones is not None and self.using_bonevel_smooth_loss: | |
_, nf, _, _, _= posed_bones.shape | |
bone_vel = posed_bones[:,1:nf,...] - posed_bones[:,:(nf-1),...] | |
bonevel_smooth_loss = self.bonevel_smooth_loss_fn(bone_vel) | |
losses['bonevel_smooth_loss'] = bonevel_smooth_loss | |
return losses, aux | |
def parse_dict_definition(self, dict_config, total_iter): | |
''' | |
The dict_config is a diction-based configuration with ascending order | |
The key: value is the NUM_ITERATION_WEIGHT_BEGIN: WEIGHT | |
For example, | |
{0: 0.1, 1000: 0.2, 10000: 0.3} | |
means at beginning, the weight is 0.1, from 1k iterations, weight is 0.2, and after 10k, weight is 0.3 | |
''' | |
length = len(dict_config) | |
all_iters = list(dict_config.keys()) | |
all_weights = list(dict_config.values()) | |
weight = all_weights[-1] | |
for i in range(length-1): | |
# this works for dict having at least two items, otherwise you don't need dict to set config | |
iter_num = all_iters[i] | |
iter_num_next = all_iters[i+1] | |
if iter_num <= total_iter and total_iter < iter_num_next: | |
weight = all_weights[i] | |
break | |
return weight | |
def compute_clip_loss(self, random_image_pred, image_pred, category): | |
# image preprocess for CLIP | |
random_image = torch.nn.functional.interpolate(random_image_pred, (self.clip_reso, self.clip_reso), mode='bilinear') | |
image_pred = torch.nn.functional.interpolate(image_pred.squeeze(1), (self.clip_reso, self.clip_reso), mode='bilinear') | |
random_image = tvf.normalize(random_image, self.clip_mean, self.clip_std) | |
image_pred = tvf.normalize(image_pred, self.clip_mean, self.clip_std) | |
feat_img_1 = self.clip_model.encode_image(random_image) | |
feat_img_2 = self.clip_model.encode_image(image_pred) | |
clip_all_loss = torch.nn.functional.cosine_similarity(feat_img_1, feat_img_2) | |
clip_all_loss = 1 - clip_all_loss.mean() | |
# feat_img_1 = torch.mean(feat_img_1, dim=0) | |
# feat_img_2 = torch.mean(feat_img_2, dim=0) | |
# clip_all_loss = torch.nn.functional.cosine_similarity(feat_img_1, feat_img_2, dim=0) | |
# clip_all_loss = 1 - clip_all_loss | |
if self.enable_clip_text: | |
text_feature = self.clip_text_feature[category].repeat(feat_img_1.shape[0], 1) | |
text_loss_1 = torch.nn.functional.cosine_similarity(feat_img_1, text_feature).mean() | |
text_loss_2 = torch.nn.functional.cosine_similarity(feat_img_2, text_feature).mean() | |
# text_feature = self.clip_text_feature[category][0] | |
# text_loss_1 = torch.nn.functional.cosine_similarity(feat_img_1, text_feature, dim=0) | |
# text_loss_2 = torch.nn.functional.cosine_similarity(feat_img_2, text_feature, dim=0) | |
clip_all_loss = clip_all_loss + (1 - text_loss_1) + (1 - text_loss_2) | |
return {'clip_all_loss': clip_all_loss} | |
def generate_patch_crop(self, images, masks, patch_size=128, patch_num_per_mask=1): | |
b, _, H, W = masks.shape | |
patches = [] | |
for i in range(masks.shape[0]): | |
mask = masks[i] | |
# mask: [1, H, W] | |
nonzero_indices = torch.nonzero(mask > 0, as_tuple=False) # [K', 3] | |
valid_mask = (nonzero_indices[:, 1] > patch_size // 2) & (nonzero_indices[:, 1] < (H - 1 - patch_size // 2)) & (nonzero_indices[:, 2] > patch_size // 2) & (nonzero_indices[:, 2] < (W - 1 - patch_size // 2)) | |
valid_idx = nonzero_indices[valid_mask] | |
patch_idx = valid_idx[torch.randperm(valid_idx.shape[0])[:patch_num_per_mask]] # [K, 3] | |
if patch_idx.shape[0] < patch_num_per_mask: | |
patches_this_img = torch.zeros(patch_num_per_mask, 3, self.few_shot_gan_tex_patch, self.few_shot_gan_tex_patch).to(self.device) | |
else: | |
patches_this_img = [] | |
for idx in range(patch_idx.shape[0]): | |
_, y, x = patch_idx[idx] | |
y_start = max(0, y - patch_size // 2) | |
y_end = min(H, y_start + patch_size) | |
x_start = max(0, x - patch_size // 2) | |
x_end = min(W, x_start + patch_size) | |
patch_content = images[i, :, y_start:y_end, x_start:x_end] | |
patch = F.interpolate(patch_content.unsqueeze(0), size=self.few_shot_gan_tex_patch, mode='bilinear') # [1, 3, ps, ps] | |
patches_this_img.append(patch) | |
patches_this_img = torch.cat(patches_this_img, dim=0) # [K, 3, ps, ps] | |
patches.append(patches_this_img) | |
patches = torch.concat(patches, dim=0) # [B*K, 3, ps, ps] | |
return patches | |
def compute_gan_tex_loss(self, category, image_gt, mask_gt, iv_image_pred, iv_mask_pred, w2c_pred, campos_pred, shape, prior_shape, texture, dino_pred, im_features, light, class_vector, num_frames, im_features_map, bins=360): | |
''' | |
This part is used to do gan training on texture, this is meant to only be used in fine-tuning, with local texture network | |
Ideally this loss only contributes to the Texture | |
''' | |
delta_angle = 2 * np.pi / bins | |
b = len(shape) | |
rand_degree = torch.randint(120, [b]) | |
rand_degree = rand_degree + 120 | |
# rand_degree = torch.ones(b) * 180 # we want to see the reversed side | |
delta_angle = delta_angle * rand_degree | |
delta_rot_matrix = [] | |
for i in range(b): | |
angle = delta_angle[i].item() | |
angle_matrix = torch.FloatTensor([ | |
[np.cos(angle), 0, np.sin(angle), 0], | |
[0, 1, 0, 0], | |
[-np.sin(angle), 0, np.cos(angle), 0], | |
[0, 0, 0, 1], | |
]).to(self.device) | |
delta_rot_matrix.append(angle_matrix) | |
delta_rot_matrix = torch.stack(delta_rot_matrix, dim=0) | |
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) | |
original_mvp = torch.bmm(proj, w2c_pred) | |
# original_campos = -w2c_pred[:, :3, 3] | |
original_campos = campos_pred | |
mvp = torch.matmul(original_mvp, delta_rot_matrix) | |
campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), original_campos[:,:,None])[:,:,0] | |
w2c = w2c_pred | |
resolution = (self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso) | |
# only train the texture | |
safe_detach = lambda x: x.detach() if x is not None else None | |
mesh = safe_detach(shape) | |
im_features = safe_detach(im_features) | |
im_features_map = safe_detach(im_features_map) | |
class_vector = safe_detach(class_vector) | |
set_requires_grad(texture, True) | |
set_requires_grad(dino_pred, False) | |
set_requires_grad(light, False) | |
background_for_reverse = 'none' | |
# background_for_reverse = 'random-pure' | |
image_pred, mask_pred, _, _, _, _ = self.render( | |
mesh, | |
texture, | |
mvp, | |
w2c, | |
campos, | |
resolution, | |
background=background_for_reverse, | |
im_features=im_features, | |
light=light, | |
prior_shape=prior_shape, | |
render_flow=False, | |
dino_pred=dino_pred, | |
spp=self.renderer_spp, | |
class_vector=class_vector, | |
render_mode='diffuse', | |
two_sided_shading=False, | |
num_frames=num_frames, | |
im_features_map={"original_mvp": original_mvp, "im_features_map": im_features_map} if im_features_map is not None else None # in other views we need to pass the original mvp | |
) | |
mask_pred = mask_pred.unsqueeze(1) | |
if self.few_shot_gan_tex_reso != self.out_image_size: | |
image_pred = torch.nn.functional.interpolate(image_pred, (self.out_image_size, self.out_image_size), mode='bilinear') | |
mask_pred = torch.nn.functional.interpolate(mask_pred, (self.out_image_size, self.out_image_size), mode='bilinear') | |
# image_pred = image_pred.clamp(0, 1) | |
# mask_pred = mask_pred.clamp(0, 1) # [B, 1, H, W] | |
if background_for_reverse == 'random': | |
# as we set a random background for rendering, we also need another random background for input view | |
# for background, we use the same as random view: a small resolution then upsample | |
random_bg = torch.rand(self.bs, self.nf, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) | |
random_bg = torch.nn.functional.interpolate(random_bg.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) | |
iv_mask_pred = iv_mask_pred.unsqueeze(2).repeat(1, 1, 3, 1, 1) | |
iv_image_pred = iv_image_pred * iv_mask_pred + random_bg * (1. - iv_mask_pred) | |
iv_image_pred = iv_image_pred.squeeze(1) | |
random_bg_gt = torch.rand(self.bs, self.nf, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) | |
random_bg_gt = torch.nn.functional.interpolate(random_bg_gt.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) | |
mask_gt = mask_gt.unsqueeze(2).repeat(1, 1, 3, 1, 1) | |
image_gt = image_gt * mask_gt + random_bg_gt * (1. - mask_gt) | |
image_gt = image_gt.squeeze(1) | |
elif background_for_reverse == 'random-pure': | |
# the background is random but with one color | |
random_values = torch.rand(b) | |
random_bg = random_values[..., None, None, None, None].repeat(1, 1, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) | |
random_bg = torch.nn.functional.interpolate(random_bg.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) | |
iv_mask_pred = iv_mask_pred.unsqueeze(2).repeat(1, 1, 3, 1, 1) | |
iv_image_pred = iv_image_pred * iv_mask_pred + random_bg * (1. - iv_mask_pred) | |
iv_image_pred = iv_image_pred.squeeze(1) | |
random_values_gt = torch.rand(b) | |
random_bg_gt = random_values_gt[..., None, None, None, None].repeat(1, 1, 3, self.few_shot_gan_tex_reso, self.few_shot_gan_tex_reso).to(self.device) | |
random_bg_gt = torch.nn.functional.interpolate(random_bg_gt.squeeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').unsqueeze(1) | |
mask_gt = mask_gt.unsqueeze(2).repeat(1, 1, 3, 1, 1) | |
image_gt = image_gt * mask_gt + random_bg_gt * (1. - mask_gt) | |
image_gt = image_gt.squeeze(1) | |
elif background_for_reverse == 'none': | |
iv_image_pred = iv_image_pred.squeeze(1) | |
iv_mask_pred = iv_mask_pred.unsqueeze(2).repeat(1, 1, 3, 1, 1) | |
# image_gt = image_gt * mask_gt + random_bg_gt * (1. - mask_gt) | |
mask_gt = mask_gt.unsqueeze(2).repeat(1, 1, 3, 1, 1) | |
image_gt = image_gt * mask_gt | |
image_gt = image_gt.squeeze(1) | |
else: | |
raise NotImplementedError | |
# image_gt = torch.nn.functional.interpolate(image_gt, (32, 32), mode='bilinear') | |
# image_gt = torch.nn.functional.interpolate(image_gt, (256, 256), mode='bilinear') | |
# we need to let discriminator think this reverse view is Real sample | |
if self.cfgs.get('few_shot_gan_tex_patch', 0) > 0: | |
patch_size = torch.randint(self.few_shot_gan_tex_patch, self.few_shot_gan_tex_patch_max, (1,)).item() | |
# random view | |
image_pred = self.generate_patch_crop(image_pred, mask_pred, patch_size, self.few_shot_gan_tex_patch_num) | |
# input view | |
iv_image_pred = self.generate_patch_crop(iv_image_pred, iv_mask_pred.squeeze(1)[:, 0:1, :, :], patch_size, self.few_shot_gan_tex_patch_num) | |
# gt view | |
image_gt = self.generate_patch_crop(image_gt, mask_gt.squeeze(1)[:, 0:1, :, :], patch_size, self.few_shot_gan_tex_patch_num) | |
return_loss = {} | |
if self.few_shot_gan_tex: | |
# here we compute the fake sample as real loss | |
gan_tex_loss = 0.0 | |
if 'rv' in self.few_shot_gan_tex_fake: | |
d_rv = self.discriminator_texture(image_pred) | |
gan_tex_loss_rv = discriminator_architecture.bce_loss_target(d_rv, 1) | |
gan_tex_loss += gan_tex_loss_rv | |
if 'iv' in self.few_shot_gan_tex_fake: | |
d_iv = self.discriminator_texture(iv_image_pred) | |
gan_tex_loss_iv = discriminator_architecture.bce_loss_target(d_iv, 1) | |
gan_tex_loss += gan_tex_loss_iv | |
return_loss['gan_tex_loss'] = gan_tex_loss | |
if self.few_shot_clip_tex: | |
clip_tex_loss_rv_iv = self.compute_clip_loss(image_pred, iv_image_pred.unsqueeze(1), category='none') | |
clip_tex_loss_rv_gt = self.compute_clip_loss(image_pred, image_gt.unsqueeze(1), category='none') | |
clip_tex_loss = clip_tex_loss_rv_iv['clip_all_loss'] + clip_tex_loss_rv_gt['clip_all_loss'] | |
return_loss['clip_tex_loss'] = clip_tex_loss | |
return_aux = { | |
'gan_tex_render_image': image_pred.clone().clamp(0, 1), | |
'gan_tex_inpview_image': iv_image_pred.clone().clamp(0, 1), | |
'gan_tex_gt_image': image_gt.clone().clamp(0, 1) | |
} | |
with torch.no_grad(): | |
# self.record_image_iv = iv_image_pred.clone().clamp(0, 1) | |
# self.record_image_rv = image_pred.clone().clamp(0, 1) | |
# self.record_image_gt = image_gt.clone().clamp(0, 1) | |
self.record_image_iv = iv_image_pred.clone() | |
self.record_image_rv = image_pred.clone() | |
self.record_image_gt = image_gt.clone() | |
return return_loss, return_aux | |
def compute_mask_distribution_loss(self, category, w2c_pred, shape, prior_shape, texture, dino_pred, im_features, light, class_vector, num_frames, im_features_map, bins=360): | |
delta_angle = 2 * np.pi / bins | |
b = len(shape) | |
if self.random_mask_law == 'batch_swap': | |
# shuffle in predicted poses | |
rand_degree_1 = torch.randperm(int(w2c_pred.shape[0] // 2)) | |
rand_degree_2 = torch.randperm(w2c_pred.shape[0] - int(w2c_pred.shape[0] // 2)) + int(w2c_pred.shape[0] // 2) | |
rand_degree = torch.cat([rand_degree_2, rand_degree_1], dim=0).long().to(w2c_pred.device) | |
w2c = w2c_pred[rand_degree] | |
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) | |
mvp = torch.bmm(proj, w2c) | |
campos = -w2c[:, :3, 3] | |
elif self.random_mask_law == 'batch_swap_noy': | |
# shuffle in predicted poses | |
rand_degree_1 = torch.randperm(int(w2c_pred.shape[0] // 2)) | |
rand_degree_2 = torch.randperm(w2c_pred.shape[0] - int(w2c_pred.shape[0] // 2)) + int(w2c_pred.shape[0] // 2) | |
rand_degree = torch.cat([rand_degree_2, rand_degree_1], dim=0).long().to(w2c_pred.device) | |
w2c = w2c_pred[rand_degree] | |
# we don't random swap the y-translation in discriminator loss | |
w2c[:, 1, 3] = w2c_pred[:, 1, 3] | |
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) | |
mvp = torch.bmm(proj, w2c) | |
campos = -w2c[:, :3, 3] | |
elif self.random_mask_law == 'random_azimuth': | |
# the render rotation matrix is different | |
rand_degree = torch.randint(bins, [b]) | |
delta_angle = delta_angle * rand_degree | |
delta_rot_matrix = [] | |
for i in range(b): | |
angle = delta_angle[i].item() | |
angle_matrix = torch.FloatTensor([ | |
[np.cos(angle), 0, np.sin(angle), 0], | |
[0, 1, 0, 0], | |
[-np.sin(angle), 0, np.cos(angle), 0], | |
[0, 0, 0, 1], | |
]).to(self.device) | |
delta_rot_matrix.append(angle_matrix) | |
delta_rot_matrix = torch.stack(delta_rot_matrix, dim=0) | |
w2c = torch.FloatTensor(np.diag([1., 1., 1., 1])) | |
w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.4]) | |
w2c = w2c.repeat(b, 1, 1).to(self.device) | |
# use the predicted transition | |
w2c_pred = w2c_pred.detach() | |
w2c[:, :3, 3] = w2c_pred[:b][:, :3, 3] | |
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) | |
mvp = torch.bmm(proj, w2c) | |
campos = -w2c[:, :3, 3] | |
mvp = torch.matmul(mvp, delta_rot_matrix) | |
campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0] | |
elif self.random_mask_law == 'random_all': | |
# the render rotation matrix is different, and actually the translation are just pre-set | |
rand_degree = torch.randint(bins, [b]) | |
delta_angle = delta_angle * rand_degree | |
delta_rot_matrix = [] | |
for i in range(b): | |
angle = delta_angle[i].item() | |
angle_matrix = torch.FloatTensor([ | |
[np.cos(angle), 0, np.sin(angle), 0], | |
[0, 1, 0, 0], | |
[-np.sin(angle), 0, np.cos(angle), 0], | |
[0, 0, 0, 1], | |
]).to(self.device) | |
delta_rot_matrix.append(angle_matrix) | |
delta_rot_matrix = torch.stack(delta_rot_matrix, dim=0) | |
w2c = torch.FloatTensor(np.diag([1., 1., 1., 1])) | |
w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.4]) | |
w2c = w2c.repeat(b, 1, 1).to(self.device) | |
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) | |
mvp = torch.bmm(proj, w2c) | |
campos = -w2c[:, :3, 3] | |
mvp = torch.matmul(mvp, delta_rot_matrix) | |
campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0] | |
else: | |
raise NotImplementedError | |
resolution = (self.out_image_size, self.out_image_size) | |
# render the articulated shape | |
mesh = shape | |
if self.enable_clip: | |
resolution = (self.clip_render_size, self.clip_render_size) | |
set_requires_grad(texture, False) | |
image_pred, mask_pred, _, _, _, _ = self.render( | |
mesh, | |
texture, | |
mvp, | |
w2c, | |
campos, | |
resolution, | |
background='none', | |
im_features=im_features, | |
light=light, | |
prior_shape=prior_shape, | |
render_flow=False, | |
dino_pred=dino_pred, | |
spp=self.renderer_spp, | |
class_vector=class_vector, | |
render_mode='diffuse', | |
two_sided_shading=False, | |
num_frames=num_frames, | |
im_features_map=im_features_map | |
) | |
if resolution[0] != self.out_image_size: | |
image_pred = torch.nn.functional.interpolate(image_pred, (self.out_image_size, self.out_image_size), mode='bilinear') | |
mask_pred = torch.nn.functional.interpolate(mask_pred.unsqueeze(1), (self.out_image_size, self.out_image_size), mode='bilinear').squeeze(1) | |
else: | |
_, mask_pred, _, _, _, _ = self.render( | |
mesh, | |
None, | |
mvp, | |
w2c, | |
campos, | |
resolution, | |
background='none', | |
im_features=None, | |
light=None, | |
prior_shape=prior_shape, | |
render_flow=False, | |
dino_pred=None, | |
class_vector=class_vector, | |
render_mode='diffuse', | |
two_sided_shading=False, | |
num_frames=num_frames, | |
im_features_map=None | |
) | |
image_pred = None | |
# TODO: disable mask distribution and isolate mask discriminator loss | |
# mask_distribution = self.class_mask_distribution[category] | |
# mask_distribution = torch.Tensor(mask_distribution).to(self.device).unsqueeze(0).repeat(b, 1, 1) | |
mask_distribution = torch.Tensor(self.class_mask_distribution["zebra"]).to(self.device).unsqueeze(0).repeat(b, 1, 1) | |
if self.mask_distribution_average: | |
# if use mask_distribution_average, then first average across batch then compute the loss | |
mask_pred = mask_pred.mean(dim=0).unsqueeze(0).repeat(b, 1, 1) | |
mask_pred = mask_pred.clamp(0,1) | |
mask_distribution = mask_distribution.clamp(0,1) | |
distribution_loss = torch.nn.functional.binary_cross_entropy(mask_pred, mask_distribution) | |
out_loss = {'mask_distribution_loss': 0 * distribution_loss} | |
out_aux = { | |
'mask_random_pred': mask_pred.unsqueeze(1), | |
'mask_distribution': mask_distribution.unsqueeze(1), | |
'rand_degree': rand_degree | |
} | |
if self.enable_clip: | |
out_aux.update({'random_render_image': image_pred}) | |
return out_loss, out_aux | |
def use_line_correct_valid_mask(self, mask_valid, p1, p2, mvp, mask_gt): | |
line = torch.cat([p1.unsqueeze(-2), p2.unsqueeze(-2)], dim=-2) # [B, 2, 3] | |
line_world4 = torch.cat([line, torch.ones_like(line[..., :1])], -1) | |
line_clip4 = line_world4 @ mvp.transpose(-1, -2) | |
line_uv = line_clip4[..., :2] / line_clip4[..., 3:4] | |
line_uv = line_uv.detach() | |
b, _, n_uv = line_uv.shape | |
line_uv = line_uv * torch.Tensor([mask_valid.shape[-2] // 2, mask_valid.shape[-1] // 2]).to(line_uv.device).unsqueeze(0).unsqueeze(-1).repeat(b, 1, n_uv) | |
line_uv = line_uv + torch.Tensor([mask_valid.shape[-2] // 2, mask_valid.shape[-1] // 2]).to(line_uv.device).unsqueeze(0).unsqueeze(-1).repeat(b, 1, n_uv) | |
from pdb import set_trace; set_trace() | |
line_slope = (line_uv[:, 0, 1] - line_uv[:, 1, 1]) / (line_uv[:, 0, 0] - line_uv[:, 1, 0]) | |
uv = np.mgrid[0:mask_valid.shape[-2], 0:mask_valid.shape[-1]].astype(np.int32) | |
uv = torch.from_numpy(np.flip(uv, axis=0).copy()).float().unsqueeze(0).repeat(b, 1, 1, 1) # [B, 2, 256, 256] | |
tmp_u = uv[:, 0, ...][mask_gt[:, 0, ...].bool()] | |
tmp_v = uv[:, 1, ...][mask_gt[:, 0, ...].bool()] | |
return mask_valid | |
def discriminator_step(self): | |
mask_gt = self.record_mask_gt | |
mask_pred = self.record_mask_iv | |
mask_random_pred = self.record_mask_rv | |
self.optimizerDiscriminator.zero_grad() | |
# the random view mask are False | |
d_random_pred = self.mask_disc(mask_random_pred) | |
disc_loss = discriminator_architecture.bce_loss_target(d_random_pred, 0) # in gen loss, train it to be real | |
grad_loss = 0.0 | |
count = 1 | |
discriminator_loss_rv = disc_loss.detach() | |
discriminator_loss_gt = 0.0 | |
discriminator_loss_iv = 0. | |
d_gt = None | |
d_iv = None | |
if self.disc_gt: | |
mask_gt.requires_grad_() | |
d_gt = self.mask_disc(mask_gt) | |
if d_gt.requires_grad is False: | |
# in the test case | |
disc_gt_loss = discriminator_architecture.bce_loss_target(d_gt, 1) | |
else: | |
grad_penalty = self.disc_reg_mul * discriminator_architecture.compute_grad2(d_gt, mask_gt) | |
disc_gt_loss = discriminator_architecture.bce_loss_target(d_gt, 1) + grad_penalty | |
grad_loss += grad_penalty | |
disc_loss = disc_loss + disc_gt_loss | |
discriminator_loss_gt = disc_gt_loss | |
count = count + 1 | |
if self.disc_iv: | |
mask_pred.requires_grad_() | |
d_iv = self.mask_disc(mask_pred) | |
if self.disc_iv_label == 'Real': | |
if d_iv.requires_grad is False: | |
# in the test case | |
disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 1) | |
else: | |
grad_penalty = self.disc_reg_mul * discriminator_architecture.compute_grad2(d_iv, mask_pred) | |
disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 1) + grad_penalty | |
grad_loss += grad_penalty | |
else: | |
disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 0) | |
disc_loss = disc_loss + disc_iv_loss | |
count = count + 1 | |
discriminator_loss_iv = disc_iv_loss | |
disc_loss = disc_loss / count | |
grad_loss = grad_loss / count | |
self.discriminator_loss = disc_loss * self.discriminator_loss_weight | |
self.discriminator_loss.backward() | |
self.optimizerDiscriminator.step() | |
self.discriminator_loss = 0. | |
return { | |
'discriminator_loss': disc_loss, | |
'discriminator_loss_rv': discriminator_loss_rv, | |
'discriminator_loss_iv': discriminator_loss_iv, | |
'discriminator_loss_gt': discriminator_loss_gt, | |
'd_rv': d_random_pred, | |
'd_iv': d_iv if d_iv is not None else None, | |
'd_gt': d_gt if d_gt is not None else None, | |
}, grad_loss | |
def compute_mask_disc_loss_gen(self, mask_gt, mask_pred, mask_random_pred, category_name=None, condition_feat=None): | |
# mask_gt[mask_gt < 1.] = 0. | |
# mask_pred[mask_pred > 0.] = 1. | |
# mask_random_pred[mask_random_pred > 0.] = 1. | |
if not self.mask_disc_feat_condition: | |
try: | |
class_idx = list(self.netPrior.category_id_map.keys()).index(category_name) | |
except: | |
class_idx = 100 | |
num_classes = len(list(self.netPrior.category_id_map.keys())) | |
class_idx = torch.LongTensor([class_idx]) | |
# class_one_hot = torch.nn.functional.one_hot(class_idx, num_classes=7).unsqueeze(-1).unsqueeze(-1).to(mask_gt.device) # [1, 7, 1, 1] | |
class_one_hot = torch.nn.functional.one_hot(class_idx, num_classes=num_classes).unsqueeze(-1).unsqueeze(-1).to(mask_gt.device) | |
class_one_hot = class_one_hot.repeat(mask_gt.shape[0], 1, mask_gt.shape[-2], mask_gt.shape[-1]) | |
# TODO: a hack try here | |
class_one_hot = class_one_hot[:, :(self.mask_disc.in_dim-1), :, :] | |
else: | |
class_one_hot = condition_feat.detach() | |
class_one_hot = class_one_hot.reshape(1, -1, 1, 1).repeat(mask_gt.shape[0], 1, mask_gt.shape[-2], mask_gt.shape[-1]) | |
# concat | |
mask_gt = torch.cat([mask_gt, class_one_hot], dim=1) | |
mask_pred = torch.cat([mask_pred, class_one_hot], dim=1) | |
mask_random_pred = torch.cat([mask_random_pred, class_one_hot], dim=1) | |
# mask shape are all [B,1,256,256] | |
# the random view mask are False | |
d_random_pred = self.mask_disc(mask_random_pred) | |
disc_loss = discriminator_architecture.bce_loss_target(d_random_pred, 1) # in gen loss, train it to be real | |
count = 1 | |
disc_loss_rv = disc_loss.detach() | |
disc_loss_iv = 0.0 | |
if self.disc_iv: | |
if self.disc_iv_label != 'Real': # consider the input view also fake | |
d_iv = self.mask_disc(mask_pred) | |
disc_iv_loss = discriminator_architecture.bce_loss_target(d_iv, 1) # so now we need to train them to be real | |
disc_loss = disc_loss + disc_iv_loss | |
count = count + 1 | |
disc_loss_iv = disc_iv_loss.detach() | |
disc_loss = disc_loss / count | |
# record the masks for discriminator training | |
self.record_mask_gt = mask_gt.clone().detach() | |
self.record_mask_iv = mask_pred.clone().detach() | |
self.record_mask_rv = mask_random_pred.clone().detach() | |
return { | |
'mask_disc_loss': disc_loss, | |
'mask_disc_loss_rv': disc_loss_rv, | |
'mask_disc_loss_iv': disc_loss_iv, | |
} | |
def forward(self, batch, epoch, iter, is_train=True, viz_logger=None, total_iter=None, save_results=False, save_dir=None, which_data='', logger_prefix='', is_training=True, bank_embedding=None): | |
batch = [x.to(self.device) if x is not None and isinstance(x, torch.Tensor) else x for x in batch] | |
input_image, mask_gt, mask_dt, mask_valid, flow_gt, bbox, bg_image, dino_feat_im, dino_cluster_im, seq_idx, frame_idx, category_name = batch | |
# if save_results: | |
# save_for_pkl = { | |
# "image": input_image.cpu(), | |
# "mask_gt": mask_gt.cpu(), | |
# "mask_dt": mask_dt.cpu(), | |
# "mask_valid": mask_valid.cpu(), | |
# "flow_gt": None, | |
# "bbox": bbox.cpu(), | |
# "bg_image": bg_image.cpu(), | |
# "dino_feat_im": dino_feat_im.cpu(), | |
# "dino_cluster_im": dino_cluster_im.cpu(), | |
# "seq_idx": seq_idx.cpu(), | |
# "frame_idx": frame_idx.cpu(), | |
# "category_name": category_name | |
# } | |
batch_size, num_frames, _, h0, w0 = input_image.shape # BxFxCxHxW | |
self.bs = batch_size | |
self.nf = num_frames | |
mid_img_idx = int((input_image.shape[1]-1)//2) | |
# print(f"mid_img_idx: {mid_img_idx}") | |
h = w = self.out_image_size | |
def collapseF(x): | |
return None if x is None else x.view(batch_size * num_frames, *x.shape[2:]) | |
def expandF(x): | |
return None if x is None else x.view(batch_size, num_frames, *x.shape[1:]) | |
if flow_gt.dim() == 2: # dummy tensor for not loading flow | |
flow_gt = None | |
if dino_cluster_im.dim() == 2: # dummy tensor for not loading dino clusters | |
dino_cluster_im = None | |
dino_cluster_im_gt = None | |
else: | |
dino_cluster_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_cluster_im), size=[h, w], mode="nearest")) | |
seq_idx = seq_idx.squeeze(1) | |
# seq_idx = seq_idx * 0 # single sequnce model | |
frame_id, crop_x0, crop_y0, crop_w, crop_h, full_w, full_h, sharpness, label = bbox.unbind(2) # BxFx7 | |
bbox = torch.stack([crop_x0, crop_y0, crop_w, crop_h], 2) | |
mask_gt = (mask_gt[:, :, 0, :, :] > 0.9).float() # BxFxHxW | |
mask_dt = mask_dt / self.in_image_size | |
if which_data != 'video': | |
flow_gt = None | |
aux_viz = {} | |
## GT | |
image_gt = input_image | |
if self.out_image_size != self.in_image_size: | |
image_gt = expandF(torch.nn.functional.interpolate(collapseF(image_gt), size=[h, w], mode='bilinear')) | |
if flow_gt is not None: | |
flow_gt = torch.nn.functional.interpolate(flow_gt.view(batch_size*(num_frames-1), 2, h0, w0), size=[h, w], mode="bilinear").view(batch_size, num_frames-1, 2, h, w) | |
self.train_pose_only = False | |
if epoch in self.pose_epochs: | |
if (total_iter // self.pose_iters) % 2 == 0: | |
self.train_pose_only = True | |
## flip input and pose | |
if epoch in self.pose_xflip_recon_epochs: | |
input_image_xflip = input_image.flip(-1) | |
input_image_xflip_flag = torch.randint(0, 2, (batch_size, num_frames), device=input_image.device) | |
input_image = input_image * (1 - input_image_xflip_flag[:,:,None,None,None]) + input_image_xflip * input_image_xflip_flag[:,:,None,None,None] | |
else: | |
input_image_xflip_flag = None | |
## 1st pose hypothesis with original predictions | |
# ============================================================================================== | |
# Predict prior mesh. | |
# ============================================================================================== | |
if self.enable_prior: | |
if self.world_size > 1: | |
if epoch < self.dmtet_grid_smaller_epoch: | |
if self.netPrior_ddp.module.netShape.grid_res != self.dmtet_grid_smaller: | |
self.netPrior_ddp.module.netShape.load_tets(self.dmtet_grid_smaller) | |
else: | |
if self.netPrior_ddp.module.netShape.grid_res != self.dmtet_grid: | |
self.netPrior_ddp.module.netShape.load_tets(self.dmtet_grid) | |
else: | |
if epoch < self.dmtet_grid_smaller_epoch: | |
if self.netPrior.netShape.grid_res != self.dmtet_grid_smaller: | |
self.netPrior.netShape.load_tets(self.dmtet_grid_smaller) | |
else: | |
if self.netPrior.netShape.grid_res != self.dmtet_grid: | |
self.netPrior.netShape.load_tets(self.dmtet_grid) | |
perturb_sdf = self.perturb_sdf if is_train else False | |
# DINO prior category specific - DOR | |
if self.world_size > 1: | |
prior_shape, dino_pred, classes_vectors = self.netPrior_ddp(category_name=category_name[0], perturb_sdf=perturb_sdf, total_iter=total_iter, is_training=is_training, class_embedding=bank_embedding) | |
else: | |
prior_shape, dino_pred, classes_vectors = self.netPrior(category_name=category_name[0], perturb_sdf=perturb_sdf, total_iter=total_iter, is_training=is_training, class_embedding=bank_embedding) | |
else: | |
prior_shape = None | |
raise NotImplementedError | |
if self.world_size > 1: | |
shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, dino_feat_im_calc, deformation, arti_params, light, forward_aux = self.netInstance_ddp(category_name, input_image, prior_shape, epoch, dino_feat_im, dino_cluster_im, total_iter, is_training=is_training) # frame dim collapsed N=(B*F) | |
else: | |
Instance_out = self.netInstance(category_name, input_image, prior_shape, epoch, dino_feat_im, dino_cluster_im, total_iter, is_training=is_training) # frame dim collapsed N=(B*F) | |
# if no patch_out as output from netInstance, then set im_features_map as None in following part | |
if len(Instance_out) == 13: | |
shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, dino_feat_im_calc, deformation, arti_params, light, forward_aux = Instance_out | |
im_features_map = None | |
else: | |
shape, pose_raw, pose, mvp, w2c, campos, texture, im_features, dino_feat_im_calc, deformation, arti_params, light, forward_aux, im_features_map = Instance_out | |
# if save_results: | |
# save_for_pkl.update( | |
# { | |
# "pose_raw": pose_raw.cpu(), | |
# "pose": pose.cpu(), | |
# "mvp": mvp.cpu(), | |
# "w2c": w2c.cpu(), | |
# "campos": campos.cpu(), | |
# "campos_z_offset": self.netInstance.cam_pos_z_offset | |
# } | |
# ) | |
if self.calc_dino_features == True: | |
# get the shape parameters of the tensor | |
batch_size, height, width, channels = dino_feat_im_calc.shape #3 X 384 X 32 X 32 | |
# reshape the tensor to have 2 dimensions, with the last dimension being preserved | |
dino_feat_im = dino_feat_im_calc.reshape(batch_size , height, -1) | |
# normalize the tensor using L2 normalization | |
norm = torch.norm(dino_feat_im, dim=-1, keepdim=True) | |
dino_feat_im = dino_feat_im / norm | |
# reshape the tensor back to the original shape with an additional singleton dimension along the first dimension | |
dino_feat_im = dino_feat_im.reshape(batch_size, height, width, channels) | |
dino_feat_im = dino_feat_im.unsqueeze(1) | |
if dino_feat_im.dim() == 2: # dummy tensor for not loading dino features | |
dino_feat_im = None | |
dino_feat_im_gt = None | |
else: | |
dino_feat_im_gt = expandF(torch.nn.functional.interpolate(collapseF(dino_feat_im), size=[h, w], mode="bilinear"))[:, :, :self.dino_feature_recon_dim] | |
rot_logit = forward_aux['rot_logit'] | |
rot_idx = forward_aux['rot_idx'] | |
rot_prob = forward_aux['rot_prob'] | |
if self.using_bonevel_smooth_loss: | |
posed_bones = forward_aux['posed_bones'] | |
else: | |
posed_bones = None | |
aux_viz.update(forward_aux) | |
if self.train_pose_only: | |
safe_detach = lambda x: x.detach() if x is not None else None | |
prior_shape = safe_detach(prior_shape) | |
shape = safe_detach(shape) | |
im_features = safe_detach(im_features) | |
arti_params = safe_detach(arti_params) | |
deformation = safe_detach(deformation) | |
set_requires_grad(texture, False) | |
set_requires_grad(light, False) | |
set_requires_grad(dino_pred, False) | |
else: | |
set_requires_grad(texture, True) | |
set_requires_grad(light, True) | |
set_requires_grad(dino_pred, True) | |
render_flow = self.render_flow and num_frames > 1 #false | |
# from IPython import embed; embed() | |
# if num_frames > 1 and self.smooth_type == 'rend': | |
# print("rendererr smoothness !!!!") | |
# image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading = self.render(shape, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=im_features[torch.randperm(im_features.size(0))], light=light, prior_shape=prior_shape, render_flow=render_flow, dino_pred=dino_pred, num_frames=num_frames, spp=self.renderer_spp) #the real rendering process | |
# else: | |
# print("regular render") | |
#print("a cecond before rendering .... need to get the correct label and the correct vector") | |
#print("label", label) | |
#print("classes_vectors", classes_vectors) | |
#print("im_features", im_features.shape) | |
class_vector = None | |
if classes_vectors is not None: | |
if len(classes_vectors.shape) == 1: | |
class_vector = classes_vectors | |
else: | |
class_vector = classes_vectors[self.netPrior.category_id_map[category_name[0]], :] | |
image_pred, mask_pred, flow_pred, dino_feat_im_pred, albedo, shading = self.render(shape, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=im_features, light=light, prior_shape=prior_shape, render_flow=render_flow, dino_pred=dino_pred, class_vector=class_vector[None, :].expand(batch_size * num_frames, -1), num_frames=num_frames, spp=self.renderer_spp, im_features_map=im_features_map) #the real rendering process | |
image_pred, mask_pred, flow_pred, dino_feat_im_pred = map(expandF, (image_pred, mask_pred, flow_pred, dino_feat_im_pred)) | |
if flow_pred is not None: | |
flow_pred = flow_pred[:, :-1] # Bx(F-1)x2xHxW | |
if self.blur_mask: | |
sigma = max(0.5, 3 * (1 - total_iter / self.blur_mask_iter)) | |
if sigma > 0.5: | |
mask_gt = util.blur_image(mask_gt, kernel_size=9, sigma=sigma, mode='gaussian') | |
# mask_pred = util.blur_image(mask_pred, kernel_size=7, mode='average') | |
# back_line_p1 = forward_aux['posed_bones'][:, :, 3, -1].squeeze(1) # [8, 3] | |
# back_line_p2 = forward_aux['posed_bones'][:, :, 7, -1].squeeze(1) | |
# mask_valid = self.use_line_correct_valid_mask(mask_valid, back_line_p1, back_line_p2, mvp, mask_gt) | |
losses = self.compute_reconstruction_losses(image_pred, image_gt, mask_pred, mask_gt, mask_dt, mask_valid, flow_pred, flow_gt, dino_feat_im_gt, dino_feat_im_pred, background_mode=self.background_mode, reduce=False) | |
## TODO: assume flow loss is not used | |
logit_loss_target = torch.zeros_like(expandF(rot_logit)) | |
final_losses = {} | |
for name, loss in losses.items(): | |
if name == 'flow_loss': | |
continue | |
loss_weight_logit = self.cfgs.get(f"{name}_weight", 0.) | |
if isinstance(loss_weight_logit, dict): | |
loss_weight_logit = self.parse_dict_definition(loss_weight_logit, total_iter) | |
# from IPython import embed; embed() | |
# print("-"*10) | |
# print(f"{name}_weight: {loss_weight_logit}.") | |
# print(f"logit_loss_target.shape: {logit_loss_target.shape}.") | |
# print(f"loss.shape: {loss.shape}.") | |
# if (name in ['flow_loss'] and epoch not in self.flow_loss_epochs) or (name in ['rgb_loss', 'perceptual_loss'] and epoch not in self.texture_epochs): | |
# if name in ['flow_loss', 'rgb_loss', 'perceptual_loss']: | |
# loss_weight_logit = 0. | |
if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']: | |
if total_iter >= self.sdf_reg_decay_start_iter: | |
decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000) | |
loss_weight_logit = max(loss_weight_logit * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) | |
if name in ['dino_feat_im_loss']: | |
dino_feat_im_loss_multipler = self.cfgs.get("logit_loss_dino_feat_im_loss_multiplier", 1.) | |
if isinstance(dino_feat_im_loss_multipler, dict): | |
dino_feat_im_loss_multipler = self.parse_dict_definition(dino_feat_im_loss_multipler, total_iter) | |
loss_weight_logit = loss_weight_logit * dino_feat_im_loss_multipler | |
# loss_weight_logit = loss_weight_logit * self.cfgs.get("logit_loss_dino_feat_im_loss_multiplier", 1.) | |
if loss_weight_logit > 0: | |
logit_loss_target += loss * loss_weight_logit | |
if self.netInstance.rot_rep in ['quadlookat', 'octlookat']: | |
loss = loss * rot_prob.detach().view(batch_size, num_frames)[:, :loss.shape[1]] *self.netInstance.num_pose_hypos | |
if name == 'flow_loss' and num_frames > 1: | |
ri = rot_idx.view(batch_size, num_frames) | |
same_rot_idx = (ri[:, 1:] == ri[:, :-1]).float() | |
loss = loss * same_rot_idx | |
final_losses[name] = loss.mean() | |
final_losses['logit_loss'] = ((expandF(rot_logit) - logit_loss_target.detach())**2.).mean() | |
## mask distribution loss | |
mask_distribution_aux = None | |
if self.enable_mask_distribution: | |
if total_iter % self.mask_distribution_loss_freq == 0: | |
mask_distribution_loss, mask_distribution_aux = self.compute_mask_distribution_loss(category_name[0], w2c, shape, prior_shape, texture, dino_pred, im_features, light, class_vector[None, :].expand(batch_size * num_frames, -1), num_frames, im_features_map) | |
final_losses.update(mask_distribution_loss) | |
# this also follows the iteration frequency | |
if self.enable_clip: | |
random_render_image = mask_distribution_aux["random_render_image"] | |
clip_all_loss = self.compute_clip_loss(random_render_image, image_pred, category_name[0]) # a dict | |
final_losses.update(clip_all_loss) | |
# implement the mask discriminator | |
if self.enable_disc and (self.mask_discriminator_iter[0] < total_iter) and (self.mask_discriminator_iter[1] > total_iter): | |
disc_loss = self.compute_mask_disc_loss_gen(mask_gt, mask_pred, mask_distribution_aux['mask_random_pred'], category_name=category_name[0], condition_feat=class_vector) | |
final_losses.update(disc_loss) | |
# implement the gan training for local texture in fine-tuning | |
gan_tex_aux = None | |
if (self.few_shot_gan_tex and viz_logger is None) or (self.few_shot_gan_tex and viz_logger is not None and logger_prefix == 'train_'): | |
gan_tex_loss, gan_tex_aux = self.compute_gan_tex_loss(category_name[0], image_gt, mask_gt, image_pred, mask_pred, w2c, campos, shape, prior_shape, texture, dino_pred, im_features, light, class_vector[None, :].expand(batch_size * num_frames, -1), num_frames, im_features_map) | |
final_losses.update(gan_tex_loss) | |
# implement the memory bank related loss | |
if bank_embedding is not None: | |
batch_embedding = bank_embedding[0] # [d] | |
embeddings = bank_embedding[1] # [B, d] | |
bank_mean_dist = torch.nn.functional.mse_loss(embeddings, batch_embedding.unsqueeze(0).repeat(batch_size, 1)) | |
final_losses.update({'bank_mean_dist_loss': bank_mean_dist}) | |
## regularizers | |
regularizers, aux = self.compute_regularizers(shape, prior_shape, input_image, dino_feat_im, pose_raw, input_image_xflip_flag, arti_params, deformation, mid_img_idx, posed_bones=posed_bones, class_vector=class_vector.detach() if class_vector is not None else None) | |
final_losses.update(regularizers) | |
aux_viz.update(aux) | |
total_loss = 0 | |
for name, loss in final_losses.items(): | |
loss_weight = self.cfgs.get(f"{name}_weight", 0.) | |
if isinstance(loss_weight, dict): | |
loss_weight = self.parse_dict_definition(loss_weight, total_iter) | |
if loss_weight <= 0: | |
continue | |
if self.train_pose_only: | |
if name not in ['silhouette_loss', 'silhouette_dt_loss', 'silhouette_inv_dt_loss', 'flow_loss', 'pose_xflip_reg_loss', 'lookat_zflip_loss', 'dino_feat_im_loss']: | |
continue | |
if epoch not in self.flow_loss_epochs: | |
if name in ['flow_loss']: | |
continue | |
if epoch not in self.texture_epochs: | |
if name in ['rgb_loss', 'perceptual_loss']: | |
continue | |
if epoch not in self.lookat_zflip_loss_epochs: | |
if name in ['lookat_zflip_loss']: | |
continue | |
if name in ['mesh_laplacian_smoothing_loss', 'mesh_normal_consistency_loss']: | |
if total_iter < self.cfgs.get('mesh_reg_start_iter', 0): | |
continue | |
if epoch >= self.mesh_reg_decay_epoch: | |
decay_rate = self.mesh_reg_decay_rate ** (epoch - self.mesh_reg_decay_epoch) | |
loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) | |
if epoch not in self.sdf_inflate_reg_loss_epochs: | |
if name in ['sdf_inflate_reg_loss']: | |
continue | |
if self.iter_arti_reg_loss_start is not None: | |
if total_iter <= self.iter_arti_reg_loss_start: | |
if name in ['arti_reg_loss']: | |
continue | |
else: | |
if epoch not in self.arti_reg_loss_epochs: | |
if name in ['arti_reg_loss']: | |
continue | |
if name in ['sdf_bce_reg_loss', 'sdf_gradient_reg_loss', 'sdf_inflate_reg_loss']: | |
if total_iter >= self.sdf_reg_decay_start_iter: | |
decay_rate = max(0, 1 - (total_iter-self.sdf_reg_decay_start_iter) / 10000) | |
loss_weight = max(loss_weight * decay_rate, self.cfgs.get(f"{name}_min_weight", 0.)) | |
total_loss += loss * loss_weight | |
self.total_loss += total_loss # reset to 0 in backward step | |
if torch.isnan(self.total_loss): | |
print("NaN in loss...") | |
import ipdb; ipdb.set_trace() | |
final_losses['logit_loss_target'] = logit_loss_target.mean() | |
metrics = {'loss': total_loss, **final_losses} | |
## log visuals | |
if viz_logger is not None: | |
b0 = max(min(batch_size, 16//num_frames), 1) | |
viz_logger.add_image(logger_prefix+'image/image_gt', misc.image_grid(image_gt.detach().cpu()[:b0,:].reshape(-1,*input_image.shape[2:]).clamp(0,1)), total_iter) | |
viz_logger.add_image(logger_prefix+'image/image_pred', misc.image_grid(image_pred.detach().cpu()[:b0,:].reshape(-1,*image_pred.shape[2:]).clamp(0,1)), total_iter) | |
# viz_logger.add_image(logger_prefix+'image/flow_loss_mask', misc.image_grid(flow_loss_mask[:b0,:,:1].reshape(-1,1,*flow_loss_mask.shape[3:]).repeat(1,3,1,1).clamp(0,1)), total_iter) | |
viz_logger.add_image(logger_prefix+'image/mask_gt', misc.image_grid(mask_gt.detach().cpu()[:b0,:].reshape(-1,*mask_gt.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter) | |
viz_logger.add_image(logger_prefix+'image/mask_pred', misc.image_grid(mask_pred.detach().cpu()[:b0,:].reshape(-1,*mask_pred.shape[2:]).unsqueeze(1).repeat(1,3,1,1).clamp(0,1)), total_iter) | |
if self.render_flow and flow_gt is not None: | |
# if False: | |
flow_gt = flow_gt.detach().cpu() | |
flow_gt_viz = torch.cat([flow_gt[:b0], torch.zeros_like(flow_gt[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 | |
flow_gt_viz = torch.nn.functional.pad(flow_gt_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) | |
# ## draw marker on large flow frames | |
# large_flow_marker_mask = torch.zeros_like(flow_gt_viz) | |
# large_flow_marker_mask[:,:,:,:8,:8] = 1. | |
# large_flow = torch.cat([self.large_flow, self.large_flow[:,:1] *0.], 1).detach().cpu()[:b0] | |
# large_flow_marker_mask = large_flow_marker_mask * large_flow[:,:,None,None,None] | |
# red = torch.FloatTensor([1,0,0])[None,None,:,None,None] | |
# flow_gt_viz = large_flow_marker_mask * red + (1-large_flow_marker_mask) * flow_gt_viz | |
viz_logger.add_image(logger_prefix+'image/flow_gt', misc.image_grid(flow_gt_viz.reshape(-1,*flow_gt_viz.shape[2:])), total_iter) | |
if self.render_flow and flow_pred is not None: | |
# if False | |
flow_pred = flow_pred.detach().cpu() | |
flow_pred_viz = torch.cat([flow_pred[:b0], torch.zeros_like(flow_pred[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 | |
flow_pred_viz = torch.nn.functional.pad(flow_pred_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) | |
viz_logger.add_image(logger_prefix+'image/flow_pred', misc.image_grid(flow_pred_viz.reshape(-1,*flow_pred_viz.shape[2:])), total_iter) | |
if sds_random_images is not None: | |
viz_logger.add_image( | |
logger_prefix + 'image/sds_image', | |
self.vis_sds_image(sds_random_images, sds_aux), | |
total_iter) | |
viz_logger.add_image( | |
logger_prefix + 'image/sds_grad', | |
self.vis_sds_grads(sds_aux), total_iter) | |
if mask_distribution_aux is not None: | |
degree_text = mask_distribution_aux['rand_degree'] | |
mask_random_pred = mask_distribution_aux['mask_random_pred'].detach().cpu().clamp(0, 1) | |
mask_distribution_data = mask_distribution_aux['mask_distribution'].detach().cpu().clamp(0, 1) | |
mask_random_pred_image = [misc.add_text_to_image(img, str(text.item())) for img, text in zip(mask_random_pred, degree_text)] | |
mask_random_pred_image = misc.image_grid(mask_random_pred_image) | |
mask_distribution_image = misc.image_grid(mask_distribution_data) | |
viz_logger.add_image( | |
logger_prefix + 'image/mask_random_pred', | |
mask_random_pred_image, | |
total_iter) | |
viz_logger.add_image( | |
logger_prefix + 'image/mask_distribution', | |
mask_distribution_image, | |
total_iter) | |
if gan_tex_aux is not None: | |
gan_tex_render_image = gan_tex_aux['gan_tex_render_image'].detach().cpu().clamp(0, 1) | |
gan_tex_render_image = misc.image_grid(gan_tex_render_image) | |
viz_logger.add_image( | |
logger_prefix + 'image/gan_tex_render_image', | |
gan_tex_render_image, | |
total_iter) | |
gan_tex_render_image_iv = gan_tex_aux['gan_tex_inpview_image'].detach().cpu().clamp(0, 1) | |
gan_tex_render_image_iv = misc.image_grid(gan_tex_render_image_iv) | |
viz_logger.add_image( | |
logger_prefix + 'image/gan_tex_inpview_image', | |
gan_tex_render_image_iv, | |
total_iter) | |
gan_tex_render_image_gt = gan_tex_aux['gan_tex_gt_image'].detach().cpu().clamp(0, 1) | |
gan_tex_render_image_gt = misc.image_grid(gan_tex_render_image_gt) | |
viz_logger.add_image( | |
logger_prefix + 'image/gan_tex_gt_image', | |
gan_tex_render_image_gt, | |
total_iter) | |
# if self.render_flow and flow_gt is not None and flow_pred is not None: | |
# flow_gt = flow_gt.detach().cpu() | |
# # flow_gt_viz = torch.cat([flow_gt[:b0], torch.zeros_like(flow_gt[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 | |
# # flow_gt_viz = torch.nn.functional.pad(flow_gt_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) | |
# # ## draw marker on large flow frames | |
# # large_flow_marker_mask = torch.zeros_like(flow_gt_viz) | |
# # large_flow_marker_mask[:,:,:,:8,:8] = 1. | |
# # large_flow = torch.cat([self.large_flow, self.large_flow[:,:1] *0.], 1).detach().cpu()[:b0] | |
# # large_flow_marker_mask = large_flow_marker_mask * large_flow[:,:,None,None,None] | |
# # red = torch.FloatTensor([1,0,0])[None,None,:,None,None] | |
# # flow_gt_viz = large_flow_marker_mask * red + (1-large_flow_marker_mask) * flow_gt_viz | |
# # viz_logger.add_image(logger_prefix+'image/flow_gt', misc.image_grid(flow_gt_viz.reshape(-1,*flow_gt_viz.shape[2:])), total_iter) | |
# flow_pred = flow_pred.detach().cpu() | |
# # flow_pred_viz = torch.cat([flow_pred[:b0], torch.zeros_like(flow_pred[:b0,:,:1])], 2) + 0.5 # -0.5~1.5 | |
# # flow_pred_viz = torch.nn.functional.pad(flow_pred_viz, pad=[0, 0, 0, 0, 0, 0, 0, 1]) | |
# flow_gt_pred = torch.cat([flow_gt, flow_pred], dim=-1) | |
# flow_gt_pred = flow_gt_pred.permute(0,1,3,4,2).detach().cpu().reshape(flow_gt_pred.shape[0]*flow_gt_pred.shape[1],*flow_gt_pred.shape[2:]) | |
# flow_gt_pred = flow_viz.flow_batch_to_images(flow_gt_pred) | |
# # flow_gt_pred = torch.tensor(flow_gt_pred).permute(0,3,1,2) | |
# # viz_logger.add_image(logger_prefix+'image/flow_gt_pred', misc.image_grid(flow_gt_pred.reshape(-1,*flow_gt_pred.shape[2:])), total_iter) | |
# viz_logger.add_image(logger_prefix+'image/flow_gt_pred', misc.image_grid(flow_gt_pred), total_iter) | |
if light is not None: | |
param_names = ['dir_x', 'dir_y', 'dir_z', 'int_ambient', 'int_diffuse'] | |
for name, param in zip(param_names, light.light_params.unbind(-1)): | |
viz_logger.add_histogram(logger_prefix+'light/'+name, param, total_iter) | |
viz_logger.add_image( | |
logger_prefix + f'image/albedo', | |
misc.image_grid(expandF(albedo)[:b0, ...].view(-1, *albedo.shape[1:])), | |
total_iter) | |
viz_logger.add_image( | |
logger_prefix + f'image/shading', | |
misc.image_grid(expandF(shading)[:b0, ...].view(-1, *shading.shape[1:]).repeat(1, 3, 1, 1) /2.), | |
total_iter) | |
viz_logger.add_histogram(logger_prefix+'sdf', self.netPrior.netShape.get_sdf(perturb_sdf=False, class_vector=class_vector), total_iter) | |
viz_logger.add_histogram(logger_prefix+'coordinates', shape.v_pos, total_iter) | |
if arti_params is not None: | |
viz_logger.add_histogram(logger_prefix+'arti_params', arti_params, total_iter) | |
viz_logger.add_histogram(logger_prefix+'edge_lengths', aux_viz['edge_lengths'], total_iter) | |
if deformation is not None: | |
viz_logger.add_histogram(logger_prefix+'deformation', deformation, total_iter) | |
rot_rep = self.netInstance.rot_rep | |
if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': | |
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']): | |
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter) | |
elif rot_rep == 'quaternion': | |
for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']): | |
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose[...,i], total_iter) | |
rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose.detach().cpu()[...,:4]), convention='XYZ') | |
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']): | |
viz_logger.add_histogram(logger_prefix+'pose/'+name, rot_euler[...,i], total_iter) | |
elif rot_rep in ['lookat', 'quadlookat', 'octlookat']: | |
for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']): | |
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,i], total_iter) | |
for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']): | |
viz_logger.add_histogram(logger_prefix+'pose/'+name, pose_raw[...,-3+i], total_iter) | |
if rot_rep in ['quadlookat', 'octlookat']: | |
for i, rp in enumerate(forward_aux['rots_probs'].unbind(-1)): | |
viz_logger.add_histogram(logger_prefix+'pose/rot_prob_%d'%i, rp, total_iter) | |
if bank_embedding is not None: | |
weights_for_emb = bank_embedding[2]['weights'] # [B, k] | |
for i, weight_for_emb in enumerate(weights_for_emb.unbind(-1)): | |
viz_logger.add_histogram(logger_prefix+'bank_embedding/emb_weight_%d'%i, weight_for_emb, total_iter) | |
indices_for_emb = bank_embedding[2]['pick_idx'] # [B, k] | |
for i, idx_for_emb in enumerate(indices_for_emb.unbind(-1)): | |
viz_logger.add_histogram(logger_prefix+'bank_embedding/emb_idx_%d'%i, idx_for_emb, total_iter) | |
if 'pose_xflip_raw' in aux_viz: | |
pose_xflip_raw = aux_viz['pose_xflip_raw'] | |
if rot_rep == 'euler_angle' or rot_rep == 'soft_calss': | |
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z', 'trans_x', 'trans_y', 'trans_z']): | |
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter) | |
elif rot_rep == 'quaternion': | |
for i, name in enumerate(['qt_0', 'qt_1', 'qt_2', 'qt_3', 'trans_x', 'trans_y', 'trans_z']): | |
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip[...,i], total_iter) | |
rot_euler = pytorch3d.transforms.matrix_to_euler_angles(pytorch3d.transforms.quaternion_to_matrix(pose_xflip.detach().cpu()[...,:4]), convention='XYZ') | |
for i, name in enumerate(['rot_x', 'rot_y', 'rot_z']): | |
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, rot_euler[...,i], total_iter) | |
elif rot_rep in ['lookat', 'quadlookat', 'octlookat']: | |
for i, name in enumerate(['fwd_x', 'fwd_y', 'fwd_z']): | |
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,i], total_iter) | |
for i, name in enumerate(['trans_x', 'trans_y', 'trans_z']): | |
viz_logger.add_histogram(logger_prefix+'pose_xflip/'+name, pose_xflip_raw[...,-3+i], total_iter) | |
if dino_feat_im_gt is not None: | |
dino_feat_im_gt_first3 = dino_feat_im_gt[:,:,:3] | |
viz_logger.add_image(logger_prefix+'image/dino_feat_im_gt', misc.image_grid(dino_feat_im_gt_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_gt_first3.shape[2:]).clamp(0,1)), total_iter) | |
if dino_cluster_im_gt is not None: | |
viz_logger.add_image(logger_prefix+'image/dino_cluster_im_gt', misc.image_grid(dino_cluster_im_gt.detach().cpu()[:b0,:].reshape(-1,*dino_cluster_im_gt.shape[2:]).clamp(0,1)), total_iter) | |
if dino_feat_im_pred is not None: | |
dino_feat_im_pred_first3 = dino_feat_im_pred[:,:,:3] | |
viz_logger.add_image(logger_prefix+'image/dino_feat_im_pred', misc.image_grid(dino_feat_im_pred_first3.detach().cpu()[:b0,:].reshape(-1,*dino_feat_im_pred_first3.shape[2:]).clamp(0,1)), total_iter) | |
for which_shape, modes in self.extra_renders.items(): | |
# This is wrong | |
# if which_shape == "prior": | |
# shape_to_render = prior_shape.extend(im_features.shape[0]) | |
# needed_im_features = None | |
if which_shape == "instance": | |
shape_to_render = shape | |
needed_im_features = im_features | |
else: | |
raise NotImplementedError | |
for mode in modes: | |
if mode in ['gray']: | |
gray_light = FixedDirectionLight(direction=torch.FloatTensor([0, 0, 1]).to(self.device), amb=0.2, diff=0.7) | |
_, render_mask, _, _, _, rendered = self.render(shape_to_render, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, render_mode='diffuse', light=gray_light, render_flow=False, dino_pred=None, im_features_map=im_features_map) #renderer for visualization only!!! | |
if self.background_mode == 'white': | |
# we want to render shading here, which is always black background, so modify here | |
render_mask = render_mask.unsqueeze(1) | |
rendered[render_mask == 0] = 1 | |
rendered = rendered.repeat(1, 3, 1, 1) | |
else: | |
rendered, _, _, _, _, _ = self.render(shape_to_render, texture, mvp, w2c, campos, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, render_mode=mode, render_flow=False, dino_pred=None, im_features_map=im_features_map) #renderer for visualization only!!! | |
if 'kd' in mode: | |
rendered = util.rgb_to_srgb(rendered) | |
rendered = rendered.detach().cpu() | |
rendered_wo_bones = rendered | |
if 'posed_bones' in aux_viz: | |
rendered_bone_image = self.render_bones(mvp, aux_viz['posed_bones'], (h, w)) | |
rendered_bone_image_mask = (rendered_bone_image < 1).any(1, keepdim=True).float() | |
# viz_logger.add_image(logger_prefix+'image/articulation_bones', misc.image_grid(self.render_bones(mvp, aux_viz['posed_bones'])), total_iter) | |
rendered = rendered_bone_image_mask*0.8 * rendered_bone_image + (1-rendered_bone_image_mask*0.8) * rendered | |
if rot_rep in ['quadlookat', 'octlookat']: | |
rand_pose_flag = forward_aux['rand_pose_flag'].detach().cpu() | |
rand_pose_marker_mask = torch.zeros_like(rendered) | |
rand_pose_marker_mask[:,:,:16,:16] = 1. | |
rand_pose_marker_mask = rand_pose_marker_mask * rand_pose_flag[:,None,None,None] | |
red = torch.FloatTensor([1,0,0])[None,:,None,None] | |
rendered = rand_pose_marker_mask * red + (1-rand_pose_marker_mask) * rendered | |
viz_logger.add_image( | |
logger_prefix + f'image/{which_shape}_{mode}', | |
misc.image_grid(expandF(rendered)[:b0, ...].view(-1, *rendered.shape[1:])), | |
total_iter) | |
if rendered_wo_bones is not None: | |
viz_logger.add_image( | |
logger_prefix + f'image/{which_shape}_{mode}_raw', | |
misc.image_grid(expandF(rendered_wo_bones)[:b0, ...].view(-1, *rendered_wo_bones.shape[1:])), | |
total_iter) | |
if mode in ['gray']: | |
viz_logger.add_video( | |
logger_prefix + f'animation/{which_shape}_{mode}', | |
self.render_rotation_frames(shape_to_render, texture, gray_light, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, num_frames=15, render_mode='diffuse', b=1, im_features_map=im_features_map, original_mvp=mvp, original_w2c=w2c, original_campos=campos, render_gray=True).detach().cpu().unsqueeze(0), | |
total_iter, | |
fps=2) | |
else: | |
viz_logger.add_video( | |
logger_prefix + f'animation/{which_shape}_{mode}', | |
self.render_rotation_frames(shape_to_render, texture, light, (h, w), background=self.background_mode, im_features=needed_im_features, prior_shape=prior_shape, num_frames=15, render_mode=mode, b=1, im_features_map=im_features_map, original_mvp=mvp, original_w2c=w2c, original_campos=campos).detach().cpu().unsqueeze(0), | |
total_iter, | |
fps=2) | |
viz_logger.add_video( | |
logger_prefix+'animation/prior_image_rotation', | |
self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, b=1, text=category_name[0], im_features_map=im_features_map, original_mvp=mvp).detach().cpu().unsqueeze(0).clamp(0,1), | |
total_iter, | |
fps=2) | |
viz_logger.add_video( | |
logger_prefix+'animation/prior_normal_rotation', | |
self.render_rotation_frames(prior_shape, texture, light, (h, w), background=self.background_mode, im_features=im_features, num_frames=15, render_mode='geo_normal', b=1, text=category_name[0], im_features_map=im_features_map, original_mvp=mvp).detach().cpu().unsqueeze(0), | |
total_iter, | |
fps=2) | |
if save_results and self.rank == 0: | |
b0 = self.cfgs.get('num_saved_from_each_batch', batch_size*num_frames) | |
# from IPython import embed; embed() | |
fnames = [f'{total_iter:07d}_{fid:010d}' for fid in collapseF(frame_id.int())][:b0] | |
# pkl_str = osp.join(save_dir, f'{total_iter:07d}_animal_data.pkl') | |
os.makedirs(save_dir, exist_ok=True) | |
# with open(pkl_str, 'wb') as fpkl: | |
# pickle.dump(save_for_pkl, fpkl) | |
# fpkl.close() | |
misc.save_images(save_dir, collapseF(image_gt)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_gt', fnames=fnames) | |
misc.save_images(save_dir, collapseF(image_pred)[:b0].clamp(0,1).detach().cpu().numpy(), suffix='image_pred', fnames=fnames) | |
misc.save_images(save_dir, collapseF(mask_gt)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_gt', fnames=fnames) | |
misc.save_images(save_dir, collapseF(mask_pred)[:b0].unsqueeze(1).repeat(1,3,1,1).clamp(0,1).detach().cpu().numpy(), suffix='mask_pred', fnames=fnames) | |
# tmp_shape = shape.first_n(b0).clone() | |
# tmp_shape.material = texture | |
# feat = im_features[:b0] if im_features is not None else None | |
# misc.save_obj(save_dir, tmp_shape, save_material=False, feat=feat, suffix="mesh", fnames=fnames) # Save the first mesh. | |
if self.render_flow and flow_gt is not None: | |
flow_gt_viz = torch.cat([flow_gt, torch.zeros_like(flow_gt[:,:,:1])], 2) + 0.5 # -0.5~1.5 | |
flow_gt_viz = flow_gt_viz.view(-1, *flow_gt_viz.shape[2:]) | |
misc.save_images(save_dir, flow_gt_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_gt', fnames=fnames) | |
if flow_pred is not None: | |
flow_pred_viz = torch.cat([flow_pred, torch.zeros_like(flow_pred[:,:,:1])], 2) + 0.5 # -0.5~1.5 | |
flow_pred_viz = flow_pred_viz.view(-1, *flow_pred_viz.shape[2:]) | |
misc.save_images(save_dir, flow_pred_viz[:b0].clamp(0,1).detach().cpu().numpy(), suffix='flow_pred', fnames=fnames) | |
misc.save_txt(save_dir, pose[:b0].detach().cpu().numpy(), suffix='pose', fnames=fnames) | |
return metrics | |
def save_scores(self, path): | |
header = 'mask_mse, \ | |
mask_iou, \ | |
image_mse, \ | |
flow_mse' | |
mean = self.all_scores.mean(0) | |
std = self.all_scores.std(0) | |
header = header + '\nMean: ' + ',\t'.join(['%.8f'%x for x in mean]) | |
header = header + '\nStd: ' + ',\t'.join(['%.8f'%x for x in std]) | |
misc.save_scores(path, self.all_scores, header=header) | |
print(header) | |
def render_rotation_frames(self, mesh, texture, light, resolution, background='none', im_features=None, prior_shape=None, num_frames=36, render_mode='diffuse', b=None, text=None, im_features_map=None, original_mvp=None, original_w2c=None, original_campos=None, render_gray=False): | |
frames = [] | |
if b is None: | |
b = len(mesh) | |
else: | |
mesh = mesh.first_n(b) | |
feat = im_features[:b] if im_features is not None else None | |
im_features_map = im_features_map[:b] if im_features_map is not None else None | |
original_mvp = original_mvp[:b] if original_mvp is not None else None # [b, 4, 4] | |
if im_features_map is not None: | |
im_features_map = {'im_features_map': im_features_map, 'original_mvp':original_mvp} | |
delta_angle = np.pi / num_frames * 2 | |
delta_rot_matrix = torch.FloatTensor([ | |
[np.cos(delta_angle), 0, np.sin(delta_angle), 0], | |
[0, 1, 0, 0], | |
[-np.sin(delta_angle), 0, np.cos(delta_angle), 0], | |
[0, 0, 0, 1], | |
]).to(self.device).repeat(b, 1, 1) | |
w2c = torch.FloatTensor(np.diag([1., 1., 1., 1])) | |
w2c[:3, 3] = torch.FloatTensor([0, 0, -self.cam_pos_z_offset *1.1]) | |
w2c = w2c.repeat(b, 1, 1).to(self.device) | |
proj = util.perspective(self.crop_fov_approx / 180 * np.pi, 1, n=0.1, f=1000.0).repeat(b, 1, 1).to(self.device) | |
mvp = torch.bmm(proj, w2c) | |
campos = -w2c[:, :3, 3] | |
if original_w2c is not None and original_campos is not None and original_mvp is not None: | |
w2c = original_w2c[:b] | |
campos = original_campos[:b] | |
mvp = original_mvp[:b] | |
def rotate_pose(mvp, campos): | |
mvp = torch.matmul(mvp, delta_rot_matrix) | |
campos = torch.matmul(delta_rot_matrix[:,:3,:3].transpose(2,1), campos[:,:,None])[:,:,0] | |
return mvp, campos | |
for _ in range(num_frames): | |
if render_gray: | |
_, render_mask, _, _, _, image_pred = self.render(mesh, texture, mvp, w2c, campos, resolution, background=background, im_features=feat, light=light, prior_shape=prior_shape, render_flow=False, dino_pred=None, render_mode=render_mode, two_sided_shading=False, im_features_map=im_features_map) | |
if self.background_mode == 'white': | |
# we want to render shading here, which is always black background, so modify here | |
render_mask = render_mask.unsqueeze(1) | |
image_pred[render_mask == 0] = 1 | |
image_pred = image_pred.repeat(1, 3, 1, 1) | |
else: | |
image_pred, _, _, _, _, _ = self.render(mesh, texture, mvp, w2c, campos, resolution, background=background, im_features=feat, light=light, prior_shape=prior_shape, render_flow=False, dino_pred=None, render_mode=render_mode, two_sided_shading=False, im_features_map=im_features_map) #for rotation frames only! | |
image_pred = image_pred.clamp(0, 1) | |
frames += [misc.image_grid(image_pred)] | |
mvp, campos = rotate_pose(mvp, campos) | |
if text is not None: | |
frames = [torch.Tensor(misc.add_text_to_image(f, text)).permute(2, 0, 1) for f in frames] | |
return torch.stack(frames, dim=0) # Shape: (T, C, H, W) | |
def render_bones(self, mvp, bones_pred, size=(256, 256)): | |
bone_world4 = torch.concat([bones_pred, torch.ones_like(bones_pred[..., :1]).to(bones_pred.device)], dim=-1) | |
b, f, num_bones = bone_world4.shape[:3] | |
bones_clip4 = (bone_world4.view(b, f, num_bones*2, 1, 4) @ mvp.transpose(-1, -2).reshape(b, f, 1, 4, 4)).view(b, f, num_bones, 2, 4) | |
bones_uv = bones_clip4[..., :2] / bones_clip4[..., 3:4] # b, f, num_bones, 2, 2 | |
dpi = 32 | |
fx, fy = size[1] // dpi, size[0] // dpi | |
rendered = [] | |
for b_idx in range(b): | |
for f_idx in range(f): | |
frame_bones_uv = bones_uv[b_idx, f_idx].cpu().numpy() | |
fig = plt.figure(figsize=(fx, fy), dpi=dpi, frameon=False) | |
ax = plt.Axes(fig, [0., 0., 1., 1.]) | |
ax.set_axis_off() | |
for bone in frame_bones_uv: | |
ax.plot(bone[:, 0], bone[:, 1], marker='o', linewidth=8, markersize=20) | |
ax.set_xlim(-1, 1) | |
ax.set_ylim(-1, 1) | |
ax.invert_yaxis() | |
# Convert to image | |
fig.add_axes(ax) | |
fig.canvas.draw_idle() | |
image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
w, h = fig.canvas.get_width_height() | |
image.resize(h, w, 3) | |
rendered += [image / 255.] | |
return torch.from_numpy(np.stack(rendered, 0).transpose(0, 3, 1, 2)) | |
def render_deformation_frames(self, mesh, texture, batch_size, num_frames, resolution, background='none', im_features=None, render_mode='diffuse', b=None): | |
# frames = [] | |
# if b is None: | |
# b = batch_size | |
# im_features = im_features[] | |
# mesh = mesh.first_n(num_frames * b) | |
# for i in range(b): | |
# tmp_mesh = mesh.get_m_to_n(i*num_frames:(i+1)*num_frames) | |
pass | |
def vis_sds_image(self, sds_image, sds_aux): | |
sds_image = sds_image.detach().cpu().clamp(0, 1) | |
sds_image = [misc.add_text_to_image(img, text) for img, text in zip(sds_image, sds_aux['dirs'])] | |
return misc.image_grid(sds_image) | |
def vis_sds_grads(self, sds_aux): | |
grads = sds_aux['sd_aux']['grad'] | |
grads = grads.detach().cpu() | |
# compute norm | |
grads_norm = grads.norm(dim=1, keepdim=True) | |
# interpolate to 4x size | |
grads_norm = F.interpolate(grads_norm, scale_factor=4, mode='nearest') | |
# add time step and weight | |
t = sds_aux['sd_aux']['t'] | |
w = sds_aux['sd_aux']['w'] | |
# max norm for each sample over dim (1, 2, 3) | |
n = grads_norm.view(grads_norm.shape[0], -1).max(dim=1)[0] | |
texts = [f"t: {t_} w: {w_:.2f} n: {n_:.2e}" for t_, w_ , n_ in zip(t, w, n)] | |
return misc.image_grid_multi_channel(grads_norm, texts=texts, font_scale=0.5) |