sketch2pose / src /losses.py
kbrodt's picture
Upload losses.py
74d6764
import itertools
import torch
import torch.nn as nn
import pose_estimation
class MSE(nn.Module):
def __init__(self, ignore=None):
super().__init__()
self.mse = torch.nn.MSELoss(reduction="none")
self.ignore = ignore if ignore is not None else []
def forward(self, y_pred, y_data):
loss = self.mse(y_pred, y_data)
if len(self.ignore) > 0:
loss[self.ignore] *= 0
return loss.sum() / (len(loss) - len(self.ignore))
class Parallel(nn.Module):
def __init__(self, skeleton, ignore=None, ground_parallel=None):
super().__init__()
self.skeleton = skeleton
if ignore is not None:
self.ignore = set(ignore)
else:
self.ignore = set()
self.ground_parallel = ground_parallel if ground_parallel is not None else []
self.parallel_in_3d = []
self.cos = None
def forward(self, y_pred3d, y_data, z, spine_j, global_step=0):
y_pred = y_pred3d[:, :2]
rleg, lleg = spine_j
Lcon2d = Lcount = 0
if hasattr(self, "contact_2d"):
for c2d in self.contact_2d:
for (
(src_1, dst_1, t_1),
(src_2, dst_2, t_2),
) in itertools.combinations(c2d, 2):
a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1)
a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2)
a = a_2 - a_1
b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1)
b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2)
b = b_2 - b_1
lcon2d = ((a - b) ** 2).sum()
Lcon2d = Lcon2d + lcon2d
Lcount += 1
if Lcount > 0:
Lcon2d = Lcon2d / Lcount
Ltan = Lpar = Lcos = Lcount = 0
Lspine = 0
for i, bone in enumerate(self.skeleton):
if bone in self.ignore:
continue
src, dst = bone
b = y_data[dst] - y_data[src]
t = nn.functional.normalize(b, dim=0)
n = torch.stack([-t[1], t[0]])
if src == 10 and dst == 11: # right leg
a = rleg
elif src == 13 and dst == 14: # left leg
a = lleg
else:
a = y_pred[dst] - y_pred[src]
bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}"
c = a - b
lcos_loc = ltan_loc = lpar_loc = 0
if self.cos is not None:
if bone not in [
(1, 2), # Neck + Right Shoulder
(1, 5), # Neck + Left Shoulder
(9, 10), # Hips + Right Upper Leg
(9, 13), # Hips + Left Upper Leg
]:
a = y_pred[dst] - y_pred[src]
l2d = torch.norm(a, dim=0)
l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0)
lcos = self.cos[i]
lcos_loc = (l2d / l3d - lcos) ** 2
Lcos = Lcos + lcos_loc
lpar_loc = ((a / l2d) * n).sum() ** 2
Lpar = Lpar + lpar_loc
else:
ltan_loc = ((c * t).sum()) ** 2
Ltan = Ltan + ltan_loc
lpar_loc = (c * n).sum() ** 2
Lpar = Lpar + lpar_loc
Lcount += 1
if Lcount > 0:
Ltan = Ltan / Lcount
Lcos = Lcos / Lcount
Lpar = Lpar / Lcount
Lspine = Lspine / Lcount
Lgr = Lcount = 0
for (src, dst), value in self.ground_parallel:
bone = y_pred[dst] - y_pred[src]
bone = nn.functional.normalize(bone, dim=0)
l = (torch.abs(bone[0]) - value) ** 2
Lgr = Lgr + l
Lcount += 1
if Lcount > 0:
Lgr = Lgr / Lcount
Lstraight3d = Lcount = 0
for (i, j), (k, l) in self.parallel_in_3d:
a = z[j] - z[i]
a = nn.functional.normalize(a, dim=0)
b = z[l] - z[k]
b = nn.functional.normalize(b, dim=0)
lo = (((a * b).sum() - 1) ** 2).sum()
Lstraight3d = Lstraight3d + lo
Lcount += 1
b = y_data[1] - y_data[8]
b = nn.functional.normalize(b, dim=0)
if Lcount > 0:
Lstraight3d = Lstraight3d / Lcount
return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d
class MimickedSelfContactLoss(nn.Module):
def __init__(self, geodesics_mask):
super().__init__()
"""
Loss that lets vertices in contact on presented mesh attract vertices that are close.
"""
# geodesic distance mask
self.register_buffer("geomask", geodesics_mask)
def forward(
self,
presented_contact,
vertices,
v2v=None,
contact_mode="dist_tanh",
contact_thresh=1,
):
contactloss = 0.0
if v2v is None:
# compute pairwise distances
verts = vertices.contiguous()
nv = verts.shape[1]
v2v = verts.squeeze().unsqueeze(1).expand(
nv, nv, 3
) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3)
v2v = torch.norm(v2v, 2, 2)
# loss for self-contact from mimic'ed pose
if len(presented_contact) > 0:
# without geodesic distance mask, compute distances
# between each pair of verts in contact
with torch.no_grad():
cvertstobody = v2v[presented_contact, :]
cvertstobody = cvertstobody[:, presented_contact]
maskgeo = self.geomask[presented_contact, :]
maskgeo = maskgeo[:, presented_contact]
weights = torch.ones_like(cvertstobody).to(verts.device)
weights[~maskgeo] = float("inf")
min_idx = torch.min((cvertstobody + 1) * weights, 1)[1]
min_idx = presented_contact[min_idx.cpu().numpy()]
v2v_min = v2v[presented_contact, min_idx]
# tanh will not pull vertices that are ~more than contact_thres far apart
if contact_mode == "dist_tanh":
contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh)
contactloss = contactloss.mean()
else:
contactloss = v2v_min.mean()
return contactloss