Mugs / src /multicropdataset.py
zhoupans's picture
Upload 13 files
3c849be
# Copyright 2022 Garena Online Private Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
multi-crop dataset to implement multi-crop augmentation and also dataset
"""
import copy
import random
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageFilter, ImageOps
from src.dataset import ImageFolder
from src.RandAugment import rand_augment_transform
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data.random_erasing import RandomErasing
from timm.data.transforms import _pil_interp
class GaussianBlur(object):
"""
Apply Gaussian Blur to the PIL image.
"""
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.0):
self.prob = p
self.radius_min = radius_min
self.radius_max = radius_max
def __call__(self, img):
do_it = random.random() <= self.prob
if not do_it:
return img
return img.filter(
ImageFilter.GaussianBlur(
radius=random.uniform(self.radius_min, self.radius_max)
)
)
class Solarization(object):
"""
Apply Solarization to the PIL image.
"""
def __init__(self, p):
self.p = p
def __call__(self, img):
if random.random() < self.p:
return ImageOps.solarize(img)
else:
return img
def strong_transforms(
img_size=224,
scale=(0.08, 1.0),
ratio=(0.75, 1.3333333333333333),
hflip=0.5,
vflip=0.0,
color_jitter=0.4,
auto_augment="rand-m9-mstd0.5-inc1",
interpolation="random",
use_prefetcher=True,
mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
re_prob=0.25,
re_mode="pixel",
re_count=1,
re_num_splits=0,
color_aug=False,
strong_ratio=0.45,
):
"""
for use in a mixing dataset that passes
* all data through the first (primary) transform, called the 'clean' data
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3.0 / 4.0, 4.0 / 3.0)) # default imagenet ratio range
primary_tfl = []
if hflip > 0.0:
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
if vflip > 0.0:
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
secondary_tfl = []
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple):
img_size_min = min(img_size)
else:
img_size_min = img_size
aa_params = dict(
translate_const=int(img_size_min * strong_ratio),
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != "random":
aa_params["interpolation"] = _pil_interp(interpolation)
if auto_augment.startswith("rand"):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
if color_jitter is not None and color_aug:
# color jitter is enabled when not using AA
flip_and_color_jitter = [
transforms.RandomApply(
[
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
)
],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
secondary_tfl += flip_and_color_jitter
if interpolation == "random":
interpolation = (Image.BILINEAR, Image.BICUBIC)
else:
interpolation = _pil_interp(interpolation)
final_tfl = [
transforms.RandomResizedCrop(
size=img_size, scale=scale, ratio=ratio, interpolation=Image.BICUBIC
)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [transforms.ToTensor()]
else:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)),
]
if re_prob > 0.0:
final_tfl.append(
RandomErasing(
re_prob,
mode=re_mode,
max_count=re_count,
num_splits=re_num_splits,
device="cpu",
)
)
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
class DataAugmentation(object):
"""
implement multi-crop data augmentation.
--global_crops_scale: scale range of the 224-sized cropped image before resizing
--local_crops_scale: scale range of the 96-sized cropped image before resizing
--local_crops_number: Number of small local views to generate
--prob: when we use strong augmentation and weak augmentation, the ratio of images to
be cropped with strong augmentation
--vanilla_weak_augmentation: whether we use the same augmentation in DINO, namely
only using weak augmentation
--color_aug: after AutoAugment, whether we further perform color augmentation
--local_crop_size: the small crop size
--timm_auto_augment_par: the parameters for the AutoAugment used in DeiT
--strong_ratio: the ratio of image augmentation for the AutoAugment used in DeiT
--re_prob: the re-prob parameter of image augmentation for the AutoAugment used in DeiT
--use_prefetcher: whether we use prefetcher which can accerelate the training speed
"""
def __init__(
self,
global_crops_scale,
local_crops_scale,
local_crops_number,
prob=0.5,
vanilla_weak_augmentation=False,
color_aug=False,
local_crop_size=[96],
timm_auto_augment_par="rand-m9-mstd0.5-inc1",
strong_ratio=0.45,
re_prob=0.25,
use_prefetcher=False,
):
## propability to perform strong augmentation
self.prob = prob
## whether we use the commonly used augmentations, e.g. DINO or MoCo-V3
self.vanilla_weak_augmentation = vanilla_weak_augmentation
flip_and_color_jitter = transforms.Compose(
[
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomApply(
[
transforms.ColorJitter(
brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1
)
],
p=0.8,
),
transforms.RandomGrayscale(p=0.2),
]
)
if use_prefetcher:
normalize = transforms.Compose(
[
transforms.ToTensor(),
]
)
else:
normalize = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
]
)
##====== build augmentation of global crops, i.e. 224-sized image crops =========
# first global crop, always weak augmentation
self.global_transfo1 = transforms.Compose(
[
transforms.RandomResizedCrop(
224, scale=global_crops_scale, interpolation=Image.BICUBIC
),
flip_and_color_jitter,
GaussianBlur(1.0),
normalize,
]
)
# second global crop, always weak augmentation
self.global_transfo2 = transforms.Compose(
[
transforms.RandomResizedCrop(
224, scale=global_crops_scale, interpolation=Image.BICUBIC
),
flip_and_color_jitter,
GaussianBlur(0.1),
Solarization(0.2),
normalize,
]
)
# strong augmentation, maybe used if we need to perform strong augmentation
self.global_transfo3 = strong_transforms(
img_size=224,
scale=global_crops_scale,
ratio=(0.75, 1.3333333333333333),
hflip=0.5,
vflip=0.0,
color_jitter=0.4,
auto_augment=timm_auto_augment_par, # 'rand-m9-mstd0.5-inc1'
interpolation="random",
use_prefetcher=use_prefetcher, # True
mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
re_prob=re_prob, # 0.25
re_mode="pixel",
re_count=1,
re_num_splits=0,
color_aug=color_aug,
strong_ratio=strong_ratio,
)
##====== build augmentation of local crops, i.e. 96-sized image crops =========
self.local_crops_number = (
local_crops_number # transformation for the local small crops
)
assert local_crop_size[0] == 96
# weak augmentation, maybe used if we need to perform weak augmentation
self.local_transfo = transforms.Compose(
[
transforms.RandomResizedCrop(
local_crop_size[0],
scale=local_crops_scale,
interpolation=Image.BICUBIC,
),
flip_and_color_jitter,
GaussianBlur(p=0.5),
normalize,
]
)
# strong augmentation, maybe used if we need to perform strong augmentation
self.local_transfo2 = strong_transforms(
img_size=local_crop_size[0], # (224, 224)
scale=local_crops_scale, # (0.08, 1.0)
ratio=(0.75, 1.3333333333333333), # (0.75, 1.3333333333333333)
hflip=0.5, # 0.5
vflip=0.0, # 0.0
color_jitter=0.4, # 0.4
auto_augment=timm_auto_augment_par, # 'rand-m9-mstd0.5-inc1'
interpolation="random", # 'random'
use_prefetcher=use_prefetcher, # True
mean=IMAGENET_DEFAULT_MEAN, # (0.485, 0.456, 0.406)
std=IMAGENET_DEFAULT_STD, # (0.229, 0.224, 0.225)
re_prob=re_prob, # 0.25
re_mode="pixel", # 'pixel'
re_count=1, # 1
re_num_splits=0, # 0
color_aug=color_aug,
strong_ratio=strong_ratio,
)
def __call__(self, image):
"""
implement multi-crop data augmentation. Generate two 224-sized +
"local_crops_number" 96-sized images
"""
crops = []
##====== images to be fed into teacher, two 224-sized =========
img1 = self.global_transfo1(image)
img2 = self.global_transfo2(image)
crops.append(img1)
crops.append(img2)
##====== images to be fed into student, two 224-sized + "local_crops_number" 96-sized =========
# first to generate two 224-sized
# this weak_flag indicates whether the current image is weakly augmented.
# For local group supervision, we only use weakly augmented images of size 224 to
# update the memory for local-group aggregation.
weak_flag = False
if self.vanilla_weak_augmentation is True:
## directly copy the images of weak augmentation
crops.append(copy.deepcopy(img1))
crops.append(copy.deepcopy(img2))
weak_flag = True
elif self.prob < 1.0 and random.random() > self.prob:
## whether perform strong augmentation
crops.append(self.global_transfo3(image))
crops.append(self.global_transfo3(image))
else:
## perform weak augmentation
crops.append(self.global_transfo1(image))
crops.append(self.global_transfo2(image))
weak_flag = True
# then to generate "local_crops_number" 96-sized
for _ in range(self.local_crops_number):
if self.prob < 1.0 and random.random() > self.prob:
## whether perform strong augmentation
crops.append(self.local_transfo2(image))
else:
## perform weak augmentation
crops.append(self.local_transfo(image))
return crops, weak_flag
def get_dataset(args):
"""
build a multi-crop data augmentation and a dataset/dataloader
"""
## preparing augmentations, including weak and strong augmentations
transform = DataAugmentation(
global_crops_scale=args.global_crops_scale,
local_crops_scale=args.local_crops_scale,
local_crops_number=args.local_crops_number,
vanilla_weak_augmentation=args.vanilla_weak_augmentation,
prob=args.prob,
color_aug=args.color_aug,
local_crop_size=args.size_crops,
timm_auto_augment_par=args.timm_auto_augment_par,
strong_ratio=args.strong_ratio,
re_prob=args.re_prob,
use_prefetcher=args.use_prefetcher,
)
## For debug mode, we only load the first two classes to reduce data reading time.
## otherwise, we load all training data for pretraining.
class_num = 2 if args.debug else 1000
dataset = ImageFolder(args.data_path, transform=transform, class_num=class_num)
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
batch_size=args.batch_size_per_gpu,
num_workers=args.num_workers,
pin_memory=True,
drop_last=True,
)
return data_loader
class data_prefetcher:
"""
implement data prefetcher. we perform some augmentation on GPUs intead of CPUs
--loader: a data loader
--fp16: whether we use fp16, if yes, we need to tranform the data to be fp16
"""
def __init__(self, loader, fp16=True):
self.loader = iter(loader)
self.fp16 = fp16
self.stream = torch.cuda.Stream()
self.mean = torch.tensor([0.485, 0.456, 0.406]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([0.229, 0.224, 0.225]).cuda().view(1, 3, 1, 1)
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
self.preload()
def preload(self):
"""
preload the next minibatch of data
"""
try:
self.multi_crops, self.weak_flag = next(self.loader)
except StopIteration:
self.multi_crops, self.weak_flag = None, None
return
with torch.cuda.stream(self.stream):
for i in range(len(self.multi_crops)):
self.multi_crops[i] = self.multi_crops[i].cuda(non_blocking=True)
if self.fp16:
self.multi_crops[i] = (
self.multi_crops[i].half().sub_(self.mean).div_(self.std)
)
else:
self.multi_crops[i] = (
self.multi_crops[i].float().sub_(self.mean).div_(self.std)
)
def next(self):
"""
load the next minibatch of data
"""
torch.cuda.current_stream().wait_stream(self.stream)
multi_crops, weak_flags = self.multi_crops, self.weak_flag
self.preload()
return multi_crops, weak_flags