Spaces:
Runtime error
Runtime error
import pytorch_lightning as pl | |
import torch | |
from pytorch_lightning.utilities.types import (EVAL_DATALOADERS) | |
from dataloader.dataset_factory import SingleImageDatasetFactory | |
class VideoDataModule(pl.LightningDataModule): | |
def __init__(self, | |
workers: int, | |
predict_dataset_factory: SingleImageDatasetFactory = None, | |
) -> None: | |
super().__init__() | |
self.num_workers = workers | |
self.video_data_module = {} | |
# TODO read size from loaded unet via unet.sample_sizes | |
self.predict_dataset_factory = predict_dataset_factory | |
def setup(self, stage: str) -> None: | |
if stage == "predict": | |
self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset( | |
) | |
def predict_dataloader(self) -> EVAL_DATALOADERS: | |
return torch.utils.data.DataLoader(self.video_data_module["predict"], | |
batch_size=1, | |
pin_memory=True, | |
num_workers=self.num_workers, | |
collate_fn=None, | |
shuffle=False, | |
drop_last=False) | |