Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Any, Dict, Union | |
| from torch.utils.data import DataLoader | |
| class BaseLoop(metaclass=ABCMeta): | |
| """Base loop class. | |
| All subclasses inherited from ``BaseLoop`` should overwrite the | |
| :meth:`run` method. | |
| Args: | |
| runner (Runner): A reference of runner. | |
| dataloader (Dataloader or dict): An iterator to generate one batch of | |
| dataset each iteration. | |
| """ | |
| def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None: | |
| self._runner = runner | |
| if isinstance(dataloader, dict): | |
| # Determine whether or not different ranks use different seed. | |
| diff_rank_seed = runner._randomness_cfg.get( | |
| 'diff_rank_seed', False) | |
| self.dataloader = runner.build_dataloader( | |
| dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed) | |
| else: | |
| self.dataloader = dataloader | |
| def runner(self): | |
| return self._runner | |
| def run(self) -> Any: | |
| """Execute loop.""" | |