StreamingSVD / dataloader /video_data_module.py
lev1's picture
Initial commit
8fd2f2f
raw
history blame
1.29 kB
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.types import (EVAL_DATALOADERS)
from dataloader.dataset_factory import SingleImageDatasetFactory
class VideoDataModule(pl.LightningDataModule):
def __init__(self,
workers: int,
predict_dataset_factory: SingleImageDatasetFactory = None,
) -> None:
super().__init__()
self.num_workers = workers
self.video_data_module = {}
# TODO read size from loaded unet via unet.sample_sizes
self.predict_dataset_factory = predict_dataset_factory
def setup(self, stage: str) -> None:
if stage == "predict":
self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset(
)
def predict_dataloader(self) -> EVAL_DATALOADERS:
return torch.utils.data.DataLoader(self.video_data_module["predict"],
batch_size=1,
pin_memory=True,
num_workers=self.num_workers,
collate_fn=None,
shuffle=False,
drop_last=False)