# Copyright (c) EPFL VILAB. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # Based on BEiT, timm, DINO, DeiT and MAE-priv code bases # https://github.com/microsoft/unilm/tree/master/beit # https://github.com/rwightman/pytorch-image-models/tree/master/timm # https://github.com/facebookresearch/deit # https://github.com/facebookresearch/dino # https://github.com/BUPT-PRIV/MAE-priv # -------------------------------------------------------- import numpy as np import torch try: import albumentations as A from albumentations.pytorch import ToTensorV2 except: print('albumentations not installed') # import cv2 import torch.nn.functional as F from utils import (IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, NYU_MEAN, NYU_STD, PAD_MASK_VALUE) from utils.dataset_folder import ImageFolder, MultiTaskImageFolder def nyu_transform(train, additional_targets, input_size=512, color_aug=False): if train: augs = [ A.SmallestMaxSize(max_size=input_size, p=1), A.HorizontalFlip(p=0.5), ] if color_aug: augs += [ # Color jittering from BYOL https://arxiv.org/pdf/2006.07733.pdf A.ColorJitter( brightness=0.1255, contrast=0.4, saturation=[0.5, 1.5], hue=[-0.2, 0.2], p=0.5 ), A.ToGray(p=0.3), ] augs += [ A.RandomCrop(height=input_size, width=input_size, p=1), A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ToTensorV2(), ] transform = A.Compose(augs, additional_targets=additional_targets) else: transform = A.Compose([ A.SmallestMaxSize(max_size=input_size, p=1), A.CenterCrop(height=input_size, width=input_size), A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ToTensorV2(), ], additional_targets=additional_targets) return transform def simple_regression_transform(train, additional_targets, input_size=512, pad_value=(128, 128, 128), pad_mask_value=PAD_MASK_VALUE): if train: transform = A.Compose([ A.HorizontalFlip(p=0.5), A.LongestMaxSize(max_size=input_size, p=1), A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.5), # Color jittering from MoCo-v3 / DINO A.RandomScale(scale_limit=(0.1 - 1, 2.0 - 1), p=1), # This is LSJ (0.1, 2.0) A.PadIfNeeded(min_height=input_size, min_width=input_size, position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT, border_mode=cv2.BORDER_CONSTANT, value=pad_value, mask_value=pad_mask_value), A.RandomCrop(height=input_size, width=input_size, p=1), A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ToTensorV2(), ], additional_targets=additional_targets) else: transform = A.Compose([ A.LongestMaxSize(max_size=input_size, p=1), A.PadIfNeeded(min_height=input_size, min_width=input_size, position=A.augmentations.PadIfNeeded.PositionType.TOP_LEFT, border_mode=cv2.BORDER_CONSTANT, value=pad_value, mask_value=pad_mask_value), A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), ToTensorV2(), ], additional_targets=additional_targets) return transform class DataAugmentationForRegression(object): def __init__(self, transform, mask_value=0.0): self.transform = transform self.mask_value = mask_value def __call__(self, task_dict): # Need to replace rgb key to image task_dict['image'] = task_dict.pop('rgb') # Convert to np.array task_dict = {k: np.array(v) for k, v in task_dict.items()} task_dict = self.transform(**task_dict) task_dict['depth'] = (task_dict['depth'].float() - NYU_MEAN)/NYU_STD # And then replace it back to rgb task_dict['rgb'] = task_dict.pop('image') task_dict['mask_valid'] = (task_dict['mask_valid'] == 255)[None] for task in task_dict: if task in ['depth']: img = task_dict[task] if 'mask_valid' in task_dict: mask_valid = task_dict['mask_valid'].squeeze() img[~mask_valid] = self.mask_value task_dict[task] = img.unsqueeze(0) elif task in ['rgb']: task_dict[task] = task_dict[task].to(torch.float) return task_dict def build_regression_dataset(args, data_path, transform, max_images=None): transform = DataAugmentationForRegression(transform=transform) return MultiTaskImageFolder(data_path, args.all_domains, transform=transform, prefixes=None, max_images=max_images)