import os import math import random from PIL import Image import blobfile as bf import numpy as np from torch.utils.data import DataLoader, Dataset def load_data( *, dataset_mode, data_dir, batch_size, image_size, class_cond=False, deterministic=False, random_crop=True, random_flip=True, is_train=True, ): """ For a dataset, create a generator over (images, kwargs) pairs. Each images is an NCHW float tensor, and the kwargs dict contains zero or more keys, each of which map to a batched Tensor of their own. The kwargs dict can be used for class labels, in which case the key is "y" and the values are integer tensors of class labels. :param data_dir: a dataset directory. :param batch_size: the batch size of each returned pair. :param image_size: the size to which images are resized. :param class_cond: if True, include a "y" key in returned dicts for class label. If classes are not available and this is true, an exception will be raised. :param deterministic: if True, yield results in a deterministic order. :param random_crop: if True, randomly crop the images for augmentation. :param random_flip: if True, randomly flip the images for augmentation. """ if not data_dir: raise ValueError("unspecified data directory") if dataset_mode == 'cityscapes': all_files = _list_image_files_recursively(os.path.join(data_dir, 'leftImg8bit', 'train' if is_train else 'val')) labels_file = _list_image_files_recursively(os.path.join(data_dir, 'gtFine', 'train' if is_train else 'val')) classes = [x for x in labels_file if x.endswith('_labelIds.png')] instances = [x for x in labels_file if x.endswith('_instanceIds.png')] elif dataset_mode == 'ade20k': all_files = _list_image_files_recursively(os.path.join(data_dir, 'images', 'training' if is_train else 'validation')) classes = _list_image_files_recursively(os.path.join(data_dir, 'annotations', 'training' if is_train else 'validation')) instances = None elif dataset_mode == 'celeba': # The edge is computed by the instances. # However, the edge get from the labels and the instances are the same on CelebA. # You can take either as instance input all_files = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'images')) classes = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'labels')) instances = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'test', 'labels')) elif dataset_mode == "crack500": all_files = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'validation', 'images')) classes = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'validation','annotations')) instances = None elif dataset_mode == "thincrack": all_files = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'train', 'images')) classes = _list_image_files_recursively(os.path.join(data_dir, 'train' if is_train else 'train','annotations')) instances = None else: raise NotImplementedError('{} not implemented'.format(dataset_mode)) print("Len of Dataset:", len(all_files)) dataset = ImageDataset( dataset_mode, image_size, all_files, classes=classes, instances=instances, random_crop=random_crop, random_flip=random_flip, is_train=is_train ) if deterministic: loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True ) else: loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True ) return loader, dataset def _list_image_files_recursively(data_dir): results = [] for entry in sorted(bf.listdir(data_dir)): full_path = bf.join(data_dir, entry) ext = entry.split(".")[-1] if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: results.append(full_path) elif bf.isdir(full_path): results.extend(_list_image_files_recursively(full_path)) return results class ImageDataset(Dataset): def __init__( self, dataset_mode, resolution, image_paths, classes=None, instances=None, shard=0, num_shards=1, random_crop=False, random_flip=True, is_train=True ): super().__init__() self.is_train = is_train self.dataset_mode = dataset_mode self.resolution = resolution self.local_images = image_paths[shard:][::num_shards] self.local_classes = None if classes is None else classes[shard:][::num_shards] self.local_instances = None if instances is None else instances[shard:][::num_shards] self.random_crop = random_crop self.random_flip = random_flip def __len__(self): return len(self.local_images) def __getitem__(self, idx): path = self.local_images[idx] with bf.BlobFile(path, "rb") as f: pil_image = Image.open(f) pil_image.load() pil_image = pil_image.convert("RGB") out_dict = {} class_path = self.local_classes[idx] with bf.BlobFile(class_path, "rb") as f: pil_class = Image.open(f) pil_class.load() pil_class = pil_class.convert("L") if self.local_instances is not None: instance_path = self.local_instances[idx] # DEBUG: from classes to instances, may affect CelebA with bf.BlobFile(instance_path, "rb") as f: pil_instance = Image.open(f) pil_instance.load() pil_instance = pil_instance.convert("L") else: pil_instance = None if self.dataset_mode == 'cityscapes': arr_image, arr_class, arr_instance = resize_arr([pil_image, pil_class, pil_instance], self.resolution) else: if self.is_train: if self.random_crop: arr_image, arr_class, arr_instance = random_crop_arr([pil_image, pil_class, pil_instance], self.resolution) else: arr_image, arr_class, arr_instance = center_crop_arr([pil_image, pil_class, pil_instance], self.resolution) else: arr_image, arr_class, arr_instance = resize_arr([pil_image, pil_class, pil_instance], self.resolution, keep_aspect=False) if self.random_flip and random.random() < 0.5: arr_image = arr_image[:, ::-1].copy() arr_class = arr_class[:, ::-1].copy() arr_instance = arr_instance[:, ::-1].copy() if arr_instance is not None else None arr_image = arr_image.astype(np.float32) / 127.5 - 1 out_dict['path'] = path out_dict['label_ori'] = arr_class.copy() if self.dataset_mode == 'ade20k': arr_class = arr_class - 1 arr_class[arr_class == 255] = 150 elif self.dataset_mode == 'coco': arr_class[arr_class == 255] = 182 elif self.dataset_mode == 'crack500': arr_class[arr_class == 255] = 1 elif self.dataset_mode == 'thincrack': arr_class[arr_class == 255] = 1 out_dict['label'] = arr_class[None, ] if arr_instance is not None: out_dict['instance'] = arr_instance[None, ] return np.transpose(arr_image, [2, 0, 1]), out_dict def resize_arr(pil_list, image_size, keep_aspect=True): # We are not on a new enough PIL to support the `reducing_gap` # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. pil_image, pil_class, pil_instance = pil_list while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) if keep_aspect: scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) else: pil_image = pil_image.resize((image_size, image_size), resample=Image.BICUBIC) pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST) if pil_instance is not None: pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST) arr_image = np.array(pil_image) arr_class = np.array(pil_class) arr_instance = np.array(pil_instance) if pil_instance is not None else None return arr_image, arr_class, arr_instance def center_crop_arr(pil_list, image_size): # We are not on a new enough PIL to support the `reducing_gap` # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. pil_image, pil_class, pil_instance = pil_list while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST) if pil_instance is not None: pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST) arr_image = np.array(pil_image) arr_class = np.array(pil_class) arr_instance = np.array(pil_instance) if pil_instance is not None else None crop_y = (arr_image.shape[0] - image_size) // 2 crop_x = (arr_image.shape[1] - image_size) // 2 return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\ arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\ arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None def random_crop_arr(pil_list, image_size, min_crop_frac=0.8, max_crop_frac=1.0): min_smaller_dim_size = math.ceil(image_size / max_crop_frac) max_smaller_dim_size = math.ceil(image_size / min_crop_frac) smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) # We are not on a new enough PIL to support the `reducing_gap` # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. pil_image, pil_class, pil_instance = pil_list while min(*pil_image.size) >= 2 * smaller_dim_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = smaller_dim_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) pil_class = pil_class.resize(pil_image.size, resample=Image.NEAREST) if pil_instance is not None: pil_instance = pil_instance.resize(pil_image.size, resample=Image.NEAREST) arr_image = np.array(pil_image) arr_class = np.array(pil_class) arr_instance = np.array(pil_instance) if pil_instance is not None else None crop_y = random.randrange(arr_image.shape[0] - image_size + 1) crop_x = random.randrange(arr_image.shape[1] - image_size + 1) return arr_image[crop_y : crop_y + image_size, crop_x : crop_x + image_size],\ arr_class[crop_y: crop_y + image_size, crop_x: crop_x + image_size],\ arr_instance[crop_y : crop_y + image_size, crop_x : crop_x + image_size] if arr_instance is not None else None