File size: 10,051 Bytes
e371ddd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
import numpy as np
import einops
import torch
from scipy.spatial.transform import Rotation as R

########################## 6DoF CaPE ####################################
class CaPE_6DoF:
    def cape_embed(self, f, P):
        """
        Apply CaPE on feature.
        :param f: feature vector of shape [..., d]
        :param P: 4x4 transformation matrix
        :return: rotated feature f by pose P: f@P
        """
        f = einops.rearrange(f, '... (d k) -> ... d k', k=4)
        return einops.rearrange(f@P, '... d k -> ... (d k)', k=4)

    def attn_with_CaPE(self, f1, f2, p1, p2):
        """
        Do attention dot production with CaPE pose encoding.
        # query = cape_embed(query, p_out_inv)  # query f_q @ (p_out)^(-T)
        # key = cape_embed(key, p_in)  # key f_k @ p_in
        :param f1: b (t1 l) d
        :param f2: b (t2 l) d
        :param p1: [b, t, 4, 4]
        :param p2: [b, t, 4, 4]
        :return: attention score: [email protected]
        """
        l = f1.shape[1] // p1.shape[1]
        assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1]
        p1_invT = einops.repeat(torch.inverse(p1).permute(0, 1, 3, 2), 'b t m n -> b (t l) m n', l=l)  # f1 [b, l*t1, d]
        query = self.cape_embed(f1, p1_invT)  # [b, l*t1, d] query: f1 @ (p1)^(-T), transpose the last two dim
        p2_copy = einops.repeat(p2, 'b t m n -> b (t l) m n', l=l) # f2 [b, l*t2, d]
        key = self.cape_embed(f2, p2_copy)  # [b, l*t2, d] key: f2 @ p2
        att = query @ key.permute(0, 2, 1)  # [b, l*t1, l*t2] attention: query@key^T
        return att


################### 6DoF Verification ###################################
def euler_to_matrix(alpha, beta, gamma, x, y, z):
    # radian
    r = R.from_euler('xyz', [alpha, beta, gamma], degrees=True)
    t = np.array([[x], [y], [z]])
    rot_matrix = r.as_matrix()
    rot_matrix = np.concatenate([rot_matrix, t], axis=-1)
    rot_matrix = np.concatenate([rot_matrix, [[0, 0, 0, 1]]], axis=0)
    return rot_matrix

def random_6dof_pose(B, T):
    pose_euler = torch.rand([B, T, 6]).numpy()  # euler
    pose_matrix = []
    for b in range(B):
        p = []
        for t in range(T):
            p.append(torch.from_numpy(euler_to_matrix(*pose_euler[b, t])))
        pose_matrix.append(torch.stack(p))
    pose_matrix = torch.stack(pose_matrix)
    return pose_matrix.float()

bs = 6  # batch size
t1 = 3  # num of target views in each batch, can be arbitrary number
t2 = 5  # num of reference views in each batch, can be arbitrary number
l = 10  # len of token
d = 16  # dim of token feature, need to mod 4 in this case
assert d % 4 == 0

# random init query and key
f1 = torch.rand(bs, t1, l, d)     # query
f2 = torch.rand(bs, t2, l, d)     # key
f1 = einops.rearrange(f1, 'b t l d -> b (t l) d')
f2 = einops.rearrange(f2, 'b t l d -> b (t l) d')

