MultiMAE / utils /datasets.py
Bachmann Roman Christian
Initial commit
3b49518
raw
history blame contribute delete
No virus
7.97 kB
# Copyright (c) EPFL VILAB.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Based on BEiT, timm, DINO, DeiT and MAE-priv code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# https://github.com/BUPT-PRIV/MAE-priv
# --------------------------------------------------------
import os
import random
import numpy as np
import torch
import torchvision.transforms.functional as TF
from torchvision import datasets, transforms
from utils import create_transform
from .data_constants import (IMAGE_TASKS, IMAGENET_DEFAULT_MEAN,
IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN,
IMAGENET_INCEPTION_STD)
from .dataset_folder import ImageFolder, MultiTaskImageFolder
def denormalize(img, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD):
return TF.normalize(
img.clone(),
mean= [-m/s for m, s in zip(mean, std)],
std= [1/s for s in std]
)
class DataAugmentationForMAE(object):
def __init__(self, args):
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
trans = [transforms.RandomResizedCrop(args.input_size)]
if args.hflip > 0.0:
trans.append(transforms.RandomHorizontalFlip(args.hflip))
trans.extend([
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))])
self.transform = transforms.Compose(trans)
def __call__(self, image):
return self.transform(image)
def __repr__(self):
repr = "(DataAugmentationForBEiT,\n"
repr += " transform = %s,\n" % str(self.transform)
repr += ")"
return repr
class DataAugmentationForMultiMAE(object):
def __init__(self, args):
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
self.rgb_mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
self.rgb_std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
self.input_size = args.input_size
self.hflip = args.hflip
def __call__(self, task_dict):
flip = random.random() < self.hflip # Stores whether to flip all images or not
ijhw = None # Stores crop coordinates used for all tasks
# Crop and flip all tasks randomly, but consistently for all tasks
for task in task_dict:
if task not in IMAGE_TASKS:
continue
if ijhw is None:
# Official MAE code uses (0.2, 1.0) for scale and (0.75, 1.3333) for ratio
ijhw = transforms.RandomResizedCrop.get_params(
task_dict[task], scale=(0.2, 1.0), ratio=(0.75, 1.3333)
)
i, j, h, w = ijhw
task_dict[task] = TF.crop(task_dict[task], i, j, h, w)
task_dict[task] = task_dict[task].resize((self.input_size, self.input_size))
if flip:
task_dict[task] = TF.hflip(task_dict[task])
# Convert to Tensor
for task in task_dict:
if task in ['depth']:
img = torch.Tensor(np.array(task_dict[task]) / 2 ** 16)
img = img.unsqueeze(0) # 1 x H x W
elif task in ['rgb']:
img = TF.to_tensor(task_dict[task])
img = TF.normalize(img, mean=self.rgb_mean, std=self.rgb_std)
elif task in ['semseg', 'semseg_coco']:
# TODO: add this to a config instead
# Rescale to 0.25x size (stride 4)
scale_factor = 0.25
img = task_dict[task].resize((int(self.input_size * scale_factor), int(self.input_size * scale_factor)))
# Using pil_to_tensor keeps it in uint8, to_tensor converts it to float (rescaled to [0, 1])
img = TF.pil_to_tensor(img).to(torch.long).squeeze(0)
task_dict[task] = img
return task_dict
def __repr__(self):
repr = "(DataAugmentationForMultiMAE,\n"
#repr += " transform = %s,\n" % str(self.transform)
repr += ")"
return repr
def build_pretraining_dataset(args):
transform = DataAugmentationForMAE(args)
print("Data Aug = %s" % str(transform))
return ImageFolder(args.data_path, transform=transform)
def build_multimae_pretraining_dataset(args):
transform = DataAugmentationForMultiMAE(args)
return MultiTaskImageFolder(args.data_path, args.all_domains, transform=transform)
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
print("Transform = ")
if isinstance(transform, tuple):
for trans in transform:
print(" - - - - - - - - - - ")
for t in trans.transforms:
print(t)
else:
for t in transform.transforms:
print(t)
print("---------------------------")
if args.data_set == 'CIFAR':
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'IMNET':
# root = os.path.join(args.data_path, 'train' if is_train else 'val')
root = args.data_path if is_train else args.eval_data_path
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif args.data_set == "image_folder":
root = args.data_path if is_train else args.eval_data_path
dataset = ImageFolder(root, transform=transform)
nb_classes = args.nb_classes
assert len(dataset.class_to_idx) == nb_classes
else:
raise NotImplementedError()
assert nb_classes == args.nb_classes
print("Number of the class = %d" % args.nb_classes)
return dataset, nb_classes
def build_transform(is_train, args):
resize_im = args.input_size > 32
imagenet_default_mean_and_std = args.imagenet_default_mean_and_std
mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN
std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD
if is_train:
# this should always dispatch to transforms_imagenet_train
transform = create_transform(
input_size=args.input_size,
is_training=True,
color_jitter=args.color_jitter,
auto_augment=args.aa,
interpolation=args.train_interpolation,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
mean=mean,
std=std,
)
if not resize_im:
# replace RandomResizedCropAndInterpolation with
# RandomCrop
transform.transforms[0] = transforms.RandomCrop(
args.input_size, padding=4)
return transform
t = []
if resize_im:
if args.crop_pct is None:
if args.input_size < 384:
args.crop_pct = 224 / 256
else:
args.crop_pct = 1.0
size = int(args.input_size / args.crop_pct)
t.append(
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
)
t.append(transforms.CenterCrop(args.input_size))
t.append(transforms.ToTensor())
t.append(transforms.Normalize(mean, std))
return transforms.Compose(t)