Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from mmengine.utils.dl_utils import TORCH_VERSION | |
| from mmengine.utils.version_utils import digit_version | |
| from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage, | |
| MomentumAnnealingEMA, StochasticWeightAverage) | |
| from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor | |
| from .base_module import BaseModule, ModuleDict, ModuleList, Sequential | |
| from .test_time_aug import BaseTTAModel | |
| from .utils import (convert_sync_batchnorm, detect_anomalous_params, | |
| merge_dict, revert_sync_batchnorm, stack_batch) | |
| from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit, | |
| KaimingInit, NormalInit, PretrainedInit, | |
| TruncNormalInit, UniformInit, XavierInit, | |
| bias_init_with_prob, caffe2_xavier_init, | |
| constant_init, initialize, kaiming_init, normal_init, | |
| trunc_normal_init, uniform_init, update_init_info, | |
| xavier_init) | |
| from .wrappers import (MMDistributedDataParallel, | |
| MMSeparateDistributedDataParallel, is_model_wrapper) | |
| __all__ = [ | |
| 'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel', | |
| 'StochasticWeightAverage', 'ExponentialMovingAverage', | |
| 'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor', | |
| 'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule', | |
| 'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList', | |
| 'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info', | |
| 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init', | |
| 'uniform_init', 'kaiming_init', 'caffe2_xavier_init', | |
| 'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit', | |
| 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', | |
| 'Caffe2XavierInit', 'PretrainedInit', 'initialize', | |
| 'convert_sync_batchnorm', 'BaseTTAModel' | |
| ] | |
| if digit_version(TORCH_VERSION) >= digit_version('2.0.0'): | |
| from .wrappers import MMFullyShardedDataParallel # noqa:F401 | |
| __all__.append('MMFullyShardedDataParallel') | |