Spaces:
Running
Running
import os | |
import torch | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
class ImageDataset(Dataset): | |
def __init__(self, dir, transform=None) -> None: | |
self.dir = dir | |
self.transform = transform | |
self.file_list = sorted(os.listdir(self.dir)) | |
def __len__(self): | |
return len(self.file_list) | |
def __getitem__(self, idx): | |
image_name = self.file_list[idx] | |
image_path = os.path.join(self.dir, image_name) | |
grayscale_image = Image.open(image_path).convert('L') | |
colorized_image = Image.open(image_path).convert('RGB') | |
if self.transform: | |
grayscale_image = self.transform(grayscale_image) | |
colorized_image = self.transform(colorized_image) | |
return grayscale_image, colorized_image | |
def show_image(image_tensor): | |
try: | |
if len(image_tensor) == 1: | |
plt.imshow(image_tensor[0], cmap="gray") | |
else: | |
plt.imshow(image_tensor.numpy().transpose(1, 2, 0)) | |
except Exception as e: | |
print(f"Exception when showing image: {e}") | |
# To be able to calculate MSE loss in case output tensor has different shape from target tensor | |
def adjust_output_shape(output_tensor, target_tensor): | |
adjusted_tensor = torch.nn.functional.interpolate(output_tensor, size=target_tensor.shape[2:], mode="bilinear", align_corners=False) | |
return adjusted_tensor | |
def pil_to_torch(pil_image): | |
transform = transforms.ToTensor() | |
return transform(pil_image).unsqueeze(0) | |
def torch_to_pil(torch_image): | |
transform = transforms.ToPILImage() | |
return transform(torch_image.squeeze(0)) |