import torch from PIL import Image from torchvision import transforms from src.lpips import LPIPS import torch.nn as nn dev = 'cuda' to_tensor_transform = transforms.Compose([transforms.ToTensor()]) mse_loss = nn.MSELoss() def calculate_l2_difference(image1, image2, device = 'cuda'): if isinstance(image1, Image.Image): image1 = to_tensor_transform(image1).to(device) if isinstance(image2, Image.Image): image2 = to_tensor_transform(image2).to(device) mse = mse_loss(image1, image2).item() return mse def calculate_psnr(image1, image2, device = 'cuda'): max_value = 1.0 if isinstance(image1, Image.Image): image1 = to_tensor_transform(image1).to(device) if isinstance(image2, Image.Image): image2 = to_tensor_transform(image2).to(device) mse = mse_loss(image1, image2) psnr = 10 * torch.log10(max_value**2 / mse).item() return psnr loss_fn = LPIPS(net_type='vgg').to(dev).eval() def calculate_lpips(image1, image2, device = 'cuda'): if isinstance(image1, Image.Image): image1 = to_tensor_transform(image1).to(device) if isinstance(image2, Image.Image): image2 = to_tensor_transform(image2).to(device) loss = loss_fn(image1, image2).item() return loss def calculate_metrics(image1, image2, device = 'cuda', size=(512, 512)): if isinstance(image1, Image.Image): image1 = image1.resize(size) image1 = to_tensor_transform(image1).to(device) if isinstance(image2, Image.Image): image2 = image2.resize(size) image2 = to_tensor_transform(image2).to(device) l2 = calculate_l2_difference(image1, image2, device) psnr = calculate_psnr(image1, image2, device) lpips = calculate_lpips(image1, image2, device) return {"l2": l2, "psnr": psnr, "lpips": lpips} def get_empty_metrics(): return {"l2": 0, "psnr": 0, "lpips": 0} def print_results(results): print(f"Reconstruction Metrics: L2: {results['l2']},\t PSNR: {results['psnr']},\t LPIPS: {results['lpips']}")