guesstimatelocation / data /datamodule.py
yunusserhat's picture
Upload 142 files
abd15df verified
import pytorch_lightning as L
from torch.utils.data import DataLoader, random_split
import torch
import time
class ImageDataModule(L.LightningDataModule):
def __init__(
self,
train_dataset,
val_dataset,
test_dataset,
global_batch_size,
num_workers,
num_nodes=1,
num_devices=1,
val_proportion=0.1,
):
super().__init__()
self._builders = {
"train": train_dataset,
"val": val_dataset,
"test": test_dataset,
}
self.num_workers = num_workers
self.batch_size = global_batch_size // (num_nodes * num_devices)
print(f"Each GPU will receive {self.batch_size} images")
self.val_proportion = val_proportion
@property
def num_classes(self):
if hasattr(self, "train_dataset"):
return self.train_dataset.num_classes
else:
return self._builders["train"]().num_classes
def setup(self, stage=None):
"""Setup the datamodule.
Args:
stage (str): stage of the datamodule
Is be one of "fit" or "test" or None
"""
print("Stage", stage)
start_time = time.time()
if stage == "fit" or stage is None:
self.train_dataset = self._builders["train"]()
self.val_dataset = self._builders["val"]()
print(f"Train dataset size: {len(self.train_dataset)}")
print(f"Val dataset size: {len(self.val_dataset)}")
else:
self.test_dataset = self._builders["test"]()
print(f"Test dataset size: {len(self.test_dataset)}")
end_time = time.time()
print(f"Setup took {(end_time - start_time):.2f} seconds")
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
pin_memory=False,
drop_last=True,
num_workers=self.num_workers,
collate_fn=self.train_dataset.collate_fn_density,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
pin_memory=False,
num_workers=self.num_workers,
collate_fn=self.val_dataset.collate_fn,
)
def test_dataloader(self):
return DataLoader(
self.test_dataset,
batch_size=self.batch_size,
shuffle=False,
pin_memory=False,
num_workers=self.num_workers,
collate_fn=self.test_dataset.collate_fn,
)