Spaces:
Sleeping
Sleeping
File size: 12,262 Bytes
e45d058 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
# Adapted from https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/datamodules/imagenet_datamodule.py
import os
from pathlib import Path
from typing import Any, List, Union, Callable, Optional
import torch
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from torch.utils.data.dataloader import default_collate
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning import LightningDataModule
from torchvision import transforms
from torchvision.datasets import ImageFolder
class DictDataset(Dataset):
def __init__(self, dataset_dict, length=None):
"""dataset_dict: dictionary mapping from index to batch
length is used in the case of DistributedSampler: e.g. the dataset could have size 1k, but
with 8 GPUs the dataset_dict would only have 125 items.
"""
super().__init__()
self.dataset_dict = dataset_dict
self.length = length or len(self.dataset_dict)
def __getitem__(self, index):
return self.dataset_dict[index]
def __len__(self):
return self.length
# From https://github.com/PyTorchLightning/lightning-bolts/blob/2415b49a2b405693cd499e09162c89f807abbdc4/pl_bolts/transforms/dataset_normalizations.py#L10
def imagenet_normalization():
return transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
class ImagenetDataModule(LightningDataModule):
"""
.. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/
Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png
:width: 400
:alt: Imagenet
Specs:
- 1000 classes
- Each image is (3 x varies x varies) (here we default to 3 x 224 x 224)
Imagenet train, val and test dataloaders.
The train set is the imagenet train.
The val set is taken from the train set with `num_imgs_per_val_class` images per class.
For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set.
The test set is the official imagenet validation set.
Example::
from pl_bolts.datamodules import ImagenetDataModule
dm = ImagenetDataModule(IMAGENET_PATH)
model = LitModel()
Trainer().fit(model, datamodule=dm)
"""
name = "imagenet"
def __init__(
self,
data_dir: str,
image_size: int = 224,
train_transforms=None,
val_transforms=None,
test_transforms=None,
img_dtype='float32', # Using str since OmegaConf doesn't support non-primitive type
cache_val_dataset=False,
mixup: Optional[Callable] = None,
num_aug_repeats: int = 0,
num_workers: int = 0,
batch_size: int = 32,
batch_size_eval: Optional[int] = None,
shuffle: bool = True,
pin_memory: bool = True,
drop_last: bool = False,
*args: Any,
**kwargs: Any,
) -> None:
"""
Args:
data_dir: path to the imagenet dataset file
num_imgs_per_val_class: how many images per class for the validation set
image_size: final image size
num_workers: how many data workers
batch_size: batch_size
shuffle: If true shuffles the data every epoch
pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before
returning them
drop_last: If true drops the last incomplete batch
"""
super().__init__(*args, **kwargs)
self.image_size = image_size
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
assert img_dtype in ['float32', 'float16', 'bfloat16']
self.img_dtype = torch.__getattribute__(img_dtype)
self.cache_val_dataset = cache_val_dataset
self.mixup = mixup
self.num_aug_repeats = num_aug_repeats
self.dims = (3, self.image_size, self.image_size)
self.data_dir = Path(data_dir).expanduser()
self.num_workers = num_workers
self.batch_size = batch_size
self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size
self.shuffle = shuffle
self.pin_memory = pin_memory
self.drop_last = drop_last
@property
def num_classes(self) -> int:
"""
Return:
1000
"""
return 1000
def _verify_splits(self, data_dir: str, split: str) -> None:
dirs = os.listdir(data_dir)
if split not in dirs:
raise FileNotFoundError(
f"a {split} Imagenet split was not found in {data_dir},"
f" make sure the folder contains a subfolder named {split}"
)
def prepare_data(self) -> None:
"""This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin.
.. warning:: Please download imagenet on your own first.
"""
self._verify_splits(self.data_dir, "train")
self._verify_splits(self.data_dir, "val")
def setup(self, stage: Optional[str] = None) -> None:
"""Creates train, val, and test dataset."""
if stage == "fit" or stage is None:
train_transforms = (self.train_transform() if self.train_transforms is None
else self.train_transforms)
val_transforms = (self.val_transform() if self.val_transforms is None
else self.val_transforms)
if self.img_dtype is not torch.float32:
assert isinstance(train_transforms, transforms.Compose)
assert isinstance(val_transforms, transforms.Compose)
convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))
train_transforms.transforms.append(convert_dtype)
val_transforms.transforms.append(convert_dtype)
self.dataset_train = ImageFolder(self.data_dir / 'train', transform=train_transforms)
self.dataset_val = ImageFolder(self.data_dir / 'val', transform=val_transforms)
if stage == "test" or stage is None:
test_transforms = (self.val_transform() if self.test_transforms is None
else self.test_transforms)
if self.img_dtype is not torch.float32:
assert isinstance(test_transforms, transforms.Compose)
convert_dtype = transforms.Lambda(lambda x: x.to(dtype=self.img_dtype))
test_transforms.transforms.append(convert_dtype)
self.dataset_test = ImageFolder(self.data_dir / 'val', transform=test_transforms)
def train_transform(self) -> Callable:
"""The standard imagenet transforms.
.. code-block:: python
transforms.Compose([
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transforms.Compose(
[
transforms.RandomResizedCrop(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
imagenet_normalization(),
]
)
return preprocessing
def val_transform(self) -> Callable:
"""The standard imagenet transforms for validation.
.. code-block:: python
transforms.Compose([
transforms.Resize(self.image_size + 32),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
"""
preprocessing = transforms.Compose(
[
transforms.Resize(self.image_size + 32),
transforms.CenterCrop(self.image_size),
transforms.ToTensor(),
imagenet_normalization(),
]
)
return preprocessing
def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader:
""" The train dataloader """
if self.num_aug_repeats == 0:
shuffle = self.shuffle
sampler = None
else:
shuffle = False
from timm.data.distributed_sampler import RepeatAugSampler
sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats)
return self._data_loader(self.dataset_train, batch_size=self.batch_size,
shuffle=shuffle, mixup=self.mixup, sampler=sampler)
def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The val dataloader """
# If using RepeatAugment, we set trainer.replace_sampler_ddp=False, so we have to
# construct the DistributedSampler ourselves.
if not self.cache_val_dataset:
sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,
sampler=sampler)
else:
print('Caching val dataset')
sampler = (SequentialSampler(self.dataset_val) if self.trainer.world_size <= 1
else DistributedSampler(self.dataset_val, shuffle=False,
drop_last=self.drop_last))
indices = list(iter(sampler))
loader = DataLoader(self.dataset_val, batch_size=None, shuffle=False, sampler=sampler,
num_workers=self.num_workers, drop_last=self.drop_last)
batches = list(loader)
assert len(batches) == len(indices)
self.dataset_val = DictDataset(dict(zip(indices, batches)),
length=len(self.dataset_val))
sampler = (DistributedSampler(self.dataset_val, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval,
sampler=sampler)
def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]:
""" The test dataloader """
sampler = (DistributedSampler(self.dataset_test, shuffle=False, drop_last=self.drop_last)
if self.num_aug_repeats != 0 else None)
return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval, sampler=sampler)
def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False,
mixup: Optional[Callable] = None, sampler=None) -> DataLoader:
collate_fn = ((lambda batch: mixup(*default_collate(batch))) if mixup is not None
else default_collate)
return DataLoader(
dataset,
collate_fn=collate_fn,
batch_size=batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=self.num_workers,
drop_last=self.drop_last,
pin_memory=self.pin_memory,
persistent_workers=True
)
class Imagenet21kPDataModule(ImagenetDataModule):
"""ImageNet-21k (winter 21) processed with https://github.com/Alibaba-MIIL/ImageNet21K
"""
@property
def num_classes(self) -> int:
"""
Return:
10450
"""
return 10450
|