File size: 6,554 Bytes
74d6764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import itertools

import torch
import torch.nn as nn

import pose_estimation


class MSE(nn.Module):
    def __init__(self, ignore=None):
        super().__init__()

        self.mse = torch.nn.MSELoss(reduction="none")
        self.ignore = ignore if ignore is not None else []

    def forward(self, y_pred, y_data):
        loss = self.mse(y_pred, y_data)

        if len(self.ignore) > 0:
            loss[self.ignore] *= 0

        return loss.sum() / (len(loss) - len(self.ignore))


class Parallel(nn.Module):
    def __init__(self, skeleton, ignore=None, ground_parallel=None):
        super().__init__()

        self.skeleton = skeleton
        if ignore is not None:
            self.ignore = set(ignore)
        else:
            self.ignore = set()

        self.ground_parallel = ground_parallel if ground_parallel is not None else []
        self.parallel_in_3d = []

        self.cos = None

    def forward(self, y_pred3d, y_data, z, spine_j, global_step=0):
        y_pred = y_pred3d[:, :2]
        rleg, lleg = spine_j

        Lcon2d = Lcount = 0
        if hasattr(self, "contact_2d"):
            for c2d in self.contact_2d:
                for (
                    (src_1, dst_1, t_1),
                    (src_2, dst_2, t_2),
                ) in itertools.combinations(c2d, 2):

                    a_1 = torch.lerp(y_data[src_1], y_data[dst_1], t_1)
                    a_2 = torch.lerp(y_data[src_2], y_data[dst_2], t_2)
                    a = a_2 - a_1

                    b_1 = torch.lerp(y_pred[src_1], y_pred[dst_1], t_1)
                    b_2 = torch.lerp(y_pred[src_2], y_pred[dst_2], t_2)
                    b = b_2 - b_1

                    lcon2d = ((a - b) ** 2).sum()
                    Lcon2d = Lcon2d + lcon2d
                    Lcount += 1

        if Lcount > 0:
            Lcon2d = Lcon2d / Lcount

        Ltan = Lpar = Lcos = Lcount = 0
        Lspine = 0
        for i, bone in enumerate(self.skeleton):
            if bone in self.ignore:
                continue

            src, dst = bone

            b = y_data[dst] - y_data[src]
            t = nn.functional.normalize(b, dim=0)
            n = torch.stack([-t[1], t[0]])

            if src == 10 and dst == 11:  # right leg
                a = rleg
            elif src == 13 and dst == 14:  # left leg
                a = lleg
            else:
                a = y_pred[dst] - y_pred[src]

            bone_name = f"{pose_estimation.KPS[src]}_{pose_estimation.KPS[dst]}"
            c = a - b
            lcos_loc = ltan_loc = lpar_loc = 0
            if self.cos is not None:
                if bone not in [
                    (1, 2),  # Neck + Right Shoulder
                    (1, 5),  # Neck + Left Shoulder
                    (9, 10),  # Hips + Right Upper Leg
                    (9, 13),  # Hips + Left Upper Leg
                ]:
                    a = y_pred[dst] - y_pred[src]
                    l2d = torch.norm(a, dim=0)
                    l3d = torch.norm(y_pred3d[dst] - y_pred3d[src], dim=0)
                    lcos = self.cos[i]

                    lcos_loc = (l2d / l3d - lcos) ** 2
                    Lcos = Lcos + lcos_loc
                    lpar_loc = ((a / l2d) * n).sum() ** 2
                    Lpar = Lpar + lpar_loc
            else:
                ltan_loc = ((c * t).sum()) ** 2
                Ltan = Ltan + ltan_loc
                lpar_loc = (c * n).sum() ** 2
                Lpar = Lpar + lpar_loc

            Lcount += 1

        if Lcount > 0:
            Ltan = Ltan / Lcount
            Lcos = Lcos / Lcount
            Lpar = Lpar / Lcount
            Lspine = Lspine / Lcount

        Lgr = Lcount = 0
        for (src, dst), value in self.ground_parallel:
            bone = y_pred[dst] - y_pred[src]
            bone = nn.functional.normalize(bone, dim=0)
            l = (torch.abs(bone[0]) - value) ** 2

            Lgr = Lgr + l
            Lcount += 1

        if Lcount > 0:
            Lgr = Lgr / Lcount

        Lstraight3d = Lcount = 0
        for (i, j), (k, l) in self.parallel_in_3d:
            a = z[j] - z[i]
            a = nn.functional.normalize(a, dim=0)
            b = z[l] - z[k]
            b = nn.functional.normalize(b, dim=0)
            lo = (((a * b).sum() - 1) ** 2).sum()
            Lstraight3d = Lstraight3d + lo
            Lcount += 1

            b = y_data[1] - y_data[8]
            b = nn.functional.normalize(b, dim=0)

        if Lcount > 0:
            Lstraight3d = Lstraight3d / Lcount

        return Ltan, Lcos, Lpar, Lspine, Lgr, Lstraight3d, Lcon2d


class MimickedSelfContactLoss(nn.Module):
    def __init__(self, geodesics_mask):
        super().__init__()
        """
        Loss that lets vertices in contact on presented mesh attract vertices that are close.
        """
        # geodesic distance mask
        self.register_buffer("geomask", geodesics_mask)

    def forward(
        self,
        presented_contact,
        vertices,
        v2v=None,
        contact_mode="dist_tanh",
        contact_thresh=1,
    ):

        contactloss = 0.0

        if v2v is None:
            # compute pairwise distances
            verts = vertices.contiguous()
            nv = verts.shape[1]
            v2v = verts.squeeze().unsqueeze(1).expand(
                nv, nv, 3
            ) - verts.squeeze().unsqueeze(0).expand(nv, nv, 3)
            v2v = torch.norm(v2v, 2, 2)

        # loss for self-contact from mimic'ed pose
        if len(presented_contact) > 0:
            # without geodesic distance mask, compute distances
            # between each pair of verts in contact
            with torch.no_grad():
                cvertstobody = v2v[presented_contact, :]
                cvertstobody = cvertstobody[:, presented_contact]
                maskgeo = self.geomask[presented_contact, :]
                maskgeo = maskgeo[:, presented_contact]
                weights = torch.ones_like(cvertstobody).to(verts.device)
                weights[~maskgeo] = float("inf")
                min_idx = torch.min((cvertstobody + 1) * weights, 1)[1]
                min_idx = presented_contact[min_idx.cpu().numpy()]

            v2v_min = v2v[presented_contact, min_idx]

            # tanh will not pull vertices that are ~more than contact_thres far apart
            if contact_mode == "dist_tanh":
                contactloss = contact_thresh * torch.tanh(v2v_min / contact_thresh)
                contactloss = contactloss.mean()
            else:
                contactloss = v2v_min.mean()

        return contactloss