Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import inspect | |
| import threading | |
| import warnings | |
| from collections import OrderedDict | |
| from typing import Type, TypeVar | |
| _lock = threading.RLock() | |
| T = TypeVar('T') | |
| def _accquire_lock() -> None: | |
| """Acquire the module-level lock for serializing access to shared data. | |
| This should be released with _release_lock(). | |
| """ | |
| if _lock: | |
| _lock.acquire() | |
| def _release_lock() -> None: | |
| """Release the module-level lock acquired by calling _accquire_lock().""" | |
| if _lock: | |
| _lock.release() | |
| class ManagerMeta(type): | |
| """The metaclass for global accessible class. | |
| The subclasses inheriting from ``ManagerMeta`` will manage their | |
| own ``_instance_dict`` and root instances. The constructors of subclasses | |
| must contain the ``name`` argument. | |
| Examples: | |
| >>> class SubClass1(metaclass=ManagerMeta): | |
| >>> def __init__(self, *args, **kwargs): | |
| >>> pass | |
| AssertionError: <class '__main__.SubClass1'>.__init__ must have the | |
| name argument. | |
| >>> class SubClass2(metaclass=ManagerMeta): | |
| >>> def __init__(self, name): | |
| >>> pass | |
| >>> # valid format. | |
| """ | |
| def __init__(cls, *args): | |
| cls._instance_dict = OrderedDict() | |
| params = inspect.getfullargspec(cls) | |
| params_names = params[0] if params[0] else [] | |
| assert 'name' in params_names, f'{cls} must have the `name` argument' | |
| super().__init__(*args) | |
| class ManagerMixin(metaclass=ManagerMeta): | |
| """``ManagerMixin`` is the base class for classes that have global access | |
| requirements. | |
| The subclasses inheriting from ``ManagerMixin`` can get their | |
| global instances. | |
| Examples: | |
| >>> class GlobalAccessible(ManagerMixin): | |
| >>> def __init__(self, name=''): | |
| >>> super().__init__(name) | |
| >>> | |
| >>> GlobalAccessible.get_instance('name') | |
| >>> instance_1 = GlobalAccessible.get_instance('name') | |
| >>> instance_2 = GlobalAccessible.get_instance('name') | |
| >>> assert id(instance_1) == id(instance_2) | |
| Args: | |
| name (str): Name of the instance. Defaults to ''. | |
| """ | |
| def __init__(self, name: str = '', **kwargs): | |
| assert isinstance(name, str) and name, \ | |
| 'name argument must be an non-empty string.' | |
| self._instance_name = name | |
| def get_instance(cls: Type[T], name: str, **kwargs) -> T: | |
| """Get subclass instance by name if the name exists. | |
| If corresponding name instance has not been created, ``get_instance`` | |
| will create an instance, otherwise ``get_instance`` will return the | |
| corresponding instance. | |
| Examples | |
| >>> instance1 = GlobalAccessible.get_instance('name1') | |
| >>> # Create name1 instance. | |
| >>> instance.instance_name | |
| name1 | |
| >>> instance2 = GlobalAccessible.get_instance('name1') | |
| >>> # Get name1 instance. | |
| >>> assert id(instance1) == id(instance2) | |
| Args: | |
| name (str): Name of instance. Defaults to ''. | |
| Returns: | |
| object: Corresponding name instance, the latest instance, or root | |
| instance. | |
| """ | |
| _accquire_lock() | |
| assert isinstance(name, str), \ | |
| f'type of name should be str, but got {type(cls)}' | |
| instance_dict = cls._instance_dict # type: ignore | |
| # Get the instance by name. | |
| if name not in instance_dict: | |
| instance = cls(name=name, **kwargs) # type: ignore | |
| instance_dict[name] = instance # type: ignore | |
| elif kwargs: | |
| warnings.warn( | |
| f'{cls} instance named of {name} has been created, ' | |
| 'the method `get_instance` should not accept any other ' | |
| 'arguments') | |
| # Get latest instantiated instance or root instance. | |
| _release_lock() | |
| return instance_dict[name] | |
| def get_current_instance(cls): | |
| """Get latest created instance. | |
| Before calling ``get_current_instance``, The subclass must have called | |
| ``get_instance(xxx)`` at least once. | |
| Examples | |
| >>> instance = GlobalAccessible.get_current_instance() | |
| AssertionError: At least one of name and current needs to be set | |
| >>> instance = GlobalAccessible.get_instance('name1') | |
| >>> instance.instance_name | |
| name1 | |
| >>> instance = GlobalAccessible.get_current_instance() | |
| >>> instance.instance_name | |
| name1 | |
| Returns: | |
| object: Latest created instance. | |
| """ | |
| _accquire_lock() | |
| if not cls._instance_dict: | |
| raise RuntimeError( | |
| f'Before calling {cls.__name__}.get_current_instance(), you ' | |
| 'should call get_instance(name=xxx) at least once.') | |
| name = next(iter(reversed(cls._instance_dict))) | |
| _release_lock() | |
| return cls._instance_dict[name] | |
| def check_instance_created(cls, name: str) -> bool: | |
| """Check whether the name corresponding instance exists. | |
| Args: | |
| name (str): Name of instance. | |
| Returns: | |
| bool: Whether the name corresponding instance exists. | |
| """ | |
| return name in cls._instance_dict | |
| def instance_name(self) -> str: | |
| """Get the name of instance. | |
| Returns: | |
| str: Name of instance. | |
| """ | |
| return self._instance_name | |