MiniDPVO / mini_dpvo /lietorch /group_ops.py
pablovela5620's picture
initial commit with working dpvo
899c526
raw
history blame
No virus
3.06 kB
import lietorch_backends
import torch
import torch.nn.functional as F
class GroupOp(torch.autograd.Function):
""" group operation base class """
@classmethod
def forward(cls, ctx, group_id, *inputs):
ctx.group_id = group_id
ctx.save_for_backward(*inputs)
out = cls.forward_op(ctx.group_id, *inputs)
return out
@classmethod
def backward(cls, ctx, grad):
error_str = "Backward operation not implemented for {}".format(cls)
assert cls.backward_op is not None, error_str
inputs = ctx.saved_tensors
grad = grad.contiguous()
grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs)
return (None, ) + tuple(grad_inputs)
class Exp(GroupOp):
""" exponential map """
forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward
class Log(GroupOp):
""" logarithm map """
forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward
class Inv(GroupOp):
""" group inverse """
forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward
class Mul(GroupOp):
""" group multiplication """
forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward
class Adj(GroupOp):
""" adjoint operator """
forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward
class AdjT(GroupOp):
""" adjoint operator """
forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward
class Act3(GroupOp):
""" action on point """
forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward
class Act4(GroupOp):
""" action on point """
forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward
class Jinv(GroupOp):
""" adjoint operator """
forward_op, backward_op = lietorch_backends.Jinv, None
class ToMatrix(GroupOp):
""" convert to matrix representation """
forward_op, backward_op = lietorch_backends.as_matrix, None
### conversion operations to/from Euclidean embeddings ###
class FromVec(torch.autograd.Function):
""" convert vector into group object """
@classmethod
def forward(cls, ctx, group_id, *inputs):
ctx.group_id = group_id
ctx.save_for_backward(*inputs)
return inputs[0]
@classmethod
def backward(cls, ctx, grad):
inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2)
class ToVec(torch.autograd.Function):
""" convert group object to vector """
@classmethod
def forward(cls, ctx, group_id, *inputs):
ctx.group_id = group_id
ctx.save_for_backward(*inputs)
return inputs[0]
@classmethod
def backward(cls, ctx, grad):
inputs = ctx.saved_tensors
J = lietorch_backends.projector(ctx.group_id, *inputs)
return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)