import torch from einops import rearrange def video_to_image(func): def wrapper(self, x, *args, **kwargs): if x.dim() == 5: t = x.shape[2] x = rearrange(x, "b c t h w -> (b t) c h w") x = func(self, x, *args, **kwargs) x = rearrange(x, "(b t) c h w -> b c t h w", t=t) return x return wrapper def nonlinearity(x): return x * torch.sigmoid(x) def cast_tuple(t, length=1): return t if isinstance(t, tuple) else ((t,) * length) def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): n_dims = len(x.shape) if src_dim < 0: src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims dims = list(range(n_dims)) del dims[src_dim] permutation = [] ctr = 0 for i in range(n_dims): if i == dest_dim: permutation.append(src_dim) else: permutation.append(dims[ctr]) ctr += 1 x = x.permute(permutation) if make_contiguous: x = x.contiguous() return x