lev1's picture
Initial commit
8fd2f2f
raw
history blame
2.97 kB
from typing import Optional
import torchdata.datapipes.iter
import webdataset as wds
from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule
try:
from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e:
print("#" * 100)
print("Datasets not yet available")
print("to enable, we need to add stable-datasets as a submodule")
print("please use ``git submodule update --init --recursive``")
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
print("#" * 100)
exit(1)
class StableDataModuleFromConfig(LightningDataModule):
def __init__(
self,
train: DictConfig,
validation: Optional[DictConfig] = None,
test: Optional[DictConfig] = None,
skip_val_loader: bool = False,
dummy: bool = False,
):
super().__init__()
self.train_config = train
assert (
"datapipeline" in self.train_config and "loader" in self.train_config
), "train config requires the fields `datapipeline` and `loader`"
self.val_config = validation
if not skip_val_loader:
if self.val_config is not None:
assert (
"datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`"
else:
print(
"Warning: No Validation datapipeline defined, using that one from training"
)
self.val_config = train
self.test_config = test
if self.test_config is not None:
assert (
"datapipeline" in self.test_config and "loader" in self.test_config
), "test config requires the fields `datapipeline` and `loader`"
self.dummy = dummy
if self.dummy:
print("#" * 100)
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)
def setup(self, stage: str) -> None:
print("Preparing datasets")
if self.dummy:
data_fn = create_dummy_dataset
else:
data_fn = create_dataset
self.train_datapipeline = data_fn(**self.train_config.datapipeline)
if self.val_config:
self.val_datapipeline = data_fn(**self.val_config.datapipeline)
if self.test_config:
self.test_datapipeline = data_fn(**self.test_config.datapipeline)
def train_dataloader(self) -> torchdata.datapipes.iter.IterDataPipe:
loader = create_loader(self.train_datapipeline, **self.train_config.loader)
return loader
def val_dataloader(self) -> wds.DataPipeline:
return create_loader(self.val_datapipeline, **self.val_config.loader)
def test_dataloader(self) -> wds.DataPipeline:
return create_loader(self.test_datapipeline, **self.test_config.loader)