Spaces:
Runtime error
Runtime error
| import pytorch_lightning as pl | |
| import torch | |
| from pytorch_lightning.utilities.types import (EVAL_DATALOADERS) | |
| from dataloader.dataset_factory import SingleImageDatasetFactory | |
| class VideoDataModule(pl.LightningDataModule): | |
| def __init__(self, | |
| workers: int, | |
| predict_dataset_factory: SingleImageDatasetFactory = None, | |
| ) -> None: | |
| super().__init__() | |
| self.num_workers = workers | |
| self.video_data_module = {} | |
| # TODO read size from loaded unet via unet.sample_sizes | |
| self.predict_dataset_factory = predict_dataset_factory | |
| def setup(self, stage: str) -> None: | |
| if stage == "predict": | |
| self.video_data_module["predict"] = self.predict_dataset_factory.get_dataset( | |
| ) | |
| def predict_dataloader(self) -> EVAL_DATALOADERS: | |
| return torch.utils.data.DataLoader(self.video_data_module["predict"], | |
| batch_size=1, | |
| pin_memory=True, | |
| num_workers=self.num_workers, | |
| collate_fn=None, | |
| shuffle=False, | |
| drop_last=False) | |