import torch def tensor_to_size(source, dest_size): if isinstance(dest_size, torch.Tensor): dest_size = dest_size.shape[0] source_size = source.shape[0] if source_size < dest_size: shape = [dest_size - source_size] + [1]*(source.dim()-1) source = torch.cat((source, source[-1:].repeat(shape)), dim=0) elif source_size > dest_size: source = source[:dest_size] return source def tensor_to_image(tensor): image = tensor.mul(255).clamp(0, 255).byte().cpu() image = image[..., [2, 1, 0]].numpy() return image def image_to_tensor(image): tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1) tensor = tensor[..., [2, 1, 0]] return tensor