Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import inspect | |
| import logging | |
| import sys | |
| from collections.abc import Callable | |
| from contextlib import contextmanager | |
| from importlib import import_module | |
| from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union | |
| from rich.console import Console | |
| from rich.table import Table | |
| from mmengine.config.utils import MODULE2PACKAGE | |
| from mmengine.utils import get_object_from_string, is_seq_of | |
| from .default_scope import DefaultScope | |
| class Registry: | |
| """A registry to map strings to classes or functions. | |
| Registered object could be built from registry. Meanwhile, registered | |
| functions could be called from registry. | |
| Args: | |
| name (str): Registry name. | |
| build_func (callable, optional): A function to construct instance | |
| from Registry. :func:`build_from_cfg` is used if neither ``parent`` | |
| or ``build_func`` is specified. If ``parent`` is specified and | |
| ``build_func`` is not given, ``build_func`` will be inherited | |
| from ``parent``. Defaults to None. | |
| parent (:obj:`Registry`, optional): Parent registry. The class | |
| registered in children registry could be built from parent. | |
| Defaults to None. | |
| scope (str, optional): The scope of registry. It is the key to search | |
| for children registry. If not specified, scope will be the name of | |
| the package where class is defined, e.g. mmdet, mmcls, mmseg. | |
| Defaults to None. | |
| locations (list): The locations to import the modules registered | |
| in this registry. Defaults to []. | |
| New in version 0.4.0. | |
| Examples: | |
| >>> # define a registry | |
| >>> MODELS = Registry('models') | |
| >>> # registry the `ResNet` to `MODELS` | |
| >>> @MODELS.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> # build model from `MODELS` | |
| >>> resnet = MODELS.build(dict(type='ResNet')) | |
| >>> @MODELS.register_module() | |
| >>> def resnet50(): | |
| >>> pass | |
| >>> resnet = MODELS.build(dict(type='resnet50')) | |
| >>> # hierarchical registry | |
| >>> DETECTORS = Registry('detectors', parent=MODELS, scope='det') | |
| >>> @DETECTORS.register_module() | |
| >>> class FasterRCNN: | |
| >>> pass | |
| >>> fasterrcnn = DETECTORS.build(dict(type='FasterRCNN')) | |
| >>> # add locations to enable auto import | |
| >>> DETECTORS = Registry('detectors', parent=MODELS, | |
| >>> scope='det', locations=['det.models.detectors']) | |
| >>> # define this class in 'det.models.detectors' | |
| >>> @DETECTORS.register_module() | |
| >>> class MaskRCNN: | |
| >>> pass | |
| >>> # The registry will auto import det.models.detectors.MaskRCNN | |
| >>> fasterrcnn = DETECTORS.build(dict(type='det.MaskRCNN')) | |
| More advanced usages can be found at | |
| https://mmengine.readthedocs.io/en/latest/advanced_tutorials/registry.html. | |
| """ | |
| def __init__(self, | |
| name: str, | |
| build_func: Optional[Callable] = None, | |
| parent: Optional['Registry'] = None, | |
| scope: Optional[str] = None, | |
| locations: List = []): | |
| from .build_functions import build_from_cfg | |
| self._name = name | |
| self._module_dict: Dict[str, Type] = dict() | |
| self._children: Dict[str, 'Registry'] = dict() | |
| self._locations = locations | |
| self._imported = False | |
| if scope is not None: | |
| assert isinstance(scope, str) | |
| self._scope = scope | |
| else: | |
| self._scope = self.infer_scope() | |
| # See https://mypy.readthedocs.io/en/stable/common_issues.html# | |
| # variables-vs-type-aliases for the use | |
| self.parent: Optional['Registry'] | |
| if parent is not None: | |
| assert isinstance(parent, Registry) | |
| parent._add_child(self) | |
| self.parent = parent | |
| else: | |
| self.parent = None | |
| # self.build_func will be set with the following priority: | |
| # 1. build_func | |
| # 2. parent.build_func | |
| # 3. build_from_cfg | |
| self.build_func: Callable | |
| if build_func is None: | |
| if self.parent is not None: | |
| self.build_func = self.parent.build_func | |
| else: | |
| self.build_func = build_from_cfg | |
| else: | |
| self.build_func = build_func | |
| def __len__(self): | |
| return len(self._module_dict) | |
| def __contains__(self, key): | |
| return self.get(key) is not None | |
| def __repr__(self): | |
| table = Table(title=f'Registry of {self._name}') | |
| table.add_column('Names', justify='left', style='cyan') | |
| table.add_column('Objects', justify='left', style='green') | |
| for name, obj in sorted(self._module_dict.items()): | |
| table.add_row(name, str(obj)) | |
| console = Console() | |
| with console.capture() as capture: | |
| console.print(table, end='') | |
| return capture.get() | |
| def infer_scope() -> str: | |
| """Infer the scope of registry. | |
| The name of the package where registry is defined will be returned. | |
| Returns: | |
| str: The inferred scope name. | |
| Examples: | |
| >>> # in mmdet/models/backbone/resnet.py | |
| >>> MODELS = Registry('models') | |
| >>> @MODELS.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> # The scope of ``ResNet`` will be ``mmdet``. | |
| """ | |
| from ..logging import print_log | |
| # `sys._getframe` returns the frame object that many calls below the | |
| # top of the stack. The call stack for `infer_scope` can be listed as | |
| # follow: | |
| # frame-0: `infer_scope` itself | |
| # frame-1: `__init__` of `Registry` which calls the `infer_scope` | |
| # frame-2: Where the `Registry(...)` is called | |
| module = inspect.getmodule(sys._getframe(2)) | |
| if module is not None: | |
| filename = module.__name__ | |
| split_filename = filename.split('.') | |
| scope = split_filename[0] | |
| else: | |
| # use "mmengine" to handle some cases which can not infer the scope | |
| # like initializing Registry in interactive mode | |
| scope = 'mmengine' | |
| print_log( | |
| 'set scope as "mmengine" when scope can not be inferred. You ' | |
| 'can silence this warning by passing a "scope" argument to ' | |
| 'Registry like `Registry(name, scope="toy")`', | |
| logger='current', | |
| level=logging.WARNING) | |
| return scope | |
| def split_scope_key(key: str) -> Tuple[Optional[str], str]: | |
| """Split scope and key. | |
| The first scope will be split from key. | |
| Return: | |
| tuple[str | None, str]: The former element is the first scope of | |
| the key, which can be ``None``. The latter is the remaining key. | |
| Examples: | |
| >>> Registry.split_scope_key('mmdet.ResNet') | |
| 'mmdet', 'ResNet' | |
| >>> Registry.split_scope_key('ResNet') | |
| None, 'ResNet' | |
| """ | |
| split_index = key.find('.') | |
| if split_index != -1: | |
| return key[:split_index], key[split_index + 1:] | |
| else: | |
| return None, key | |
| def name(self): | |
| return self._name | |
| def scope(self): | |
| return self._scope | |
| def module_dict(self): | |
| return self._module_dict | |
| def children(self): | |
| return self._children | |
| def root(self): | |
| return self._get_root_registry() | |
| def switch_scope_and_registry(self, scope: Optional[str]) -> Generator: | |
| """Temporarily switch default scope to the target scope, and get the | |
| corresponding registry. | |
| If the registry of the corresponding scope exists, yield the | |
| registry, otherwise yield the current itself. | |
| Args: | |
| scope (str, optional): The target scope. | |
| Examples: | |
| >>> from mmengine.registry import Registry, DefaultScope, MODELS | |
| >>> import time | |
| >>> # External Registry | |
| >>> MMDET_MODELS = Registry('mmdet_model', scope='mmdet', | |
| >>> parent=MODELS) | |
| >>> MMCLS_MODELS = Registry('mmcls_model', scope='mmcls', | |
| >>> parent=MODELS) | |
| >>> # Local Registry | |
| >>> CUSTOM_MODELS = Registry('custom_model', scope='custom', | |
| >>> parent=MODELS) | |
| >>> | |
| >>> # Initiate DefaultScope | |
| >>> DefaultScope.get_instance(f'scope_{time.time()}', | |
| >>> scope_name='custom') | |
| >>> # Check default scope | |
| >>> DefaultScope.get_current_instance().scope_name | |
| custom | |
| >>> # Switch to mmcls scope and get `MMCLS_MODELS` registry. | |
| >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as registry: | |
| >>> DefaultScope.get_current_instance().scope_name | |
| mmcls | |
| >>> registry.scope | |
| mmcls | |
| >>> # Nested switch scope | |
| >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmdet') as mmdet_registry: | |
| >>> DefaultScope.get_current_instance().scope_name | |
| mmdet | |
| >>> mmdet_registry.scope | |
| mmdet | |
| >>> with CUSTOM_MODELS.switch_scope_and_registry(scope='mmcls') as mmcls_registry: | |
| >>> DefaultScope.get_current_instance().scope_name | |
| mmcls | |
| >>> mmcls_registry.scope | |
| mmcls | |
| >>> | |
| >>> # Check switch back to original scope. | |
| >>> DefaultScope.get_current_instance().scope_name | |
| custom | |
| """ # noqa: E501 | |
| from ..logging import print_log | |
| # Switch to the given scope temporarily. If the corresponding registry | |
| # can be found in root registry, return the registry under the scope, | |
| # otherwise return the registry itself. | |
| with DefaultScope.overwrite_default_scope(scope): | |
| # Get the global default scope | |
| default_scope = DefaultScope.get_current_instance() | |
| # Get registry by scope | |
| if default_scope is not None: | |
| scope_name = default_scope.scope_name | |
| try: | |
| import_module(f'{scope_name}.registry') | |
| except (ImportError, AttributeError, ModuleNotFoundError): | |
| if scope in MODULE2PACKAGE: | |
| print_log( | |
| f'{scope} is not installed and its ' | |
| 'modules will not be registered. If you ' | |
| 'want to use modules defined in ' | |
| f'{scope}, Please install {scope} by ' | |
| f'`pip install {MODULE2PACKAGE[scope]}.', | |
| logger='current', | |
| level=logging.WARNING) | |
| else: | |
| print_log( | |
| f'Failed to import `{scope}.registry` ' | |
| f'make sure the registry.py exists in `{scope}` ' | |
| 'package.', | |
| logger='current', | |
| level=logging.WARNING) | |
| root = self._get_root_registry() | |
| registry = root._search_child(scope_name) | |
| if registry is None: | |
| # if `default_scope` can not be found, fallback to argument | |
| # `registry` | |
| print_log( | |
| f'Failed to search registry with scope "{scope_name}" ' | |
| f'in the "{root.name}" registry tree. ' | |
| f'As a workaround, the current "{self.name}" registry ' | |
| f'in "{self.scope}" is used to build instance. This ' | |
| 'may cause unexpected failure when running the built ' | |
| f'modules. Please check whether "{scope_name}" is a ' | |
| 'correct scope, or whether the registry is ' | |
| 'initialized.', | |
| logger='current', | |
| level=logging.WARNING) | |
| registry = self | |
| # If there is no built default scope, just return current registry. | |
| else: | |
| registry = self | |
| yield registry | |
| def _get_root_registry(self) -> 'Registry': | |
| """Return the root registry.""" | |
| root = self | |
| while root.parent is not None: | |
| root = root.parent | |
| return root | |
| def import_from_location(self) -> None: | |
| """import modules from the pre-defined locations in self._location.""" | |
| if not self._imported: | |
| # Avoid circular import | |
| from ..logging import print_log | |
| # avoid BC breaking | |
| if len(self._locations) == 0 and self.scope in MODULE2PACKAGE: | |
| print_log( | |
| f'The "{self.name}" registry in {self.scope} did not ' | |
| 'set import location. Fallback to call ' | |
| f'`{self.scope}.utils.register_all_modules` ' | |
| 'instead.', | |
| logger='current', | |
| level=logging.DEBUG) | |
| try: | |
| module = import_module(f'{self.scope}.utils') | |
| except (ImportError, AttributeError, ModuleNotFoundError): | |
| if self.scope in MODULE2PACKAGE: | |
| print_log( | |
| f'{self.scope} is not installed and its ' | |
| 'modules will not be registered. If you ' | |
| 'want to use modules defined in ' | |
| f'{self.scope}, Please install {self.scope} by ' | |
| f'`pip install {MODULE2PACKAGE[self.scope]}.', | |
| logger='current', | |
| level=logging.WARNING) | |
| else: | |
| print_log( | |
| f'Failed to import {self.scope} and register ' | |
| 'its modules, please make sure you ' | |
| 'have registered the module manually.', | |
| logger='current', | |
| level=logging.WARNING) | |
| else: | |
| # The import errors triggered during the registration | |
| # may be more complex, here just throwing | |
| # the error to avoid causing more implicit registry errors | |
| # like `xxx`` not found in `yyy` registry. | |
| module.register_all_modules(False) # type: ignore | |
| for loc in self._locations: | |
| import_module(loc) | |
| print_log( | |
| f"Modules of {self.scope}'s {self.name} registry have " | |
| f'been automatically imported from {loc}', | |
| logger='current', | |
| level=logging.DEBUG) | |
| self._imported = True | |
| def get(self, key: str) -> Optional[Type]: | |
| """Get the registry record. | |
| If `key`` represents the whole object name with its module | |
| information, for example, `mmengine.model.BaseModel`, ``get`` | |
| will directly return the class object :class:`BaseModel`. | |
| Otherwise, it will first parse ``key`` and check whether it | |
| contains a scope name. The logic to search for ``key``: | |
| - ``key`` does not contain a scope name, i.e., it is purely a module | |
| name like "ResNet": :meth:`get` will search for ``ResNet`` from the | |
| current registry to its parent or ancestors until finding it. | |
| - ``key`` contains a scope name and it is equal to the scope of the | |
| current registry (e.g., "mmcls"), e.g., "mmcls.ResNet": :meth:`get` | |
| will only search for ``ResNet`` in the current registry. | |
| - ``key`` contains a scope name and it is not equal to the scope of | |
| the current registry (e.g., "mmdet"), e.g., "mmcls.FCNet": If the | |
| scope exists in its children, :meth:`get` will get "FCNet" from | |
| them. If not, :meth:`get` will first get the root registry and root | |
| registry call its own :meth:`get` method. | |
| Args: | |
| key (str): Name of the registered item, e.g., the class name in | |
| string format. | |
| Returns: | |
| Type or None: Return the corresponding class if ``key`` exists, | |
| otherwise return None. | |
| Examples: | |
| >>> # define a registry | |
| >>> MODELS = Registry('models') | |
| >>> # register `ResNet` to `MODELS` | |
| >>> @MODELS.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> resnet_cls = MODELS.get('ResNet') | |
| >>> # hierarchical registry | |
| >>> DETECTORS = Registry('detector', parent=MODELS, scope='det') | |
| >>> # `ResNet` does not exist in `DETECTORS` but `get` method | |
| >>> # will try to search from its parents or ancestors | |
| >>> resnet_cls = DETECTORS.get('ResNet') | |
| >>> CLASSIFIER = Registry('classifier', parent=MODELS, scope='cls') | |
| >>> @CLASSIFIER.register_module() | |
| >>> class MobileNet: | |
| >>> pass | |
| >>> # `get` from its sibling registries | |
| >>> mobilenet_cls = DETECTORS.get('cls.MobileNet') | |
| """ | |
| # Avoid circular import | |
| from ..logging import print_log | |
| if not isinstance(key, str): | |
| raise TypeError( | |
| 'The key argument of `Registry.get` must be a str, ' | |
| f'got {type(key)}') | |
| scope, real_key = self.split_scope_key(key) | |
| obj_cls = None | |
| registry_name = self.name | |
| scope_name = self.scope | |
| # lazy import the modules to register them into the registry | |
| self.import_from_location() | |
| if scope is None or scope == self._scope: | |
| # get from self | |
| if real_key in self._module_dict: | |
| obj_cls = self._module_dict[real_key] | |
| elif scope is None: | |
| # try to get the target from its parent or ancestors | |
| parent = self.parent | |
| while parent is not None: | |
| if real_key in parent._module_dict: | |
| obj_cls = parent._module_dict[real_key] | |
| registry_name = parent.name | |
| scope_name = parent.scope | |
| break | |
| parent = parent.parent | |
| else: | |
| # import the registry to add the nodes into the registry tree | |
| try: | |
| import_module(f'{scope}.registry') | |
| print_log( | |
| f'Registry node of {scope} has been automatically ' | |
| 'imported.', | |
| logger='current', | |
| level=logging.DEBUG) | |
| except (ImportError, AttributeError, ModuleNotFoundError): | |
| print_log( | |
| f'Cannot auto import {scope}.registry, please check ' | |
| f'whether the package "{scope}" is installed correctly ' | |
| 'or import the registry manually.', | |
| logger='current', | |
| level=logging.DEBUG) | |
| # get from self._children | |
| if scope in self._children: | |
| obj_cls = self._children[scope].get(real_key) | |
| registry_name = self._children[scope].name | |
| scope_name = scope | |
| else: | |
| root = self._get_root_registry() | |
| if scope != root._scope and scope not in root._children: | |
| # If not skip directly, `root.get(key)` will recursively | |
| # call itself until RecursionError is thrown. | |
| pass | |
| else: | |
| obj_cls = root.get(key) | |
| if obj_cls is None: | |
| # Actually, it's strange to implement this `try ... except` to | |
| # get the object by its name in `Registry.get`. However, If we | |
| # want to build the model using a configuration like | |
| # `dict(type='mmengine.model.BaseModel')`, which can | |
| # be dumped by lazy import config, we need this code snippet | |
| # for `Registry.get` to work. | |
| try: | |
| obj_cls = get_object_from_string(key) | |
| except Exception: | |
| raise RuntimeError(f'Failed to get {key}') | |
| if obj_cls is not None: | |
| # For some rare cases (e.g. obj_cls is a partial function), obj_cls | |
| # doesn't have `__name__`. Use default value to prevent error | |
| cls_name = getattr(obj_cls, '__name__', str(obj_cls)) | |
| print_log( | |
| f'Get class `{cls_name}` from "{registry_name}"' | |
| f' registry in "{scope_name}"', | |
| logger='current', | |
| level=logging.DEBUG) | |
| return obj_cls | |
| def _search_child(self, scope: str) -> Optional['Registry']: | |
| """Depth-first search for the corresponding registry in its children. | |
| Note that the method only search for the corresponding registry from | |
| the current registry. Therefore, if we want to search from the root | |
| registry, :meth:`_get_root_registry` should be called to get the | |
| root registry first. | |
| Args: | |
| scope (str): The scope name used for searching for its | |
| corresponding registry. | |
| Returns: | |
| Registry or None: Return the corresponding registry if ``scope`` | |
| exists, otherwise return None. | |
| """ | |
| if self._scope == scope: | |
| return self | |
| for child in self._children.values(): | |
| registry = child._search_child(scope) | |
| if registry is not None: | |
| return registry | |
| return None | |
| def build(self, cfg: dict, *args, **kwargs) -> Any: | |
| """Build an instance. | |
| Build an instance by calling :attr:`build_func`. | |
| Args: | |
| cfg (dict): Config dict needs to be built. | |
| Returns: | |
| Any: The constructed object. | |
| Examples: | |
| >>> from mmengine import Registry | |
| >>> MODELS = Registry('models') | |
| >>> @MODELS.register_module() | |
| >>> class ResNet: | |
| >>> def __init__(self, depth, stages=4): | |
| >>> self.depth = depth | |
| >>> self.stages = stages | |
| >>> cfg = dict(type='ResNet', depth=50) | |
| >>> model = MODELS.build(cfg) | |
| """ | |
| return self.build_func(cfg, *args, **kwargs, registry=self) | |
| def _add_child(self, registry: 'Registry') -> None: | |
| """Add a child for a registry. | |
| Args: | |
| registry (:obj:`Registry`): The ``registry`` will be added as a | |
| child of the ``self``. | |
| """ | |
| assert isinstance(registry, Registry) | |
| assert registry.scope is not None | |
| assert registry.scope not in self.children, \ | |
| f'scope {registry.scope} exists in {self.name} registry' | |
| self.children[registry.scope] = registry | |
| def _register_module(self, | |
| module: Type, | |
| module_name: Optional[Union[str, List[str]]] = None, | |
| force: bool = False) -> None: | |
| """Register a module. | |
| Args: | |
| module (type): Module to be registered. Typically a class or a | |
| function, but generally all ``Callable`` are acceptable. | |
| module_name (str or list of str, optional): The module name to be | |
| registered. If not specified, the class name will be used. | |
| Defaults to None. | |
| force (bool): Whether to override an existing class with the same | |
| name. Defaults to False. | |
| """ | |
| if not callable(module): | |
| raise TypeError(f'module must be Callable, but got {type(module)}') | |
| if module_name is None: | |
| module_name = module.__name__ | |
| if isinstance(module_name, str): | |
| module_name = [module_name] | |
| for name in module_name: | |
| if not force and name in self._module_dict: | |
| existed_module = self.module_dict[name] | |
| raise KeyError(f'{name} is already registered in {self.name} ' | |
| f'at {existed_module.__module__}') | |
| self._module_dict[name] = module | |
| def register_module( | |
| self, | |
| name: Optional[Union[str, List[str]]] = None, | |
| force: bool = False, | |
| module: Optional[Type] = None) -> Union[type, Callable]: | |
| """Register a module. | |
| A record will be added to ``self._module_dict``, whose key is the class | |
| name or the specified name, and value is the class itself. | |
| It can be used as a decorator or a normal function. | |
| Args: | |
| name (str or list of str, optional): The module name to be | |
| registered. If not specified, the class name will be used. | |
| force (bool): Whether to override an existing class with the same | |
| name. Defaults to False. | |
| module (type, optional): Module class or function to be registered. | |
| Defaults to None. | |
| Examples: | |
| >>> backbones = Registry('backbone') | |
| >>> # as a decorator | |
| >>> @backbones.register_module() | |
| >>> class ResNet: | |
| >>> pass | |
| >>> backbones = Registry('backbone') | |
| >>> @backbones.register_module(name='mnet') | |
| >>> class MobileNet: | |
| >>> pass | |
| >>> # as a normal function | |
| >>> class ResNet: | |
| >>> pass | |
| >>> backbones.register_module(module=ResNet) | |
| """ | |
| if not isinstance(force, bool): | |
| raise TypeError(f'force must be a boolean, but got {type(force)}') | |
| # raise the error ahead of time | |
| if not (name is None or isinstance(name, str) or is_seq_of(name, str)): | |
| raise TypeError( | |
| 'name must be None, an instance of str, or a sequence of str, ' | |
| f'but got {type(name)}') | |
| # use it as a normal method: x.register_module(module=SomeClass) | |
| if module is not None: | |
| self._register_module(module=module, module_name=name, force=force) | |
| return module | |
| # use it as a decorator: @x.register_module() | |
| def _register(module): | |
| self._register_module(module=module, module_name=name, force=force) | |
| return module | |
| return _register | |