Spaces:
Paused
Paused
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import datetime | |
| import functools | |
| import os | |
| import subprocess | |
| from typing import Callable, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.multiprocessing as mp | |
| from torch import Tensor | |
| from torch import distributed as torch_dist | |
| from torch.distributed import ProcessGroup | |
| from mmengine.device import is_mlu_available, is_npu_available | |
| from collections.abc import Iterable, Mapping | |
| _LOCAL_PROCESS_GROUP = None | |
| def is_distributed() -> bool: | |
| """Return True if distributed environment has been initialized.""" | |
| return torch_dist.is_available() and torch_dist.is_initialized() | |
| def get_local_group() -> Optional[ProcessGroup]: | |
| """Return local process group.""" | |
| if not is_distributed(): | |
| return None | |
| if _LOCAL_PROCESS_GROUP is None: | |
| raise RuntimeError('Local process group is not created, please use ' | |
| '`init_local_group` to setup local process group.') | |
| return _LOCAL_PROCESS_GROUP | |
| def get_default_group() -> Optional[ProcessGroup]: | |
| """Return default process group.""" | |
| return torch_dist.distributed_c10d._get_default_group() | |
| def infer_launcher(): | |
| if 'WORLD_SIZE' in os.environ: | |
| return 'pytorch' | |
| elif 'SLURM_NTASKS' in os.environ: | |
| return 'slurm' | |
| elif 'OMPI_COMM_WORLD_LOCAL_RANK' in os.environ: | |
| return 'mpi' | |
| else: | |
| return 'none' | |
| def init_dist(launcher, | |
| backend='nccl', | |
| init_backend='torch', | |
| **kwargs) -> None: | |
| """Initialize distributed environment. | |
| Args: | |
| launcher (str): Way to launcher multi processes. Supported launchers | |
| are 'pytorch', 'mpi' and 'slurm'. | |
| backend (str): Communication Backends. Supported backends are 'nccl', | |
| 'gloo' and 'mpi'. Defaults to 'nccl'. | |
| **kwargs: keyword arguments are passed to ``init_process_group``. | |
| """ | |
| timeout = kwargs.get('timeout', None) | |
| if timeout is not None: | |
| # If a timeout (in seconds) is specified, it must be converted | |
| # to a timedelta object before forwarding the call to | |
| # the respective backend, because they expect a timedelta object. | |
| try: | |
| kwargs['timeout'] = datetime.timedelta(seconds=timeout) | |
| except TypeError as exception: | |
| raise TypeError( | |
| f'Timeout for distributed training must be provided as ' | |
| f"timeout in seconds, but we've received the type " | |
| f'{type(timeout)}. Please specify the timeout like this: ' | |
| f"dist_cfg=dict(backend='nccl', timeout=1800)") from exception | |
| if mp.get_start_method(allow_none=True) is None: | |
| mp.set_start_method('spawn') | |
| if launcher == 'pytorch': | |
| _init_dist_pytorch(backend, init_backend=init_backend, **kwargs) | |
| elif launcher == 'mpi': | |
| _init_dist_mpi(backend, **kwargs) | |
| elif launcher == 'slurm': | |
| _init_dist_slurm(backend, init_backend=init_backend, **kwargs) | |
| else: | |
| raise ValueError(f'Invalid launcher type: {launcher}') | |
| def _init_dist_pytorch(backend, init_backend='torch', **kwargs) -> None: | |
| """Initialize distributed environment with PyTorch launcher. | |
| Args: | |
| backend (str): Backend of torch.distributed. Supported backends are | |
| 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. | |
| **kwargs: keyword arguments are passed to ``init_process_group``. | |
| """ | |
| rank = int(os.environ['RANK']) | |
| if is_mlu_available(): | |
| import torch_mlu # noqa: F401 | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| torch.mlu.set_device(local_rank) | |
| torch_dist.init_process_group( | |
| backend='cncl', | |
| rank=rank, | |
| world_size=int(os.environ['WORLD_SIZE']), | |
| **kwargs) | |
| elif is_npu_available(): | |
| import torch_npu # noqa: F401 | |
| torch.npu.set_device(rank) | |
| torch_dist.init_process_group( | |
| backend='hccl', | |
| rank=rank, | |
| world_size=int(os.environ['WORLD_SIZE']), | |
| **kwargs) | |
| else: | |
| # LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1 | |
| local_rank = int(os.environ['LOCAL_RANK']) | |
| torch.cuda.set_device(local_rank) | |
| if init_backend == 'torch': | |
| torch_dist.init_process_group(backend=backend, **kwargs) | |
| elif init_backend == 'deepspeed': | |
| import deepspeed | |
| deepspeed.init_distributed(dist_backend=backend, **kwargs) | |
| elif init_backend == 'colossalai': | |
| import colossalai | |
| colossalai.launch_from_torch(backend=backend, **kwargs) | |
| else: | |
| raise ValueError( | |
| 'supported "init_backend" is "torch" or "deepspeed", ' | |
| f'but got {init_backend}') | |
| def _init_dist_mpi(backend, **kwargs) -> None: | |
| """Initialize distributed environment with MPI launcher. | |
| Args: | |
| backend (str): Backend of torch.distributed. Supported backends are | |
| 'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'. | |
| **kwargs: keyword arguments are passed to ``init_process_group``. | |
| """ | |
| if backend == 'smddp': | |
| try: | |
| import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 | |
| except ModuleNotFoundError as e: | |
| raise ModuleNotFoundError( | |
| 'Please use an Amazon SageMaker DLC to access smdistributed: ' | |
| 'https://github.com/aws/deep-learning-containers/blob/master' | |
| '/available_images.md#sagemaker-framework-containers' | |
| '-sm-support-only') from e | |
| local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) | |
| torch.cuda.set_device(local_rank) | |
| if 'MASTER_PORT' not in os.environ: | |
| # 29500 is torch.distributed default port | |
| os.environ['MASTER_PORT'] = '29500' | |
| if 'MASTER_ADDR' not in os.environ: | |
| raise KeyError('The environment variable MASTER_ADDR is not set') | |
| os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE'] | |
| os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK'] | |
| torch_dist.init_process_group(backend=backend, **kwargs) | |
| def _init_dist_slurm(backend, | |
| port=None, | |
| init_backend='torch', | |
| **kwargs) -> None: | |
| """Initialize slurm distributed training environment. | |
| If argument ``port`` is not specified, then the master port will be system | |
| environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system | |
| environment variable, then a default port ``29500`` will be used. | |
| Args: | |
| backend (str): Backend of torch.distributed. | |
| port (int, optional): Master port. Defaults to None. | |
| """ | |
| proc_id = int(os.environ['SLURM_PROCID']) | |
| ntasks = int(os.environ['SLURM_NTASKS']) | |
| node_list = os.environ['SLURM_NODELIST'] | |
| # Not sure when this environment variable could be None, so use a fallback | |
| local_rank_env = os.environ.get('SLURM_LOCALID', None) | |
| if local_rank_env is not None: | |
| local_rank = int(local_rank_env) | |
| else: | |
| num_gpus = torch.cuda.device_count() | |
| local_rank = proc_id % num_gpus | |
| torch.cuda.set_device(local_rank) | |
| addr = subprocess.getoutput( | |
| f'scontrol show hostname {node_list} | head -n1') | |
| # specify master port | |
| if port is not None: | |
| os.environ['MASTER_PORT'] = str(port) | |
| elif 'MASTER_PORT' in os.environ: | |
| pass # use MASTER_PORT in the environment variable | |
| else: | |
| # 29500 is torch.distributed default port | |
| os.environ['MASTER_PORT'] = '29500' | |
| # use MASTER_ADDR in the environment variable if it already exists | |
| if 'MASTER_ADDR' not in os.environ: | |
| os.environ['MASTER_ADDR'] = addr | |
| os.environ['WORLD_SIZE'] = str(ntasks) | |
| os.environ['LOCAL_RANK'] = str(local_rank) | |
| os.environ['RANK'] = str(proc_id) | |
| if init_backend == 'torch': | |
| torch_dist.init_process_group(backend=backend, **kwargs) | |
| elif init_backend == 'deepspeed': | |
| import deepspeed | |
| deepspeed.init_distributed(dist_backend=backend, **kwargs) | |
| elif init_backend == 'colossalai': | |
| import colossalai | |
| colossalai.launch_from_slurm( | |
| backend=backend, | |
| host=os.environ['MASTER_ADDR'], | |
| port=os.environ['MASTER_PORT'], | |
| **kwargs, | |
| ) | |
| else: | |
| raise ValueError('supported "init_backend" is "torch" or "deepspeed", ' | |
| f'but got {init_backend}') | |
| def init_local_group(node_rank: int, num_gpus_per_node: int): | |
| """Setup the local process group. | |
| Setup a process group which only includes processes that on the same | |
| machine as the current process. | |
| The code is modified from | |
| https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py | |
| Args: | |
| node_rank (int): Rank of machines used for training. | |
| num_gpus_per_node (int): Number of gpus used for training in a single | |
| machine. | |
| """ # noqa: W501 | |
| global _LOCAL_PROCESS_GROUP | |
| assert _LOCAL_PROCESS_GROUP is None | |
| ranks = list( | |
| range(node_rank * num_gpus_per_node, | |
| (node_rank + 1) * num_gpus_per_node)) | |
| _LOCAL_PROCESS_GROUP = torch_dist.new_group(ranks) | |
| def get_backend(group: Optional[ProcessGroup] = None) -> Optional[str]: | |
| """Return the backend of the given process group. | |
| Note: | |
| Calling ``get_backend`` in non-distributed environment will return | |
| None. | |
| Args: | |
| group (ProcessGroup, optional): The process group to work on. The | |
| default is the general main process group. If another specific | |
| group is specified, the calling process must be part of | |
| :attr:`group`. Defaults to None. | |
| Returns: | |
| str or None: Return the backend of the given process group as a lower | |
| case string if in distributed environment, otherwise None. | |
| """ | |
| if is_distributed(): | |
| # handle low versions of torch like 1.5.0 which does not support | |
| # passing in None for group argument | |
| if group is None: | |
| group = get_default_group() | |
| return torch_dist.get_backend(group) | |
| else: | |
| return None | |
| def get_world_size(group: Optional[ProcessGroup] = None) -> int: | |
| """Return the number of the given process group. | |
| Note: | |
| Calling ``get_world_size`` in non-distributed environment will return | |
| 1. | |
| Args: | |
| group (ProcessGroup, optional): The process group to work on. If None, | |
| the default process group will be used. Defaults to None. | |
| Returns: | |
| int: Return the number of processes of the given process group if in | |
| distributed environment, otherwise 1. | |
| """ | |
| if is_distributed(): | |
| # handle low versions of torch like 1.5.0 which does not support | |
| # passing in None for group argument | |
| if group is None: | |
| group = get_default_group() | |
| return torch_dist.get_world_size(group) | |
| else: | |
| return 1 | |
| def get_rank(group: Optional[ProcessGroup] = None) -> int: | |
| """Return the rank of the given process group. | |
| Rank is a unique identifier assigned to each process within a distributed | |
| process group. They are always consecutive integers ranging from 0 to | |
| ``world_size``. | |
| Note: | |
| Calling ``get_rank`` in non-distributed environment will return 0. | |
| Args: | |
| group (ProcessGroup, optional): The process group to work on. If None, | |
| the default process group will be used. Defaults to None. | |
| Returns: | |
| int: Return the rank of the process group if in distributed | |
| environment, otherwise 0. | |
| """ | |
| if is_distributed(): | |
| # handle low versions of torch like 1.5.0 which does not support | |
| # passing in None for group argument | |
| if group is None: | |
| group = get_default_group() | |
| return torch_dist.get_rank(group) | |
| else: | |
| return 0 | |
| def get_local_size() -> int: | |
| """Return the number of the current node. | |
| Returns: | |
| int: Return the number of processes in the current node if in | |
| distributed environment, otherwise 1. | |
| """ | |
| if not is_distributed(): | |
| return 1 | |
| if _LOCAL_PROCESS_GROUP is None: | |
| raise RuntimeError('Local process group is not created, please use ' | |
| '`init_local_group` to setup local process group.') | |
| return torch_dist.get_world_size(_LOCAL_PROCESS_GROUP) | |
| def get_local_rank() -> int: | |
| """Return the rank of current process in the current node. | |
| Returns: | |
| int: Return the rank of current process in the current node if in | |
| distributed environment, otherwise 0 | |
| """ | |
| if not is_distributed(): | |
| return 0 | |
| if _LOCAL_PROCESS_GROUP is None: | |
| raise RuntimeError('Local process group is not created, please use ' | |
| '`init_local_group` to setup local process group.') | |
| return torch_dist.get_rank(_LOCAL_PROCESS_GROUP) | |
| def get_dist_info(group: Optional[ProcessGroup] = None) -> Tuple[int, int]: | |
| """Get distributed information of the given process group. | |
| Note: | |
| Calling ``get_dist_info`` in non-distributed environment will return | |
| (0, 1). | |
| Args: | |
| group (ProcessGroup, optional): The process group to work on. If None, | |
| the default process group will be used. Defaults to None. | |
| Returns: | |
| tuple[int, int]: Return a tuple containing the ``rank`` and | |
| ``world_size``. | |
| """ | |
| world_size = get_world_size(group) | |
| rank = get_rank(group) | |
| return rank, world_size | |
| def is_main_process(group: Optional[ProcessGroup] = None) -> bool: | |
| """Whether the current rank of the given process group is equal to 0. | |
| Args: | |
| group (ProcessGroup, optional): The process group to work on. If None, | |
| the default process group will be used. Defaults to None. | |
| Returns: | |
| bool: Return True if the current rank of the given process group is | |
| equal to 0, otherwise False. | |
| """ | |
| return get_rank(group) == 0 | |
| def master_only(func: Callable) -> Callable: | |
| """Decorate those methods which should be executed in master process. | |
| Args: | |
| func (callable): Function to be decorated. | |
| Returns: | |
| callable: Return decorated function. | |
| """ | |
| def wrapper(*args, **kwargs): | |
| if is_main_process(): | |
| return func(*args, **kwargs) | |
| return wrapper | |
| def barrier(group: Optional[ProcessGroup] = None) -> None: | |
| """Synchronize all processes from the given process group. | |
| This collective blocks processes until the whole group enters this | |
| function. | |
| Note: | |
| Calling ``barrier`` in non-distributed environment will do nothing. | |
| Args: | |
| group (ProcessGroup, optional): The process group to work on. If None, | |
| the default process group will be used. Defaults to None. | |
| """ | |
| if is_distributed(): | |
| # handle low versions of torch like 1.5.0 which does not support | |
| # passing in None for group argument | |
| if group is None: | |
| group = get_default_group() | |
| torch_dist.barrier(group) | |
| def get_data_device(data: Union[Tensor, Mapping, Iterable]) -> torch.device: | |
| """Return the device of ``data``. | |
| If ``data`` is a sequence of Tensor, all items in ``data`` should have a | |
| same device type. | |
| If ``data`` is a dict whose values are Tensor, all values should have a | |
| same device type. | |
| Args: | |
| data (Tensor or Sequence or dict): Inputs to be inferred the device. | |
| Returns: | |
| torch.device: The device of ``data``. | |
| Examples: | |
| >>> import torch | |
| >>> from mmengine.dist import cast_data_device | |
| >>> # data is a Tensor | |
| >>> data = torch.tensor([0, 1]) | |
| >>> get_data_device(data) | |
| device(type='cpu') | |
| >>> # data is a list of Tensor | |
| >>> data = [torch.tensor([0, 1]), torch.tensor([2, 3])] | |
| >>> get_data_device(data) | |
| device(type='cpu') | |
| >>> # data is a dict | |
| >>> data = {'key1': torch.tensor([0, 1]), 'key2': torch.tensor([0, 1])} | |
| >>> get_data_device(data) | |
| device(type='cpu') | |
| """ | |
| if isinstance(data, Tensor): | |
| return data.device | |
| elif isinstance(data, Mapping): | |
| pre = None | |
| for v in data.values(): | |
| cur = get_data_device(v) | |
| if pre is None: | |
| pre = cur | |
| else: | |
| if cur != pre: | |
| raise ValueError( | |
| 'device type in data should be consistent, but got ' | |
| f'{cur} and {pre}') | |
| if pre is None: | |
| raise ValueError('data should not be empty.') | |
| return pre | |
| elif isinstance(data, Iterable) and not isinstance(data, str): | |
| pre = None | |
| for item in data: | |
| cur = get_data_device(item) | |
| if pre is None: | |
| pre = cur | |
| else: | |
| if cur != pre: | |
| raise ValueError( | |
| 'device type in data should be consistent, but got ' | |
| f'{cur} and {pre}') | |
| if pre is None: | |
| raise ValueError('data should not be empty.') | |
| return pre | |
| else: | |
| raise TypeError('data should be a Tensor, sequence of tensor or dict, ' | |
| f'but got {data}') | |
| def get_comm_device(group: Optional[ProcessGroup] = None) -> torch.device: | |
| """Return the device for communication among groups. | |
| Args: | |
| group (ProcessGroup, optional): The process group to work on. | |
| Returns: | |
| torch.device: The device of backend. | |
| """ | |
| backend = get_backend(group) | |
| if backend == 'hccl': | |
| import torch_npu # noqa: F401 | |
| return torch.device('npu', torch.npu.current_device()) | |
| elif backend == torch_dist.Backend.NCCL: | |
| return torch.device('cuda', torch.cuda.current_device()) | |
| elif backend == 'cncl': | |
| import torch_mlu # noqa: F401 | |
| return torch.device('mlu', torch.mlu.current_device()) | |
| elif backend == 'smddp': | |
| return torch.device('cuda', torch.cuda.current_device()) | |
| else: | |
| # GLOO and MPI backends use cpu device by default | |
| return torch.device('cpu') | |
| def cast_data_device( | |
| data: Union[Tensor, Mapping, Iterable], | |
| device: torch.device, | |
| out: Optional[Union[Tensor, Mapping, Iterable]] = None | |
| ) -> Union[Tensor, Mapping, Iterable]: | |
| """Recursively convert Tensor in ``data`` to ``device``. | |
| If ``data`` has already on the ``device``, it will not be casted again. | |
| Args: | |
| data (Tensor or list or dict): Inputs to be casted. | |
| device (torch.device): Destination device type. | |
| out (Tensor or list or dict, optional): If ``out`` is specified, its | |
| value will be equal to ``data``. Defaults to None. | |
| Returns: | |
| Tensor or list or dict: ``data`` was casted to ``device``. | |
| """ | |
| if out is not None: | |
| if type(data) != type(out): | |
| raise TypeError( | |
| 'out should be the same type with data, but got data is ' | |
| f'{type(data)} and out is {type(data)}') | |
| if isinstance(out, set): | |
| raise TypeError('out should not be a set') | |
| if isinstance(data, Tensor): | |
| if get_data_device(data) == device: | |
| data_on_device = data | |
| else: | |
| data_on_device = data.to(device) | |
| if out is not None: | |
| # modify the value of out inplace | |
| out.copy_(data_on_device) # type: ignore | |
| return data_on_device | |
| elif isinstance(data, Mapping): | |
| data_on_device = {} | |
| if out is not None: | |
| data_len = len(data) | |
| out_len = len(out) # type: ignore | |
| if data_len != out_len: | |
| raise ValueError('length of data and out should be same, ' | |
| f'but got {data_len} and {out_len}') | |
| for k, v in data.items(): | |
| data_on_device[k] = cast_data_device(v, device, | |
| out[k]) # type: ignore | |
| else: | |
| for k, v in data.items(): | |
| data_on_device[k] = cast_data_device(v, device) | |
| if len(data_on_device) == 0: | |
| raise ValueError('data should not be empty') | |
| # To ensure the type of output as same as input, we use `type(data)` | |
| # to wrap the output | |
| return type(data)(data_on_device) # type: ignore | |
| elif isinstance(data, Iterable) and not isinstance( | |
| data, str) and not isinstance(data, np.ndarray): | |
| data_on_device = [] | |
| if out is not None: | |
| for v1, v2 in zip(data, out): | |
| data_on_device.append(cast_data_device(v1, device, v2)) | |
| else: | |
| for v in data: | |
| data_on_device.append(cast_data_device(v, device)) | |
| if len(data_on_device) == 0: | |
| raise ValueError('data should not be empty') | |
| return type(data)(data_on_device) # type: ignore | |
| else: | |
| raise TypeError('data should be a Tensor, list of tensor or dict, ' | |
| f'but got {data}') | |