import random from typing import Any, Optional import numpy as np import os import cv2 from glob import glob from PIL import Image, ImageDraw from tqdm import tqdm import kornia import matplotlib.pyplot as plt import seaborn as sns import albumentations as albu import functools import math import torch import torch.nn as nn from torch import Tensor import torchvision as tv import torchvision.models as models from torchvision import transforms from torchvision.transforms import functional as F from losses import TempCombLoss ######## for loading checkpoint from googledrive google_drive_paths = { "BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL", "BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9", } def ensure_checkpoint_exists(model_weights_filename): if not os.path.isfile(model_weights_filename) and ( model_weights_filename in google_drive_paths ): gdrive_url = google_drive_paths[model_weights_filename] try: from gdown import download as drive_download drive_download(gdrive_url, model_weights_filename, quiet=False) except ModuleNotFoundError: print( "gdown module not found.", "pip3 install gdown or, manually download the checkpoint file:", gdrive_url ) if not os.path.isfile(model_weights_filename) and ( model_weights_filename not in google_drive_paths ): print( model_weights_filename, " not found, you may need to manually download the model weights." ) def normalize(image: np.ndarray) -> np.ndarray: """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. Args: image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. Returns: Normalized image data. Data range [0, 1]. """ return image.astype(np.float64) / 255.0 def unnormalize(image: np.ndarray) -> np.ndarray: """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data. Args: image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``. Returns: Denormalized image data. Data range [0, 255]. """ return image.astype(np.float64) * 255.0 def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor: """Convert ``PIL.Image`` to Tensor. Args: image (np.ndarray): The image data read by ``PIL.Image`` range_norm (bool): Scale [0, 1] data to between [-1, 1] half (bool): Whether to convert torch.float32 similarly to torch.half type. Returns: Normalized image data Examples: >>> image = Image.open("image.bmp") >>> tensor_image = image2tensor(image, range_norm=False, half=False) """ tensor = F.to_tensor(image) if range_norm: tensor = tensor.mul_(2.0).sub_(1.0) if half: tensor = tensor.half() return tensor def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any: """Converts ``torch.Tensor`` to ``PIL.Image``. Args: tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image`` range_norm (bool): Scale [-1, 1] data to between [0, 1] half (bool): Whether to convert torch.float32 similarly to torch.half type. Returns: Convert image data to support PIL library Examples: >>> tensor = torch.randn([1, 3, 128, 128]) >>> image = tensor2image(tensor, range_norm=False, half=False) """ if range_norm: tensor = tensor.add_(1.0).div_(2.0) if half: tensor = tensor.half() image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8") return image