import numpy as np import torch def np2th(ndarray): if isinstance(ndarray, torch.Tensor): return ndarray.detach().cpu() elif isinstance(ndarray, np.ndarray): return torch.tensor(ndarray).float() else: raise ValueError("Input should be either torch.Tensor or np.ndarray")