import os import torch import matplotlib.pyplot as plt from PIL import Image from torch.utils.data import Dataset from torchvision import transforms class ImageDataset(Dataset): def __init__(self, dir, transform=None) -> None: self.dir = dir self.transform = transform self.file_list = sorted(os.listdir(self.dir)) def __len__(self): return len(self.file_list) def __getitem__(self, idx): image_name = self.file_list[idx] image_path = os.path.join(self.dir, image_name) grayscale_image = Image.open(image_path).convert('L') colorized_image = Image.open(image_path).convert('RGB') if self.transform: grayscale_image = self.transform(grayscale_image) colorized_image = self.transform(colorized_image) return grayscale_image, colorized_image def show_image(image_tensor): try: if len(image_tensor) == 1: plt.imshow(image_tensor[0], cmap="gray") else: plt.imshow(image_tensor.numpy().transpose(1, 2, 0)) except Exception as e: print(f"Exception when showing image: {e}") # To be able to calculate MSE loss in case output tensor has different shape from target tensor def adjust_output_shape(output_tensor, target_tensor): adjusted_tensor = torch.nn.functional.interpolate(output_tensor, size=target_tensor.shape[2:], mode="bilinear", align_corners=False) return adjusted_tensor def pil_to_torch(pil_image): transform = transforms.ToTensor() return transform(pil_image).unsqueeze(0) def torch_to_pil(torch_image): transform = transforms.ToPILImage() return transform(torch_image.squeeze(0))