from .. import constants import functools from scipy.spatial.transform.rotation import Rotation from ..custom_types import * def quat_to_rot(q): shape = q.shape q = q.view(-1, 4) q_sq = 2 * q[:, :, None] * q[:, None, :] m00 = 1 - q_sq[:, 1, 1] - q_sq[:, 2, 2] m01 = q_sq[:, 0, 1] - q_sq[:, 2, 3] m02 = q_sq[:, 0, 2] + q_sq[:, 1, 3] m10 = q_sq[:, 0, 1] + q_sq[:, 2, 3] m11 = 1 - q_sq[:, 0, 0] - q_sq[:, 2, 2] m12 = q_sq[:, 1, 2] - q_sq[:, 0, 3] m20 = q_sq[:, 0, 2] - q_sq[:, 1, 3] m21 = q_sq[:, 1, 2] + q_sq[:, 0, 3] m22 = 1 - q_sq[:, 0, 0] - q_sq[:, 1, 1] r = torch.stack((m00, m01, m02, m10, m11, m12, m20, m21, m22), dim=1) r = r.view(*shape[:-1], 3, 3) return r def rot_to_quat(r): shape = r.shape r = r.view(-1, 3, 3) qw = .5 * (1 + r[:, 0, 0] + r[:, 1, 1] + r[:, 2, 2]).sqrt() qx = (r[:, 2, 1] - r[:, 1, 2]) / (4 * qw) qy = (r[:, 0, 2] - r[:, 2, 0]) / (4 * qw) qz = (r[:, 1, 0] - r[:, 0, 1]) / (4 * qw) q = torch.stack((qx, qy, qz, qw), -1) q = q.view(*shape[:-2], 4) return q @functools.lru_cache(10) def get_rotation_matrix(theta: float, axis: float, degree: bool = False) -> ARRAY: if degree: theta = theta * np.pi / 180 rotate_mat = np.eye(3) rotate_mat[axis, axis] = 1 cos_theta, sin_theta = np.cos(theta), np.sin(theta) rotate_mat[(axis + 1) % 3, (axis + 1) % 3] = cos_theta rotate_mat[(axis + 2) % 3, (axis + 2) % 3] = cos_theta rotate_mat[(axis + 1) % 3, (axis + 2) % 3] = sin_theta rotate_mat[(axis + 2) % 3, (axis + 1) % 3] = -sin_theta return rotate_mat def get_random_rotation(batch_size: int) -> T: r = Rotation.random(batch_size).as_matrix().astype(np.float32) Rotation.random() return torch.from_numpy(r) def rand_bounded_rotation_matrix(cache_size: int, theta_range: float = .1): def create_cache(): # from http://www.realtimerendering.com/resources/GraphicsGems/gemsiii/rand_rotation.c with torch.no_grad(): theta, phi, z = torch.rand(cache_size, 3).split((1, 1, 1), dim=1) theta = (2 * theta - 1) * theta_range + 1 theta = np.pi * theta # Rotation about the pole (Z). phi = phi * 2 * np.pi # For direction of pole deflection. z = 2 * z * theta_range # For magnitude of pole deflection. r = z.sqrt() v = torch.cat((torch.sin(phi) * r, torch.cos(phi) * r, torch.sqrt(2.0 - z)), dim=1) st = torch.sin(theta).squeeze(1) ct = torch.cos(theta).squeeze(1) rot_ = torch.zeros(cache_size, 3, 3) rot_[:, 0, 0] = ct rot_[:, 1, 1] = ct rot_[:, 0, 1] = st rot_[:, 1, 0] = -st rot_[:, 2, 2] = 1 rot = (torch.einsum('ba,bd->bad', v, v) - torch.eye(3)[None, :, :]).bmm(rot_) det = rot.det() assert (det.gt(0.99) * det.lt(1.0001)).all().item() return rot def get_batch_rot(batch_size): nonlocal cache select = torch.randint(cache_size, size=(batch_size,)) return cache[select] cache = create_cache() return get_batch_rot def transform_rotation(points: T, one_axis=False, max_angle=-1): r = get_random_rotation(one_axis, max_angle) transformed = torch.einsum('nd,rd->nr', points, r) return transformed def tb_to_rot(abc: T) -> T: c, s = torch.cos(abc), torch.sin(abc) aa = c[:, 0] * c[:, 1] ab = c[:, 0] * s[:, 1] * s[:, 2] - c[:, 2] * s[:, 0] ac = s[:, 0] * s[:, 2] + c[:, 0] * c[:, 2] * s[:, 1] ba = c[:, 1] * s[:, 0] bb = c[:, 0] * c[:, 2] + s.prod(-1) bc = c[:, 2] * s[:, 0] * s[:, 1] - c[:, 0] * s[:, 2] ca = -s[:, 1] cb = c[:, 1] * s[:, 2] cc = c[:, 1] * c[:, 2] return torch.stack((aa, ab, ac, ba, bb, bc, ca, cb, cc), 1).view(-1, 3, 3) def rot_to_tb(rot: T) -> T: sy = torch.sqrt(rot[:, 0, 0] * rot[:, 0, 0] + rot[:, 1, 0] * rot[:, 1, 0]) out = torch.zeros(rot.shape[0], 3, device = rot.device) mask = sy.gt(1e-6) z = torch.atan2(rot[mask, 2, 1], rot[mask, 2, 2]) y = torch.atan2(-rot[mask, 2, 0], sy[mask]) x = torch.atan2(rot[mask, 1, 0], rot[mask, 0, 0]) out[mask] = torch.stack((x, y, z), dim=1) if not mask.all(): mask = ~mask z = torch.atan2(-rot[mask, 1, 2], rot[mask, 1, 1]) y = torch.atan2(-rot[mask, 2, 0], sy[mask]) x = torch.zeros(x.shape) out[mask] = torch.stack((x, y, z), dim=1) return out def apply_gmm_affine(gmms: TS, affine: T): mu, p, phi, eigen = gmms if affine.dim() == 2: affine = affine.unsqueeze(0).expand(mu.shape[0], *affine.shape) mu_r = torch.einsum('bad, bpnd->bpna', affine, mu) p_r = torch.einsum('bad, bpncd->bpnca', affine, p) return mu_r, p_r, phi, eigen def get_reflection(reflect_axes: Tuple[bool, ...]) -> T: reflect = torch.eye(constants.DIM) for i in range(constants.DIM): if reflect_axes[i]: reflect[i, i] = -1 return reflect def get_tait_bryan_from_p(p: T) -> T: # p = p.squeeze(1) shape = p.shape rot = p.reshape(-1, 3, 3).permute(0, 2, 1) angles = rot_to_tb(rot) angles = angles / np.pi angles[:, 1] = angles[:, 1] * 2 angles = angles.view(*shape[:2], 3) return angles