Spaces:
Running
Running
import pytorch_lightning as L | |
from torch.utils.data import DataLoader, random_split | |
import torch | |
import time | |
class ImageDataModule(L.LightningDataModule): | |
def __init__( | |
self, | |
train_dataset, | |
val_dataset, | |
test_dataset, | |
global_batch_size, | |
num_workers, | |
num_nodes=1, | |
num_devices=1, | |
val_proportion=0.1, | |
): | |
super().__init__() | |
self._builders = { | |
"train": train_dataset, | |
"val": val_dataset, | |
"test": test_dataset, | |
} | |
self.num_workers = num_workers | |
self.batch_size = global_batch_size // (num_nodes * num_devices) | |
print(f"Each GPU will receive {self.batch_size} images") | |
self.val_proportion = val_proportion | |
def num_classes(self): | |
if hasattr(self, "train_dataset"): | |
return self.train_dataset.num_classes | |
else: | |
return self._builders["train"]().num_classes | |
def setup(self, stage=None): | |
"""Setup the datamodule. | |
Args: | |
stage (str): stage of the datamodule | |
Is be one of "fit" or "test" or None | |
""" | |
print("Stage", stage) | |
start_time = time.time() | |
if stage == "fit" or stage is None: | |
self.train_dataset = self._builders["train"]() | |
self.val_dataset = self._builders["val"]() | |
print(f"Train dataset size: {len(self.train_dataset)}") | |
print(f"Val dataset size: {len(self.val_dataset)}") | |
else: | |
self.test_dataset = self._builders["test"]() | |
print(f"Test dataset size: {len(self.test_dataset)}") | |
end_time = time.time() | |
print(f"Setup took {(end_time - start_time):.2f} seconds") | |
def train_dataloader(self): | |
return DataLoader( | |
self.train_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
pin_memory=False, | |
drop_last=True, | |
num_workers=self.num_workers, | |
collate_fn=self.train_dataset.collate_fn_density, | |
) | |
def val_dataloader(self): | |
return DataLoader( | |
self.val_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
pin_memory=False, | |
num_workers=self.num_workers, | |
collate_fn=self.val_dataset.collate_fn, | |
) | |
def test_dataloader(self): | |
return DataLoader( | |
self.test_dataset, | |
batch_size=self.batch_size, | |
shuffle=False, | |
pin_memory=False, | |
num_workers=self.num_workers, | |
collate_fn=self.test_dataset.collate_fn, | |
) | |