MiniDPVO / mini_dpvo /lietorch /broadcasting.py
pablovela5620's picture
initial commit with working dpvo
899c526
raw
history blame
No virus
965 Bytes
import torch
import numpy as np
def check_broadcastable(x, y):
assert len(x.shape) == len(y.shape)
for (n, m) in zip(x.shape[:-1], y.shape[:-1]):
assert n==m or n==1 or m==1
def broadcast_inputs(x, y):
""" Automatic broadcasting of missing dimensions """
if y is None:
xs, xd = x.shape[:-1], x.shape[-1]
return (x.view(-1, xd).contiguous(), ), x.shape[:-1]
check_broadcastable(x, y)
xs, xd = x.shape[:-1], x.shape[-1]
ys, yd = y.shape[:-1], y.shape[-1]
out_shape = [max(n,m) for (n,m) in zip(xs,ys)]
if x.shape[:-1] == y.shape[-1]:
x1 = x.view(-1, xd)
y1 = y.view(-1, yd)
else:
x_expand = [m if n==1 else 1 for (n,m) in zip(xs, ys)]
y_expand = [n if m==1 else 1 for (n,m) in zip(xs, ys)]
x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()
return (x1, y1), tuple(out_shape)