pablovela5620's picture
initial commit with working dpvo
899c526
raw
history blame
No virus
8.98 kB
import torch
import numpy as np
# group operations implemented in cuda
from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToMatrix, ToVec, FromVec
from .broadcasting import broadcast_inputs
class LieGroupParameter(torch.Tensor):
""" Wrapper class for LieGroup """
from torch._C import _disabled_torch_function_impl
__torch_function__ = _disabled_torch_function_impl
def __new__(cls, group, requires_grad=True):
data = torch.zeros(group.tangent_shape,
device=group.data.device,
dtype=group.data.dtype,
requires_grad=True)
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __init__(self, group):
self.group = group
def retr(self):
return self.group.retr(self)
def log(self):
return self.retr().log()
def inv(self):
return self.retr().inv()
def adj(self, a):
return self.retr().adj(a)
def __mul__(self, other):
if isinstance(other, LieGroupParameter):
return self.retr() * other.retr()
else:
return self.retr() * other
def add_(self, update, alpha):
self.group = self.group.exp(alpha*update) * self.group
def __getitem__(self, index):
return self.retr().__getitem__(index)
class LieGroup:
""" Base class for Lie Group """
def __init__(self, data):
self.data = data
def __repr__(self):
return "{}: size={}, device={}, dtype={}".format(
self.group_name, self.shape, self.device, self.dtype)
@property
def shape(self):
return self.data.shape[:-1]
@property
def device(self):
return self.data.device
@property
def dtype(self):
return self.data.dtype
def vec(self):
return self.apply_op(ToVec, self.data)
@property
def tangent_shape(self):
return self.data.shape[:-1] + (self.manifold_dim,)
@classmethod
def Identity(cls, *batch_shape, **kwargs):
""" Construct identity element with batch shape """
if isinstance(batch_shape[0], tuple):
batch_shape = batch_shape[0]
elif isinstance(batch_shape[0], list):
batch_shape = tuple(batch_shape[0])
numel = np.prod(batch_shape)
data = cls.id_elem.reshape(1,-1)
if 'device' in kwargs:
data = data.to(kwargs['device'])
if 'dtype' in kwargs:
data = data.type(kwargs['dtype'])
data = data.repeat(numel, 1)
return cls(data).view(batch_shape)
@classmethod
def IdentityLike(cls, G):
return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype)
@classmethod
def InitFromVec(cls, data):
return cls(cls.apply_op(FromVec, data))
@classmethod
def Random(cls, *batch_shape, sigma=1.0, **kwargs):
""" Construct random element with batch_shape by random sampling in tangent space"""
if isinstance(batch_shape[0], tuple):
batch_shape = batch_shape[0]
elif isinstance(batch_shape[0], list):
batch_shape = tuple(batch_shape[0])
tangent_shape = batch_shape + (cls.manifold_dim,)
xi = torch.randn(tangent_shape, **kwargs)
return cls.exp(sigma * xi)
@classmethod
def apply_op(cls, op, x, y=None):
""" Apply group operator """
inputs, out_shape = broadcast_inputs(x, y)
data = op.apply(cls.group_id, *inputs)
return data.view(out_shape + (-1,))
@classmethod
def exp(cls, x):
""" exponential map: x -> X """
return cls(cls.apply_op(Exp, x))
def quaternion(self):
""" extract quaternion """
return self.apply_op(Quat, self.data)
def log(self):
""" logarithm map """
return self.apply_op(Log, self.data)
def inv(self):
""" group inverse """
return self.__class__(self.apply_op(Inv, self.data))
def mul(self, other):
""" group multiplication """
return self.__class__(self.apply_op(Mul, self.data, other.data))
def retr(self, a):
""" retraction: Exp(a) * X """
dX = self.__class__.apply_op(Exp, a)
return self.__class__(self.apply_op(Mul, dX, self.data))
def adj(self, a):
""" adjoint operator: b = A(X) * a """
return self.apply_op(Adj, self.data, a)
def adjT(self, a):
""" transposed adjoint operator: b = a * A(X) """
return self.apply_op(AdjT, self.data, a)
def Jinv(self, a):
return self.apply_op(Jinv, self.data, a)
def act(self, p):
""" action on a point cloud """
# action on point
if p.shape[-1] == 3:
return self.apply_op(Act3, self.data, p)
# action on homogeneous point
elif p.shape[-1] == 4:
return self.apply_op(Act4, self.data, p)
def matrix(self):
""" convert element to 4x4 matrix """
I = torch.eye(4, dtype=self.dtype, device=self.device)
I = I.view([1] * (len(self.data.shape) - 1) + [4, 4])
return self.__class__(self.data[...,None,:]).act(I).transpose(-1,-2)
def translation(self):
""" extract translation component """
p = torch.as_tensor([0.0, 0.0, 0.0, 1.0], dtype=self.dtype, device=self.device)
p = p.view([1] * (len(self.data.shape) - 1) + [4,])
return self.apply_op(Act4, self.data, p)
def detach(self):
return self.__class__(self.data.detach())
def view(self, dims):
data_reshaped = self.data.view(dims + (self.embedded_dim,))
return self.__class__(data_reshaped)
def __mul__(self, other):
# group multiplication
if isinstance(other, LieGroup):
return self.mul(other)
# action on point
elif isinstance(other, torch.Tensor):
return self.act(other)
def __getitem__(self, index):
return self.__class__(self.data[index])
def __setitem__(self, index, item):
self.data[index] = item.data
def to(self, *args, **kwargs):
return self.__class__(self.data.to(*args, **kwargs))
def cpu(self):
return self.__class__(self.data.cpu())
def cuda(self):
return self.__class__(self.data.cuda())
def float(self, device):
return self.__class__(self.data.float())
def double(self, device):
return self.__class__(self.data.double())
def unbind(self, dim=0):
return [self.__class__(x) for x in self.data.unbind(dim=dim)]
class SO3(LieGroup):
group_name = 'SO3'
group_id = 1
manifold_dim = 3
embedded_dim = 4
# unit quaternion
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0])
def __init__(self, data):
if isinstance(data, SE3):
data = data.data[..., 3:7]
super(SO3, self).__init__(data)
class RxSO3(LieGroup):
group_name = 'RxSO3'
group_id = 2
manifold_dim = 4
embedded_dim = 5
# unit quaternion
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0, 1.0])
def __init__(self, data):
if isinstance(data, Sim3):
data = data.data[..., 3:8]
super(RxSO3, self).__init__(data)
class SE3(LieGroup):
group_name = 'SE3'
group_id = 3
manifold_dim = 6
embedded_dim = 7
# translation, unit quaternion
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
def __init__(self, data):
if isinstance(data, SO3):
translation = torch.zeros_like(data.data[...,:3])
data = torch.cat([translation, data.data], -1)
super(SE3, self).__init__(data)
def scale(self, s):
t, q = self.data.split([3,4], -1)
t = t * s.unsqueeze(-1)
return SE3(torch.cat([t, q], dim=-1))
class Sim3(LieGroup):
group_name = 'Sim3'
group_id = 4
manifold_dim = 7
embedded_dim = 8
# translation, unit quaternion, scale
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0])
def __init__(self, data):
if isinstance(data, SO3):
scale = torch.ones_like(SO3.data[...,:1])
translation = torch.zeros_like(SO3.data[...,:3])
data = torch.cat([translation, SO3.data, scale], -1)
elif isinstance(data, SE3):
scale = torch.ones_like(data.data[...,:1])
data = torch.cat([data.data, scale], -1)
elif isinstance(data, Sim3):
data = data.data
super(Sim3, self).__init__(data)
def cat(group_list, dim):
""" Concatenate groups along dimension """
data = torch.cat([X.data for X in group_list], dim=dim)
return group_list[0].__class__(data)
def stack(group_list, dim):
""" Concatenate groups along dimension """
data = torch.stack([X.data for X in group_list], dim=dim)
return group_list[0].__class__(data)