Spaces:
Runtime error
Runtime error
File size: 1,286 Bytes
8fd2f2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.types import (EVAL_DATALOADERS)
from dataloader.dataset_factory import SingleImageDatasetFactory
class VideoDataModule(pl.LightningDataModule):
def __init__(self,
workers: int,
predict_dataset_factory: SingleImageDatasetFactory = None,
) -> None:
super().__init__()
self.num_workers = workers
self.video_data_module = {}
# TODO read size from loaded unet via unet.sample_sizes
self.predict_dataset_factory = predict_dataset_factory
def setup(self, stage: str) -> None:
if stage == "predict":
self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset(
)
def predict_dataloader(self) -> EVAL_DATALOADERS:
return torch.utils.data.DataLoader(self.video_data_module["predict"],
batch_size=1,
pin_memory=True,
num_workers=self.num_workers,
collate_fn=None,
shuffle=False,
drop_last=False)
|