File size: 1,286 Bytes
8fd2f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import pytorch_lightning as pl
import torch
from pytorch_lightning.utilities.types import (EVAL_DATALOADERS)
from dataloader.dataset_factory import SingleImageDatasetFactory


class VideoDataModule(pl.LightningDataModule):

    def __init__(self,
                 workers: int,
                 predict_dataset_factory: SingleImageDatasetFactory = None,
                 ) -> None:
        super().__init__()
        self.num_workers = workers

        self.video_data_module = {}
        # TODO read size from loaded unet via unet.sample_sizes
        self.predict_dataset_factory = predict_dataset_factory

    def setup(self, stage: str) -> None:
        if stage == "predict":
            self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset(
            )

    def predict_dataloader(self) -> EVAL_DATALOADERS:
        return torch.utils.data.DataLoader(self.video_data_module["predict"],
                                           batch_size=1,
                                           pin_memory=True,
                                           num_workers=self.num_workers,
                                           collate_fn=None,
                                           shuffle=False,
                                           drop_last=False)