Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import pytorch_lightning as pl | |
from omegaconf import OmegaConf | |
from functools import partial | |
from ldm.util import instantiate_from_config | |
from torch.utils.data import random_split, DataLoader, Dataset, Subset | |
class WrappedDataset(Dataset): | |
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" | |
def __init__(self, dataset): | |
self.data = dataset | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
return self.data[idx] | |
class DataModuleFromConfig(pl.LightningDataModule): | |
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, | |
wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, | |
shuffle_val_dataloader=False): | |
super().__init__() | |
self.batch_size = batch_size | |
self.dataset_configs = dict() | |
self.num_workers = num_workers if num_workers is not None else batch_size * 2 | |
self.use_worker_init_fn = use_worker_init_fn | |
if train is not None: | |
self.dataset_configs["train"] = train | |
self.train_dataloader = self._train_dataloader | |
if validation is not None: | |
self.dataset_configs["validation"] = validation | |
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) | |
if test is not None: | |
self.dataset_configs["test"] = test | |
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) | |
if predict is not None: | |
self.dataset_configs["predict"] = predict | |
self.predict_dataloader = self._predict_dataloader | |
self.wrap = wrap | |
def prepare_data(self): | |
for data_cfg in self.dataset_configs.values(): | |
instantiate_from_config(data_cfg) | |
def setup(self, stage=None): | |
self.datasets = dict( | |
(k, instantiate_from_config(self.dataset_configs[k])) | |
for k in self.dataset_configs) | |
if self.wrap: | |
for k in self.datasets: | |
self.datasets[k] = WrappedDataset(self.datasets[k]) | |
def _train_dataloader(self): | |
init_fn = None | |
return DataLoader(self.datasets["train"], batch_size=self.batch_size, | |
num_workers=self.num_workers, shuffle= True, | |
worker_init_fn=init_fn) | |
def _val_dataloader(self, shuffle=False): | |
init_fn = None | |
return DataLoader(self.datasets["validation"], | |
batch_size=self.batch_size, | |
num_workers=self.num_workers, | |
worker_init_fn=init_fn, | |
shuffle=shuffle) | |
def _test_dataloader(self, shuffle=False): | |
is_iterable_dataset = isinstance(self.datasets['train'], Txt2ImgIterableBaseDataset) | |
if is_iterable_dataset or self.use_worker_init_fn: | |
init_fn = worker_init_fn | |
else: | |
init_fn = None | |
# do not shuffle dataloader for iterable dataset | |
shuffle = shuffle and (not is_iterable_dataset) | |
return DataLoader(self.datasets["test"], batch_size=self.batch_size, | |
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle) | |
def _predict_dataloader(self, shuffle=False): | |
if isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn: | |
init_fn = worker_init_fn | |
else: | |
init_fn = None | |
return DataLoader(self.datasets["predict"], batch_size=self.batch_size, | |
num_workers=self.num_workers, worker_init_fn=init_fn) | |
def create_data(config): | |
data = instantiate_from_config(config.data) | |
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html | |
# calling these ourselves should not be necessary but it is. | |
# lightning still takes care of proper multiprocessing though | |
data.prepare_data() | |
data.setup() | |
return data |