import itertools import torch import torch.nn as nn import pose_estimation class MSE(nn.Module): def __init__(self, ignore=None): super().__init__() self.mse = torch.nn.MSELoss(reduction="none") self.ignore = ignore if ignore is not None else [] def forward(self, y_pred, y_data): loss = self.mse(y_pred, y_data) if len(self.ignore) > 0: loss[self.ignore] *= 0 return loss.sum() / (len(loss) - len(self.ignore)) class Parallel(nn.Module): def __init__(self, skeleton, ignore=None, ground_parallel=None): super().__init__() self.skeleton = skeleton if ignore is not None: self.ignore = set(ignore) else: self.ignore = set() self.ground_parallel = ground_parallel if ground_parallel is not None else [] self.parallel_in_3d = [] self.cos = None def forward(self, y_pred3d, y_data, z, spine_j, global_step=0): y_pred = y_pred3d[:, :2] rleg, lleg = spine_j Lcon2d = Lcount = 0 if hasattr(self, "contact_2d"): for c2d in self.contact_2d: for ( (src_1, dst_1, t_1), (src_2, dst_2, t_2), ) in itertools.combinations(c2d, 2): a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1) a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2) a = a_2 - a_1 b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1) b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2) b = b_2 - b_1 lcon2d = ((a - b) ** 2).sum() Lcon2d = Lcon2d + lcon2d Lcount += 1 if Lcount > 0: Lcon2d = Lcon2d / Lcount Ltan = Lpar = Lcos = Lcount = 0 Lspine = 0 for i, bone in enumerate(self.skeleton): if bone in self.ignore: continue src, dst = bone b = y_data[dst] - y_data[src] t = nn.functional.normalize(b, dim=0) n = torch.stack([-t[1], t[0]]) if src == 10 and dst == 11: # right leg a = rleg elif src == 13 and dst == 14: # left leg a = lleg else: a = y_pred[dst] - y_pred[src] bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}" c = a - b lcos_loc = ltan_loc = lpar_loc = 0 if self.cos is not None: if bone not in [ (1, 2), # Neck + Right Shoulder (1, 5), # Neck + Left Shoulder (9, 10), # Hips + Right Upper Leg (9, 13), # Hips + Left Upper Leg ]: a = y_pred[dst] - y_pred[src] l2d = torch.norm(a, dim=0) l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0) lcos = self.cos[i] lcos_loc = (l2d / l3d - lcos) ** 2 Lcos = Lcos + lcos_loc lpar_loc = ((a / l2d) * n).sum() ** 2 Lpar = Lpar + lpar_loc else: ltan_loc = ((c * t).sum()) ** 2 Ltan = Ltan + ltan_loc lpar_loc = (c * n).sum() ** 2 Lpar = Lpar + lpar_loc Lcount += 1 if Lcount > 0: Ltan = Ltan / Lcount Lcos = Lcos / Lcount Lpar = Lpar / Lcount Lspine = Lspine / Lcount Lgr = Lcount = 0 for (src, dst), value in self.ground_parallel: bone = y_pred[dst] - y_pred[src] bone = nn.functional.normalize(bone, dim=0) l = (torch.abs(bone[0]) - value) ** 2 Lgr = Lgr + l Lcount += 1 if Lcount > 0: Lgr = Lgr / Lcount Lstraight3d = Lcount = 0 for (i, j), (k, l) in self.parallel_in_3d: a = z[j] - z[i] a = nn.functional.normalize(a, dim=0) b = z[l] - z[k] b = nn.functional.normalize(b, dim=0) lo = (((a * b).sum() - 1) ** 2).sum() Lstraight3d = Lstraight3d + lo Lcount += 1 b = y_data[1] - y_data[8] b = nn.functional.normalize(b, dim=0) if Lcount > 0: Lstraight3d = Lstraight3d / Lcount return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d class MimickedSelfContactLoss(nn.Module): def __init__(self, geodesics_mask): super().__init__() """ Loss that lets vertices in contact on presented mesh attract vertices that are close. """ # geodesic distance mask self.register_buffer("geomask", geodesics_mask) def forward( self, presented_contact, vertices, v2v=None, contact_mode="dist_tanh", contact_thresh=1, ): contactloss = 0.0 if v2v is None: # compute pairwise distances verts = vertices.contiguous() nv = verts.shape[1] v2v = verts.squeeze().unsqueeze(1).expand( nv, nv, 3 ) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3) v2v = torch.norm(v2v, 2, 2) # loss for self-contact from mimic'ed pose if len(presented_contact) > 0: # without geodesic distance mask, compute distances # between each pair of verts in contact with torch.no_grad(): cvertstobody = v2v[presented_contact, :] cvertstobody = cvertstobody[:, presented_contact] maskgeo = self.geomask[presented_contact, :] maskgeo = maskgeo[:, presented_contact] weights = torch.ones_like(cvertstobody).to(verts.device) weights[~maskgeo] = float("inf") min_idx = torch.min((cvertstobody + 1) * weights, 1)[1] min_idx = presented_contact[min_idx.cpu().numpy()] v2v_min = v2v[presented_contact, min_idx] # tanh will not pull vertices that are ~more than contact_thres far apart if contact_mode == "dist_tanh": contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh) contactloss = contactloss.mean() else: contactloss = v2v_min.mean() return contactloss