Spaces:
Runtime error
Runtime error
| ''' | |
| * Copyright (c) 2023 Salesforce, Inc. | |
| * All rights reserved. | |
| * SPDX-License-Identifier: Apache License 2.0 | |
| * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ | |
| * By Can Qin | |
| * Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet | |
| * Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala | |
| * Modified from MMCV repo: From https://github.com/open-mmlab/mmcv | |
| * Copyright (c) OpenMMLab. All rights reserved. | |
| ''' | |
| from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | |
| from .builder import DATASETS | |
| class ConcatDataset(_ConcatDataset): | |
| """A wrapper of concatenated dataset. | |
| Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but | |
| concat the group flag for image aspect ratio. | |
| Args: | |
| datasets (list[:obj:`Dataset`]): A list of datasets. | |
| """ | |
| def __init__(self, datasets): | |
| super(ConcatDataset, self).__init__(datasets) | |
| self.CLASSES = datasets[0].CLASSES | |
| self.PALETTE = datasets[0].PALETTE | |
| class RepeatDataset(object): | |
| """A wrapper of repeated dataset. | |
| The length of repeated dataset will be `times` larger than the original | |
| dataset. This is useful when the data loading time is long but the dataset | |
| is small. Using RepeatDataset can reduce the data loading time between | |
| epochs. | |
| Args: | |
| dataset (:obj:`Dataset`): The dataset to be repeated. | |
| times (int): Repeat times. | |
| """ | |
| def __init__(self, dataset, times): | |
| self.dataset = dataset | |
| self.times = times | |
| self.CLASSES = dataset.CLASSES | |
| self.PALETTE = dataset.PALETTE | |
| self._ori_len = len(self.dataset) | |
| def __getitem__(self, idx): | |
| """Get item from original dataset.""" | |
| return self.dataset[idx % self._ori_len] | |
| def __len__(self): | |
| """The length is multiplied by ``times``""" | |
| return self.times * self._ori_len | |