BayesCap / utils.py
udion's picture
BayesCap demo to EuroCrypt
5bd623f
raw
history blame contribute delete
No virus
3.85 kB
import random
from typing import Any, Optional
import numpy as np
import os
import cv2
from glob import glob
from PIL import Image, ImageDraw
from tqdm import tqdm
import kornia
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as albu
import functools
import math
import torch
import torch.nn as nn
from torch import Tensor
import torchvision as tv
import torchvision.models as models
from torchvision import transforms
from torchvision.transforms import functional as F
from losses import TempCombLoss
######## for loading checkpoint from googledrive
google_drive_paths = {
"BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL",
"BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9",
}
def ensure_checkpoint_exists(model_weights_filename):
if not os.path.isfile(model_weights_filename) and (
model_weights_filename in google_drive_paths
):
gdrive_url = google_drive_paths[model_weights_filename]
try:
from gdown import download as drive_download
drive_download(gdrive_url, model_weights_filename, quiet=False)
except ModuleNotFoundError:
print(
"gdown module not found.",
"pip3 install gdown or, manually download the checkpoint file:",
gdrive_url
)
if not os.path.isfile(model_weights_filename) and (
model_weights_filename not in google_drive_paths
):
print(
model_weights_filename,
" not found, you may need to manually download the model weights."
)
def normalize(image: np.ndarray) -> np.ndarray:
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
Args:
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
Returns:
Normalized image data. Data range [0, 1].
"""
return image.astype(np.float64) / 255.0
def unnormalize(image: np.ndarray) -> np.ndarray:
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
Args:
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
Returns:
Denormalized image data. Data range [0, 255].
"""
return image.astype(np.float64) * 255.0
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
"""Convert ``PIL.Image`` to Tensor.
Args:
image (np.ndarray): The image data read by ``PIL.Image``
range_norm (bool): Scale [0, 1] data to between [-1, 1]
half (bool): Whether to convert torch.float32 similarly to torch.half type.
Returns:
Normalized image data
Examples:
>>> image = Image.open("image.bmp")
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
"""
tensor = F.to_tensor(image)
if range_norm:
tensor = tensor.mul_(2.0).sub_(1.0)
if half:
tensor = tensor.half()
return tensor
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
"""Converts ``torch.Tensor`` to ``PIL.Image``.
Args:
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
range_norm (bool): Scale [-1, 1] data to between [0, 1]
half (bool): Whether to convert torch.float32 similarly to torch.half type.
Returns:
Convert image data to support PIL library
Examples:
>>> tensor = torch.randn([1, 3, 128, 128])
>>> image = tensor2image(tensor, range_norm=False, half=False)
"""
if range_norm:
tensor = tensor.add_(1.0).div_(2.0)
if half:
tensor = tensor.half()
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
return image