Spaces:
Sleeping
Sleeping
File size: 965 Bytes
899c526 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
import numpy as np
def check_broadcastable(x, y):
assert len(x.shape) == len(y.shape)
for (n, m) in zip(x.shape[:-1], y.shape[:-1]):
assert n==m or n==1 or m==1
def broadcast_inputs(x, y):
""" Automatic broadcasting of missing dimensions """
if y is None:
xs, xd = x.shape[:-1], x.shape[-1]
return (x.view(-1, xd).contiguous(), ), x.shape[:-1]
check_broadcastable(x, y)
xs, xd = x.shape[:-1], x.shape[-1]
ys, yd = y.shape[:-1], y.shape[-1]
out_shape = [max(n,m) for (n,m) in zip(xs,ys)]
if x.shape[:-1] == y.shape[-1]:
x1 = x.view(-1, xd)
y1 = y.view(-1, yd)
else:
x_expand = [m if n==1 else 1 for (n,m) in zip(xs, ys)]
y_expand = [n if m==1 else 1 for (n,m) in zip(xs, ys)]
x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()
return (x1, y1), tuple(out_shape)
|