# random init pose p1, p2, delta_p, [bs, t, 4, 4]
p1 = random_6dof_pose(bs, t1)   # [bs, t1, 4, 4]
p2 = random_6dof_pose(bs, t2)   # [bs, t2, 4, 4]
p_delta = random_6dof_pose(bs, 1)   # [bs, 1, 4, 4]
# delta p is identical to p1 and p2 in each batch
p1_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t1//1)
p2_delta = einops.repeat(p_delta, 'b 1 m n -> b (1 t) m n', t=t2//1)

# run attention with CaPE 6DoF
cape_6dof = CaPE_6DoF()
# att
att = cape_6dof.attn_with_CaPE(f1, f2, p1, p2)
# att_delta
att_delta = cape_6dof.attn_with_CaPE(f1, f2, p1@p1_delta, p2@p2_delta)

# condition: att score should be the same i.e. non effect from any delta_p
assert torch.allclose(att, att_delta, 1e-3)
print("6DoF CaPE Verified")




########################## 4DoF CaPE ####################################
class CaPE_4DoF:
    def rotate_every_two(self, x):
        x = einops.rearrange(x, '... (d j) -> ... d j', j=2)
        x1, x2 = x.unbind(dim=-1)
        x = torch.stack((-x2, x1), dim=-1)
        return einops.rearrange(x, '... d j -> ... (d j)')

    def cape(self, x, p):
        d, l, n = x.shape[-1], p.shape[-2], p.shape[-1]
        assert d % (2 * n) == 0
        m = einops.repeat(p, 'b l n -> b l (n k)', k=d // n)
        return m

    def cape_embed(self, qq, kk, p1, p2):
        """
        Embed camera position encoding into attention map
        :param qq: query feature map   [b, l_q, feature_dim]
        :param kk: key feature map    [b, l_k, feature_dim]
        :param p1: query pose  [b, l_q, pose_dim]
        :param p2: key pose    [b, l_k, pose_dim]
        :return: cape embedded attention map    [b, l_q, l_k]
        """
        assert p1.shape[-1] == p2.shape[-1]
        assert qq.shape[-1] == kk.shape[-1]
        assert p1.shape[0] == p2.shape[0] == qq.shape[0] == kk.shape[0]
        assert p1.shape[1] == qq.shape[1]
        assert p2.shape[1] == kk.shape[1]

        m1 = self.cape(qq, p1)
        m2 = self.cape(kk, p2)

        q = (qq * m1.cos()) + (self.rotate_every_two(qq) * m1.sin())
        k = (kk * m2.cos()) + (self.rotate_every_two(kk) * m2.sin())

        return q, k

    def attn_with_CaPE(self, f1, f2, p1, p2):
        """
        Do attention dot production with CaPE pose encoding.
        # query = cape_embed(query, p_out_inv)  # query f_q @ (p_out)^(-T)
        # key = cape_embed(key, p_in)  # key f_k @ p_in
        :param f1: b (t1 l) d
        :param f2: b (t2 l) d
        :param p1: [b, t, 4]
        :param p2: [b, t, 4]
        :return: attention score: [email protected]
        """
        l = f1.shape[1] // p1.shape[1]
        assert f1.shape[1] // p1.shape[1] == f2.shape[1] // p2.shape[1]
        p1_reshape = einops.repeat(p1, 'b t m -> b (t l) m', l=l)  # f1 [b, l*t1, d]
        p2_reshape = einops.repeat(p2, 'b t m -> b (t l) m', l=l)  # f1 [b, l*t1, d]
        query, key = self.cape_embed(f1, f2, p1_reshape, p2_reshape)
        att = query @ key.permute(0, 2, 1)  # [b, l*t1, l*t2] attention: query@key^T
        return att

################### 4DoF Verification ###################################
def random_4dof_pose(B, T):
    pose = torch.zeros([B, T, 4])
    pose[:, :, :3] = torch.rand([B, T, 3])  # radian angle
    # theta \in [0, pi], azimuth  \in [0, 2pi], radius \in [0, pi], 0
    pose[:, :, 1] *= (2*torch.pi)
    pose[:, :, 0] *= torch.pi
    pose[:, :, 2] *= torch.pi
    return pose.float()

def look_at(origin, target, up):
    forward = (target - origin)
    forward = forward / torch.linalg.norm(forward, dim=-1, keepdim=True)
    right = torch.linalg.cross(forward, up)
    right = right / torch.linalg.norm(right, dim=-1, keepdim=True)
    new_up = torch.linalg.cross(forward, right)
    new_up = new_up / torch.linalg.norm(new_up, dim=-1, keepdim=True)
    rotation_matrix = torch.stack((right, new_up, forward, target), dim=-1)
    matrix = torch.cat([rotation_matrix, torch.tensor([[0, 0, 0, 1]]).repeat(rotation_matrix.shape[0],rotation_matrix.shape[1], 1, 1)], dim=-2)

    return matrix

def pose_4dof2matrix(pose_4dof):
    """

    :param pose_4dof: [b, t, 4]
    :return: pose 4x4 matrix: [b, t, 4, 4]
    """
    theta = pose_4dof[:, :, 0]
    azimuth = pose_4dof[:, :, 1]
    radius = pose_4dof[:, :, 2]
    xyz = torch.stack([torch.sin(theta) * torch.cos(azimuth), torch.sin(theta) * torch.sin(azimuth), torch.cos(theta)], dim=-1) * radius.unsqueeze(-1)
    origin = torch.zeros_like(xyz)
    up = torch.zeros_like(xyz)
    up[:, :, 2] = 1
    pose = look_at(origin, xyz, up)
    return pose

def pose_matrix24dof(pose_matrix):
    """

    :param pose_matrix: [b, t, 4, 4]
    :return: pose_4dof: [b, t, 4]   theta, azimuth, radius, 0, looking at origin
    """
    xyz = pose_matrix[..., :3, 3]
    xy = xyz[..., 0] ** 2 + xyz[..., 1] ** 2
    radius = torch.sqrt(xy + xyz[..., 2] ** 2)
    theta = torch.arctan2(torch.sqrt(xy), xyz[..., 2])  # for elevation angle defined from Z-axis down
    azimuth = torch.arctan2(xyz[..., 1], xyz[..., 0])
    pose = torch.stack([theta, azimuth, radius, torch.zeros_like(radius)], dim=-1)
    # move to [0, 2pi]
    pose %= (2 * torch.pi)

    return pose

bs = 6  # batch size
t1 = 3  # num of target views in each batch, can be arbitrary number
t2 = 5  # num of reference views in each batch, can be arbitrary number
l = 10  # len of token
d = 16  # dim of token feature, need to mod 4 in this case

# random init query and key
f1 = torch.rand(bs, t1, l, d)     # query
f2 = torch.rand(bs, t2, l, d)     # key
f1 = einops.rearrange(f1, 'b t l d -> b (t l) d')
f2 = einops.rearrange(f2, 'b t l d -> b (t l) d')

#random init 4DoF pose [bs, t1, 4], theta, azimuth, radius, 0
p1 = random_4dof_pose(bs, t1)   # [bs, t1, 4]
p2 = random_4dof_pose(bs, t2)   # [bs, t2, 4]
p1_matrix = pose_4dof2matrix(p1)
p1_4dof = pose_matrix24dof(p1_matrix)
assert torch.allclose(p1, p1_4dof)

p_delta_4dof = random_4dof_pose(bs, 1)
# delta p is identical to p1 and p2 in each batch
p1_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t1//1)
p2_delta_4dof = einops.repeat(p_delta_4dof, 'b 1 m -> b (1 t) m', t=t2//1)

# run attention with CaPE 6DoF
cape_4dof = CaPE_4DoF()
# att
att = cape_4dof.attn_with_CaPE(f1, f2, p1, p2)
# att_delta
att_delta = cape_4dof.attn_with_CaPE(f1, f2, p1+p1_delta_4dof, p2+p2_delta_4dof)

# condition: att score should be the same i.e. non effect from any delta_p
assert torch.allclose(att, att_delta, 1e-3)
print("4DoF CaPE Verified")

# print("You should get assertion error because 4DoF CaPE cannot handle 6DoF jitter")
# # att_delta_6dof, it cannot handle 6dof jitter
# p_delta_6dof = random_6dof_pose(bs, 1)   # [bs, 1, 4, 4] any delta transformation in 6DoF
# # delta p is identical to p1 and p2 in each batch
# p1_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t1//1)
# p2_delta_6dof = einops.repeat(p_delta_6dof, 'b 1 m n -> b (1 t) m n', t=t2//1)
# # 4dof pose to 4x4 matrix
# p1_matrix = pose_4dof2matrix(p1)
# p2_matrix = pose_4dof2matrix(p2)
# att_delta_6dof = cape_4dof.attn_with_CaPE(f1, f2, pose_matrix24dof(p1_matrix@p1_delta_6dof), pose_matrix24dof(p2_matrix@p2_delta_6dof))
# # condition: att score should be the same i.e. non effect from any delta_p
# assert torch.allclose(att, att_delta_6dof, 1e-3)