import lietorch_backends import torch import torch.nn.functional as F class GroupOp(torch.autograd.Function): """ group operation base class """ @classmethod def forward(cls, ctx, group_id, *inputs): ctx.group_id = group_id ctx.save_for_backward(*inputs) out = cls.forward_op(ctx.group_id, *inputs) return out @classmethod def backward(cls, ctx, grad): error_str = "Backward operation not implemented for {}".format(cls) assert cls.backward_op is not None, error_str inputs = ctx.saved_tensors grad = grad.contiguous() grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs) return (None, ) + tuple(grad_inputs) class Exp(GroupOp): """ exponential map """ forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward class Log(GroupOp): """ logarithm map """ forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward class Inv(GroupOp): """ group inverse """ forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward class Mul(GroupOp): """ group multiplication """ forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward class Adj(GroupOp): """ adjoint operator """ forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward class AdjT(GroupOp): """ adjoint operator """ forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward class Act3(GroupOp): """ action on point """ forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward class Act4(GroupOp): """ action on point """ forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward class Jinv(GroupOp): """ adjoint operator """ forward_op, backward_op = lietorch_backends.Jinv, None class ToMatrix(GroupOp): """ convert to matrix representation """ forward_op, backward_op = lietorch_backends.as_matrix, None ### conversion operations to/from Euclidean embeddings ### class FromVec(torch.autograd.Function): """ convert vector into group object """ @classmethod def forward(cls, ctx, group_id, *inputs): ctx.group_id = group_id ctx.save_for_backward(*inputs) return inputs[0] @classmethod def backward(cls, ctx, grad): inputs = ctx.saved_tensors J = lietorch_backends.projector(ctx.group_id, *inputs) return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2) class ToVec(torch.autograd.Function): """ convert group object to vector """ @classmethod def forward(cls, ctx, group_id, *inputs): ctx.group_id = group_id ctx.save_for_backward(*inputs) return inputs[0] @classmethod def backward(cls, ctx, grad): inputs = ctx.saved_tensors J = lietorch_backends.projector(ctx.group_id, *inputs) return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)