BayesCap / src /utils.py
udion's picture
BayesCap demo to EuroCrypt
5bd623f
import random
from typing import Any, Optional
import numpy as np
import os
import cv2
from glob import glob
from PIL import Image, ImageDraw
from tqdm import tqdm
import kornia
import matplotlib.pyplot as plt
import seaborn as sns
import albumentations as albu
import functools
import math
import torch
import torch.nn as nn
from torch import Tensor
import torchvision as tv
import torchvision.models as models
from torchvision import transforms
from torchvision.transforms import functional as F
from losses import TempCombLoss
########### DeblurGAN function
def get_norm_layer(norm_type='instance'):
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def _array_to_batch(x):
x = np.transpose(x, (2, 0, 1))
x = np.expand_dims(x, 0)
return torch.from_numpy(x)
def get_normalize():
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
def process(a, b):
r = normalize(image=a, target=b)
return r['image'], r['target']
return process
def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
x, _ = get_normalize()(x, x)
if mask is None:
mask = np.ones_like(x, dtype=np.float32)
else:
mask = np.round(mask.astype('float32') / 255)
h, w, _ = x.shape
block_size = 32
min_height = (h // block_size + 1) * block_size
min_width = (w // block_size + 1) * block_size
pad_params = {'mode': 'constant',
'constant_values': 0,
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
}
x = np.pad(x, **pad_params)
mask = np.pad(mask, **pad_params)
return map(_array_to_batch, (x, mask)), h, w
def postprocess(x: torch.Tensor) -> np.ndarray:
x, = x
x = x.detach().cpu().float().numpy()
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
return x.astype('uint8')
def sorted_glob(pattern):
return sorted(glob(pattern))
###########
def normalize(image: np.ndarray) -> np.ndarray:
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
Args:
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
Returns:
Normalized image data. Data range [0, 1].
"""
return image.astype(np.float64) / 255.0
def unnormalize(image: np.ndarray) -> np.ndarray:
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
Args:
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
Returns:
Denormalized image data. Data range [0, 255].
"""
return image.astype(np.float64) * 255.0
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
"""Convert ``PIL.Image`` to Tensor.
Args:
image (np.ndarray): The image data read by ``PIL.Image``
range_norm (bool): Scale [0, 1] data to between [-1, 1]
half (bool): Whether to convert torch.float32 similarly to torch.half type.
Returns:
Normalized image data
Examples:
>>> image = Image.open("image.bmp")
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
"""
tensor = F.to_tensor(image)
if range_norm:
tensor = tensor.mul_(2.0).sub_(1.0)
if half:
tensor = tensor.half()
return tensor
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
"""Converts ``torch.Tensor`` to ``PIL.Image``.
Args:
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
range_norm (bool): Scale [-1, 1] data to between [0, 1]
half (bool): Whether to convert torch.float32 similarly to torch.half type.
Returns:
Convert image data to support PIL library
Examples:
>>> tensor = torch.randn([1, 3, 128, 128])
>>> image = tensor2image(tensor, range_norm=False, half=False)
"""
if range_norm:
tensor = tensor.add_(1.0).div_(2.0)
if half:
tensor = tensor.half()
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
return image
def convert_rgb_to_y(image: Any) -> Any:
"""Convert RGB image or tensor image data to YCbCr(Y) format.
Args:
image: RGB image data read by ``PIL.Image''.
Returns:
Y image array data.
"""
if type(image) == np.ndarray:
return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
elif type(image) == torch.Tensor:
if len(image.shape) == 4:
image = image.squeeze_(0)
return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
else:
raise Exception("Unknown Type", type(image))
def convert_rgb_to_ycbcr(image: Any) -> Any:
"""Convert RGB image or tensor image data to YCbCr format.
Args:
image: RGB image data read by ``PIL.Image''.
Returns:
YCbCr image array data.
"""
if type(image) == np.ndarray:
y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
return np.array([y, cb, cr]).transpose([1, 2, 0])
elif type(image) == torch.Tensor:
if len(image.shape) == 4:
image = image.squeeze(0)
y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
else:
raise Exception("Unknown Type", type(image))
def convert_ycbcr_to_rgb(image: Any) -> Any:
"""Convert YCbCr format image to RGB format.
Args:
image: YCbCr image data read by ``PIL.Image''.
Returns:
RGB image array data.
"""
if type(image) == np.ndarray:
r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
return np.array([r, g, b]).transpose([1, 2, 0])
elif type(image) == torch.Tensor:
if len(image.shape) == 4:
image = image.squeeze(0)
r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
return torch.cat([r, g, b], 0).permute(1, 2, 0)
else:
raise Exception("Unknown Type", type(image))
def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
"""Cut ``PIL.Image`` in the center area of the image.
Args:
lr: Low-resolution image data read by ``PIL.Image``.
hr: High-resolution image data read by ``PIL.Image``.
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
upscale_factor (int): magnification factor.
Returns:
Randomly cropped low-resolution images and high-resolution images.
"""
w, h = hr.size
left = (w - image_size) // 2
top = (h - image_size) // 2
right = left + image_size
bottom = top + image_size
lr = lr.crop((left // upscale_factor,
top // upscale_factor,
right // upscale_factor,
bottom // upscale_factor))
hr = hr.crop((left, top, right, bottom))
return lr, hr
def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
"""Will ``PIL.Image`` randomly capture the specified area of the image.
Args:
lr: Low-resolution image data read by ``PIL.Image``.
hr: High-resolution image data read by ``PIL.Image``.
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
upscale_factor (int): magnification factor.
Returns:
Randomly cropped low-resolution images and high-resolution images.
"""
w, h = hr.size
left = torch.randint(0, w - image_size + 1, size=(1,)).item()
top = torch.randint(0, h - image_size + 1, size=(1,)).item()
right = left + image_size
bottom = top + image_size
lr = lr.crop((left // upscale_factor,
top // upscale_factor,
right // upscale_factor,
bottom // upscale_factor))
hr = hr.crop((left, top, right, bottom))
return lr, hr
def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
"""Will ``PIL.Image`` randomly rotate the image.
Args:
lr: Low-resolution image data read by ``PIL.Image``.
hr: High-resolution image data read by ``PIL.Image``.
angle (int): rotation angle, clockwise and counterclockwise rotation.
Returns:
Randomly rotated low-resolution images and high-resolution images.
"""
angle = random.choice((+angle, -angle))
lr = F.rotate(lr, angle)
hr = F.rotate(hr, angle)
return lr, hr
def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
"""Flip the ``PIL.Image`` image horizontally randomly.
Args:
lr: Low-resolution image data read by ``PIL.Image``.
hr: High-resolution image data read by ``PIL.Image``.
p (optional, float): rollover probability. (Default: 0.5)
Returns:
Low-resolution image and high-resolution image after random horizontal flip.
"""
if torch.rand(1).item() > p:
lr = F.hflip(lr)
hr = F.hflip(hr)
return lr, hr
def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
"""Turn the ``PIL.Image`` image upside down randomly.
Args:
lr: Low-resolution image data read by ``PIL.Image``.
hr: High-resolution image data read by ``PIL.Image``.
p (optional, float): rollover probability. (Default: 0.5)
Returns:
Randomly rotated up and down low-resolution images and high-resolution images.
"""
if torch.rand(1).item() > p:
lr = F.vflip(lr)
hr = F.vflip(hr)
return lr, hr
def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
"""Set ``PIL.Image`` to randomly adjust the image brightness.
Args:
lr: Low-resolution image data read by ``PIL.Image``.
hr: High-resolution image data read by ``PIL.Image``.
Returns:
Low-resolution image and high-resolution image with randomly adjusted brightness.
"""
# Randomly adjust the brightness gain range.
factor = random.uniform(0.5, 2)
lr = F.adjust_brightness(lr, factor)
hr = F.adjust_brightness(hr, factor)
return lr, hr
def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
"""Set ``PIL.Image`` to randomly adjust the image contrast.
Args:
lr: Low-resolution image data read by ``PIL.Image``.
hr: High-resolution image data read by ``PIL.Image``.
Returns:
Low-resolution image and high-resolution image with randomly adjusted contrast.
"""
# Randomly adjust the contrast gain range.
factor = random.uniform(0.5, 2)
lr = F.adjust_contrast(lr, factor)
hr = F.adjust_contrast(hr, factor)
return lr, hr
#### metrics to compute -- assumes single images, i.e., tensor of 3 dims
def img_mae(x1, x2):
m = torch.abs(x1-x2).mean()
return m
def img_mse(x1, x2):
m = torch.pow(torch.abs(x1-x2),2).mean()
return m
def img_psnr(x1, x2):
m = kornia.metrics.psnr(x1, x2, 1)
return m
def img_ssim(x1, x2):
m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
m = m.mean()
return m
def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
'''
xLR/SR/HR: 3xHxW
xSRvar: 1xHxW
'''
plt.figure(figsize=(30,10))
plt.subplot(1,5,1)
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
plt.axis('off')
plt.subplot(1,5,2)
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
plt.axis('off')
plt.subplot(1,5,3)
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
plt.axis('off')
plt.subplot(1,5,4)
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
print('error', error_map.min(), error_map.max())
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
plt.clim(elim[0], elim[1])
plt.axis('off')
plt.subplot(1,5,5)
print('uncer', xSRvar.min(), xSRvar.max())
plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
plt.clim(ulim[0], ulim[1])
plt.axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
'''
xLR/SR/HR: 3xHxW
'''
plt.figure(figsize=(30,10))
if task != 'm':
plt.subplot(1,4,1)
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
plt.axis('off')
plt.subplot(1,4,2)
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
plt.axis('off')
plt.subplot(1,4,3)
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
plt.axis('off')
else:
plt.subplot(1,4,1)
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
plt.clim(0,0.9)
plt.axis('off')
plt.subplot(1,4,2)
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
plt.clim(0,0.9)
plt.axis('off')
plt.subplot(1,4,3)
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
plt.clim(0,0.9)
plt.axis('off')
plt.subplot(1,4,4)
if task == 'inpainting':
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
else:
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
print('error', error_map.min(), error_map.max())
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
plt.clim(elim[0], elim[1])
plt.axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
'''
xSRvar: 1xHxW
'''
plt.figure(figsize=(30,10))
plt.subplot(1,4,1)
print('uncer', xSRvar1.min(), xSRvar1.max())
plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
plt.clim(ulim[0], ulim[1])
plt.axis('off')
plt.subplot(1,4,2)
print('uncer', xSRvar2.min(), xSRvar2.max())
plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
plt.clim(ulim[0], ulim[1])
plt.axis('off')
plt.subplot(1,4,3)
print('uncer', xSRvar3.min(), xSRvar3.max())
plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
plt.clim(ulim[0], ulim[1])
plt.axis('off')
plt.subplot(1,4,4)
print('uncer', xSRvar4.min(), xSRvar4.max())
plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
plt.clim(ulim[0], ulim[1])
plt.axis('off')
plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
def get_UCE(list_err, list_yout_var, num_bins=100):
err_min = np.min(list_err)
err_max = np.max(list_err)
err_len = (err_max-err_min)/num_bins
num_points = len(list_err)
bin_stats = {}
for i in range(num_bins):
bin_stats[i] = {
'start_idx': err_min + i*err_len,
'end_idx': err_min + (i+1)*err_len,
'num_points': 0,
'mean_err': 0,
'mean_var': 0,
}
for e,v in zip(list_err, list_yout_var):
for i in range(num_bins):
if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
bin_stats[i]['num_points'] += 1
bin_stats[i]['mean_err'] += e
bin_stats[i]['mean_var'] += v
uce = 0
eps = 1e-8
for i in range(num_bins):
bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
*(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
uce += bin_stats[i]['uce_bin']
list_x, list_y = [], []
for i in range(num_bins):
if bin_stats[i]['num_points']>0:
list_x.append(bin_stats[i]['mean_err'])
list_y.append(bin_stats[i]['mean_var'])
# sns.set_style('darkgrid')
# sns.scatterplot(x=list_x, y=list_y)
# sns.regplot(x=list_x, y=list_y, order=1)
# plt.xlabel('MSE', fontsize=34)
# plt.ylabel('Uncertainty', fontsize=34)
# plt.plot(list_x, list_x, color='r')
# plt.xlim(np.min(list_x), np.max(list_x))
# plt.ylim(np.min(list_err), np.max(list_x))
# plt.show()
return bin_stats, uce
##################### training BayesCap
def train_BayesCap(
NetC,
NetG,
train_loader,
eval_loader,
Cri = TempCombLoss(),
device='cuda',
dtype=torch.cuda.FloatTensor(),
init_lr=1e-4,
num_epochs=100,
eval_every=1,
ckpt_path='../ckpt/BayesCap',
T1=1e0,
T2=5e-2,
task=None,
):
NetC.to(device)
NetC.train()
NetG.to(device)
NetG.eval()
optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
score = -1e8
all_loss = []
for eph in range(num_epochs):
eph_loss = 0
with tqdm(train_loader, unit='batch') as tepoch:
for (idx, batch) in enumerate(tepoch):
if idx>2000:
break
tepoch.set_description('Epoch {}'.format(eph))
##
xLR, xHR = batch[0].to(device), batch[1].to(device)
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
if task == 'inpainting':
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
xMask = xMask.to(device).type(dtype)
# pass them through the network
with torch.no_grad():
if task == 'inpainting':
_, xSR1 = NetG(xLR, xMask)
elif task == 'depth':
xSR1 = NetG(xLR)[("disp", 0)]
else:
xSR1 = NetG(xLR)
# with torch.autograd.set_detect_anomaly(True):
xSR = xSR1.clone()
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
# print(xSRC_alpha)
optimizer.zero_grad()
if task == 'depth':
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
else:
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
# print(loss)
loss.backward()
optimizer.step()
##
eph_loss += loss.item()
tepoch.set_postfix(loss=loss.item())
eph_loss /= len(train_loader)
all_loss.append(eph_loss)
print('Avg. loss: {}'.format(eph_loss))
# evaluate and save the models
torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
if eph%eval_every == 0:
curr_score = eval_BayesCap(
NetC,
NetG,
eval_loader,
device=device,
dtype=dtype,
task=task,
)
print('current score: {} | Last best score: {}'.format(curr_score, score))
if curr_score >= score:
score = curr_score
torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
optim_scheduler.step()
#### get different uncertainty maps
def get_uncer_BayesCap(
NetC,
NetG,
xin,
task=None,
xMask=None,
):
with torch.no_grad():
if task == 'inpainting':
_, xSR = NetG(xin, xMask)
else:
xSR = NetG(xin)
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
b_map = xSRC_beta.to('cpu').data
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
return xSRvar
def get_uncer_TTDAp(
NetG,
xin,
p_mag=0.05,
num_runs=50,
task=None,
xMask=None,
):
list_xSR = []
with torch.no_grad():
for z in range(num_runs):
if task == 'inpainting':
_, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
else:
xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
list_xSR.append(xSRz)
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
return xSRvar
def get_uncer_DO(
NetG,
xin,
dop=0.2,
num_runs=50,
task=None,
xMask=None,
):
list_xSR = []
with torch.no_grad():
for z in range(num_runs):
if task == 'inpainting':
_, xSRz = NetG(xin, xMask, dop=dop)
else:
xSRz = NetG(xin, dop=dop)
list_xSR.append(xSRz)
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
return xSRvar
################### Different eval functions
def eval_BayesCap(
NetC,
NetG,
eval_loader,
device='cuda',
dtype=torch.cuda.FloatTensor,
task=None,
xMask=None,
):
NetC.to(device)
NetC.eval()
NetG.to(device)
NetG.eval()
mean_ssim = 0
mean_psnr = 0
mean_mse = 0
mean_mae = 0
num_imgs = 0
list_error = []
list_var = []
with tqdm(eval_loader, unit='batch') as tepoch:
for (idx, batch) in enumerate(tepoch):
tepoch.set_description('Validating ...')
##
xLR, xHR = batch[0].to(device), batch[1].to(device)
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
if task == 'inpainting':
if xMask==None:
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
xMask = xMask.to(device).type(dtype)
else:
xMask = xMask.to(device).type(dtype)
# pass them through the network
with torch.no_grad():
if task == 'inpainting':
_, xSR = NetG(xLR, xMask)
elif task == 'depth':
xSR = NetG(xLR)[("disp", 0)]
else:
xSR = NetG(xLR)
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
b_map = xSRC_beta.to('cpu').data
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
n_batch = xSRC_mu.shape[0]
if task == 'depth':
xHR = xSR
for j in range(n_batch):
num_imgs += 1
mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
mean_mse += img_mse(xSRC_mu[j], xHR[j])
mean_mae += img_mae(xSRC_mu[j], xHR[j])
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
var_map = xSRvar[j].to('cpu').data.reshape(-1)
list_error.extend(list(error_map.numpy()))
list_var.extend(list(var_map.numpy()))
##
mean_ssim /= num_imgs
mean_psnr /= num_imgs
mean_mse /= num_imgs
mean_mae /= num_imgs
print(
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
(
mean_ssim, mean_psnr, mean_mse, mean_mae
)
)
# print(len(list_error), len(list_var))
# print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
# print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
return mean_ssim
def eval_TTDA_p(
NetG,
eval_loader,
device='cuda',
dtype=torch.cuda.FloatTensor,
p_mag=0.05,
num_runs=50,
task = None,
xMask = None,
):
NetG.to(device)
NetG.eval()
mean_ssim = 0
mean_psnr = 0
mean_mse = 0
mean_mae = 0
num_imgs = 0
with tqdm(eval_loader, unit='batch') as tepoch:
for (idx, batch) in enumerate(tepoch):
tepoch.set_description('Validating ...')
##
xLR, xHR = batch[0].to(device), batch[1].to(device)
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
# pass them through the network
list_xSR = []
with torch.no_grad():
if task=='inpainting':
_, xSR = NetG(xLR, xMask)
else:
xSR = NetG(xLR)
for z in range(num_runs):
xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
list_xSR.append(xSRz)
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
n_batch = xSR.shape[0]
for j in range(n_batch):
num_imgs += 1
mean_ssim += img_ssim(xSR[j], xHR[j])
mean_psnr += img_psnr(xSR[j], xHR[j])
mean_mse += img_mse(xSR[j], xHR[j])
mean_mae += img_mae(xSR[j], xHR[j])
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
mean_ssim /= num_imgs
mean_psnr /= num_imgs
mean_mse /= num_imgs
mean_mae /= num_imgs
print(
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
(
mean_ssim, mean_psnr, mean_mse, mean_mae
)
)
return mean_ssim
def eval_DO(
NetG,
eval_loader,
device='cuda',
dtype=torch.cuda.FloatTensor,
dop=0.2,
num_runs=50,
task=None,
xMask=None,
):
NetG.to(device)
NetG.eval()
mean_ssim = 0
mean_psnr = 0
mean_mse = 0
mean_mae = 0
num_imgs = 0
with tqdm(eval_loader, unit='batch') as tepoch:
for (idx, batch) in enumerate(tepoch):
tepoch.set_description('Validating ...')
##
xLR, xHR = batch[0].to(device), batch[1].to(device)
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
# pass them through the network
list_xSR = []
with torch.no_grad():
if task == 'inpainting':
_, xSR = NetG(xLR, xMask)
else:
xSR = NetG(xLR)
for z in range(num_runs):
xSRz = NetG(xLR, dop=dop)
list_xSR.append(xSRz)
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
n_batch = xSR.shape[0]
for j in range(n_batch):
num_imgs += 1
mean_ssim += img_ssim(xSR[j], xHR[j])
mean_psnr += img_psnr(xSR[j], xHR[j])
mean_mse += img_mse(xSR[j], xHR[j])
mean_mae += img_mae(xSR[j], xHR[j])
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
##
mean_ssim /= num_imgs
mean_psnr /= num_imgs
mean_mse /= num_imgs
mean_mae /= num_imgs
print(
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
(
mean_ssim, mean_psnr, mean_mse, mean_mae
)
)
return mean_ssim
############### compare all function
def compare_all(
NetC,
NetG,
eval_loader,
p_mag = 0.05,
dop = 0.2,
num_runs = 100,
device='cuda',
dtype=torch.cuda.FloatTensor,
task=None,
):
NetC.to(device)
NetC.eval()
NetG.to(device)
NetG.eval()
with tqdm(eval_loader, unit='batch') as tepoch:
for (idx, batch) in enumerate(tepoch):
tepoch.set_description('Comparing ...')
##
xLR, xHR = batch[0].to(device), batch[1].to(device)
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
if task == 'inpainting':
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
xMask = xMask.to(device).type(dtype)
# pass them through the network
with torch.no_grad():
if task == 'inpainting':
_, xSR = NetG(xLR, xMask)
else:
xSR = NetG(xLR)
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
if task == 'inpainting':
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
else:
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
n_batch = xSR.shape[0]
for j in range(n_batch):
if task=='s':
show_SR_w_err(xLR[j], xHR[j], xSR[j])
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
if task=='d':
show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
if task=='inpainting':
show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
if task=='m':
show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
################# Degrading Identity
def degrage_BayesCap_p(
NetC,
NetG,
eval_loader,
device='cuda',
dtype=torch.cuda.FloatTensor,
num_runs=50,
):
NetC.to(device)
NetC.eval()
NetG.to(device)
NetG.eval()
p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
list_s = []
list_p = []
list_u1 = []
list_u2 = []
list_c = []
for p_mag in p_mag_list:
mean_ssim = 0
mean_psnr = 0
mean_mse = 0
mean_mae = 0
num_imgs = 0
list_error = []
list_error2 = []
list_var = []
with tqdm(eval_loader, unit='batch') as tepoch:
for (idx, batch) in enumerate(tepoch):
tepoch.set_description('Validating ...')
##
xLR, xHR = batch[0].to(device), batch[1].to(device)
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
# pass them through the network
with torch.no_grad():
xSR = NetG(xLR)
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
b_map = xSRC_beta.to('cpu').data
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
n_batch = xSRC_mu.shape[0]
for j in range(n_batch):
num_imgs += 1
mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
mean_mse += img_mse(xSRC_mu[j], xSR[j])
mean_mae += img_mae(xSRC_mu[j], xSR[j])
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
var_map = xSRvar[j].to('cpu').data.reshape(-1)
list_error.extend(list(error_map.numpy()))
list_error2.extend(list(error_map2.numpy()))
list_var.extend(list(var_map.numpy()))
##
mean_ssim /= num_imgs
mean_psnr /= num_imgs
mean_mse /= num_imgs
mean_mae /= num_imgs
print(
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
(
mean_ssim, mean_psnr, mean_mse, mean_mae
)
)
uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
print('UCE1: ', uce1)
print('UCE2: ', uce2)
list_s.append(mean_ssim.item())
list_p.append(mean_psnr.item())
list_u1.append(uce1)
list_u2.append(uce2)
plt.plot(list_s)
plt.show()
plt.plot(list_p)
plt.show()
plt.plot(list_u1, label='wrt SR output')
plt.plot(list_u2, label='wrt BayesCap output')
plt.legend()
plt.show()
sns.set_style('darkgrid')
fig,ax = plt.subplots()
# make a plot
ax.plot(p_mag_list, list_s, color="red", marker="o")
# set x-axis label
ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
# set y-axis label
ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
# twin object for two different y-axis on the sample plot
ax2=ax.twinx()
# make a plot with different y-axis using second axis object
ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
ax2.set_ylabel("UCE", color="green", fontsize=10)
plt.legend(fontsize=10)
plt.tight_layout()
plt.show()
################# DeepFill_v2
# ----------------------------------------
# PATH processing
# ----------------------------------------
def text_readlines(filename):
# Try to read a txt file and return a list.Return [] if there was a mistake.
try:
file = open(filename, 'r')
except IOError:
error = []
return error
content = file.readlines()
# This for loop deletes the EOF (like \n)
for i in range(len(content)):
content[i] = content[i][:len(content[i])-1]
file.close()
return content
def savetxt(name, loss_log):
np_loss_log = np.array(loss_log)
np.savetxt(name, np_loss_log)
def get_files(path):
# read a folder, return the complete path
ret = []
for root, dirs, files in os.walk(path):
for filespath in files:
ret.append(os.path.join(root, filespath))
return ret
def get_names(path):
# read a folder, return the image name
ret = []
for root, dirs, files in os.walk(path):
for filespath in files:
ret.append(filespath)
return ret
def text_save(content, filename, mode = 'a'):
# save a list to a txt
# Try to save a list variable in txt file.
file = open(filename, mode)
for i in range(len(content)):
file.write(str(content[i]) + '\n')
file.close()
def check_path(path):
if not os.path.exists(path):
os.makedirs(path)
# ----------------------------------------
# Validation and Sample at training
# ----------------------------------------
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
# Save image one-by-one
for i in range(len(img_list)):
img = img_list[i]
# Recover normalization: * 255 because last layer is sigmoid activated
img = img * 255
# Process img_copy and do not destroy the data of img
img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
img_copy = img_copy.astype(np.uint8)
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
# Save to certain path
save_img_name = sample_name + '_' + name_list[i] + '.jpg'
save_img_path = os.path.join(sample_folder, save_img_name)
cv2.imwrite(save_img_path, img_copy)
def psnr(pred, target, pixel_max_cnt = 255):
mse = torch.mul(target - pred, target - pred)
rmse_avg = (torch.mean(mse).item()) ** 0.5
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
return p
def grey_psnr(pred, target, pixel_max_cnt = 255):
pred = torch.sum(pred, dim = 0)
target = torch.sum(target, dim = 0)
mse = torch.mul(target - pred, target - pred)
rmse_avg = (torch.mean(mse).item()) ** 0.5
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
return p
def ssim(pred, target):
pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
target = target[0]
pred = pred[0]
ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
return ssim
## for contextual attention
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
"""
Extract patches from images and put them in the C output dimension.
:param padding:
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
each dimension of images
:param strides: [stride_rows, stride_cols]
:param rates: [dilation_rows, dilation_cols]
:return: A Tensor
"""
assert len(images.size()) == 4
assert padding in ['same', 'valid']
batch_size, channel, height, width = images.size()
if padding == 'same':
images = same_padding(images, ksizes, strides, rates)
elif padding == 'valid':
pass
else:
raise NotImplementedError('Unsupported padding type: {}.\
Only "same" or "valid" are supported.'.format(padding))
unfold = torch.nn.Unfold(kernel_size=ksizes,
dilation=rates,
padding=0,
stride=strides)
patches = unfold(images)
return patches # [N, C*k*k, L], L is the total number of such blocks
def same_padding(images, ksizes, strides, rates):
assert len(images.size()) == 4
batch_size, channel, rows, cols = images.size()
out_rows = (rows + strides[0] - 1) // strides[0]
out_cols = (cols + strides[1] - 1) // strides[1]
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
# Pad the input
padding_top = int(padding_rows / 2.)
padding_left = int(padding_cols / 2.)
padding_bottom = padding_rows - padding_top
padding_right = padding_cols - padding_left
paddings = (padding_left, padding_right, padding_top, padding_bottom)
images = torch.nn.ZeroPad2d(paddings)(images)
return images
def reduce_mean(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.mean(x, dim=i, keepdim=keepdim)
return x
def reduce_std(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.std(x, dim=i, keepdim=keepdim)
return x
def reduce_sum(x, axis=None, keepdim=False):
if not axis:
axis = range(len(x.shape))
for i in sorted(axis, reverse=True):
x = torch.sum(x, dim=i, keepdim=keepdim)
return x
def random_mask(num_batch=1, mask_shape=(256,256)):
list_mask = []
for _ in range(num_batch):
# rectangle mask
image_height = mask_shape[0]
image_width = mask_shape[1]
max_delta_height = image_height//8
max_delta_width = image_width//8
height = image_height//4
width = image_width//4
max_t = image_height - height
max_l = image_width - width
t = random.randint(0, max_t)
l = random.randint(0, max_l)
# bbox = (t, l, height, width)
h = random.randint(0, max_delta_height//2)
w = random.randint(0, max_delta_width//2)
mask = torch.zeros((1, 1, image_height, image_width))
mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
rect_mask = mask
# brush mask
min_num_vertex = 4
max_num_vertex = 12
mean_angle = 2 * math.pi / 5
angle_range = 2 * math.pi / 15
min_width = 12
max_width = 40
H, W = image_height, image_width
average_radius = math.sqrt(H*H+W*W) / 8
mask = Image.new('L', (W, H), 0)
for _ in range(np.random.randint(1, 4)):
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
angle_min = mean_angle - np.random.uniform(0, angle_range)
angle_max = mean_angle + np.random.uniform(0, angle_range)
angles = []
vertex = []
for i in range(num_vertex):
if i % 2 == 0:
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
else:
angles.append(np.random.uniform(angle_min, angle_max))
h, w = mask.size
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
for i in range(num_vertex):
r = np.clip(
np.random.normal(loc=average_radius, scale=average_radius//2),
0, 2*average_radius)
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
vertex.append((int(new_x), int(new_y)))
draw = ImageDraw.Draw(mask)
width = int(np.random.uniform(min_width, max_width))
draw.line(vertex, fill=255, width=width)
for v in vertex:
draw.ellipse((v[0] - width//2,
v[1] - width//2,
v[0] + width//2,
v[1] + width//2),
fill=255)
if np.random.normal() > 0:
mask.transpose(Image.FLIP_LEFT_RIGHT)
if np.random.normal() > 0:
mask.transpose(Image.FLIP_TOP_BOTTOM)
mask = transforms.ToTensor()(mask)
mask = mask.reshape((1, 1, H, W))
brush_mask = mask
mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
list_mask.append(mask)
mask = torch.cat(list_mask, dim=0)
return mask