Spaces:
Runtime error
Runtime error
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """ | |
| This file contains primitives for multi-gpu communication. | |
| This is useful when doing distributed training. | |
| """ | |
| import gc | |
| import os | |
| import pickle | |
| import shutil | |
| import mmcv | |
| import torch | |
| import torch.distributed as dist | |
| from mmcv.runner import get_dist_info | |
| def is_distributed(): | |
| return get_world_size() > 1 | |
| def get_world_size(): | |
| if not dist.is_available(): | |
| return 1 | |
| if not dist.is_initialized(): | |
| return 1 | |
| return dist.get_world_size() | |
| def get_rank(): | |
| if not dist.is_available(): | |
| return 0 | |
| if not dist.is_initialized(): | |
| return 0 | |
| return dist.get_rank() | |
| def get_local_rank(): | |
| if not dist.is_available(): | |
| return 0 | |
| if not dist.is_initialized(): | |
| return 0 | |
| local_rank = int(os.getenv("LOCAL_RANK", 0)) | |
| return local_rank | |
| def is_master(): | |
| return get_rank() == 0 | |
| def is_local_master(): | |
| return get_local_rank() == 0 | |
| def get_local_proc_group(group_size=8): | |
| world_size = get_world_size() | |
| if world_size <= group_size or group_size == 1: | |
| return None | |
| assert ( | |
| world_size % group_size == 0 | |
| ), f"world size ({world_size}) should be evenly divided by group size ({group_size})." | |
| process_groups = getattr(get_local_proc_group, "process_groups", dict()) | |
| if group_size not in process_groups: | |
| num_groups = dist.get_world_size() // group_size | |
| groups = [list(range(i * group_size, (i + 1) * group_size)) for i in range(num_groups)] | |
| process_groups.update({group_size: [torch.distributed.new_group(group) for group in groups]}) | |
| get_local_proc_group.process_groups = process_groups | |
| group_idx = get_rank() // group_size | |
| process_groups = get_local_proc_group.process_groups.get(group_size)[group_idx] | |
| return process_groups | |
| def synchronize(): | |
| """ | |
| Helper function to synchronize (barrier) among all processes when | |
| using distributed training | |
| """ | |
| if not dist.is_available(): | |
| return | |
| if not dist.is_initialized(): | |
| return | |
| world_size = dist.get_world_size() | |
| if world_size == 1: | |
| return | |
| dist.barrier() | |
| def all_gather(data): | |
| """ | |
| Run all_gather on arbitrary picklable data (not necessarily tensors) | |
| Args: | |
| data: any picklable object | |
| Returns: | |
| list[data]: list of data gathered from each rank | |
| """ | |
| to_device = torch.device("cuda") | |
| # to_device = torch.device("cpu") | |
| world_size = get_world_size() | |
| if world_size == 1: | |
| return [data] | |
| # serialized to a Tensor | |
| buffer = pickle.dumps(data) | |
| storage = torch.ByteStorage.from_buffer(buffer) | |
| tensor = torch.ByteTensor(storage).to(to_device) | |
| # obtain Tensor size of each rank | |
| local_size = torch.LongTensor([tensor.numel()]).to(to_device) | |
| size_list = [torch.LongTensor([0]).to(to_device) for _ in range(world_size)] | |
| dist.all_gather(size_list, local_size) | |
| size_list = [int(size.item()) for size in size_list] | |
| max_size = max(size_list) | |
| # receiving Tensor from all ranks | |
| # we pad the tensor because torch all_gather does not support | |
| # gathering tensors of different shapes | |
| tensor_list = [] | |
| for _ in size_list: | |
| tensor_list.append(torch.ByteTensor(size=(max_size,)).to(to_device)) | |
| if local_size != max_size: | |
| padding = torch.ByteTensor(size=(max_size - local_size,)).to(to_device) | |
| tensor = torch.cat((tensor, padding), dim=0) | |
| dist.all_gather(tensor_list, tensor) | |
| data_list = [] | |
| for size, tensor in zip(size_list, tensor_list): | |
| buffer = tensor.cpu().numpy().tobytes()[:size] | |
| data_list.append(pickle.loads(buffer)) | |
| return data_list | |
| def reduce_dict(input_dict, average=True): | |
| """ | |
| Args: | |
| input_dict (dict): all the values will be reduced | |
| average (bool): whether to do average or sum | |
| Reduce the values in the dictionary from all processes so that process with rank | |
| 0 has the averaged results. Returns a dict with the same fields as | |
| input_dict, after reduction. | |
| """ | |
| world_size = get_world_size() | |
| if world_size < 2: | |
| return input_dict | |
| with torch.no_grad(): | |
| names = [] | |
| values = [] | |
| # sort the keys so that they are consistent across processes | |
| for k in sorted(input_dict.keys()): | |
| names.append(k) | |
| values.append(input_dict[k]) | |
| values = torch.stack(values, dim=0) | |
| dist.reduce(values, dst=0) | |
| if dist.get_rank() == 0 and average: | |
| # only main process gets accumulated, so only divide by | |
| # world_size in this case | |
| values /= world_size | |
| reduced_dict = {k: v for k, v in zip(names, values)} | |
| return reduced_dict | |
| def broadcast(data, **kwargs): | |
| if get_world_size() == 1: | |
| return data | |
| data = [data] | |
| dist.broadcast_object_list(data, **kwargs) | |
| return data[0] | |
| def all_gather_cpu(result_part, tmpdir=None, collect_by_master=True): | |
| rank, world_size = get_dist_info() | |
| if tmpdir is None: | |
| tmpdir = "./tmp" | |
| if rank == 0: | |
| mmcv.mkdir_or_exist(tmpdir) | |
| synchronize() | |
| # dump the part result to the dir | |
| mmcv.dump(result_part, os.path.join(tmpdir, f"part_{rank}.pkl")) | |
| synchronize() | |
| # collect all parts | |
| if collect_by_master and rank != 0: | |
| return None | |
| else: | |
| # load results of all parts from tmp dir | |
| results = [] | |
| for i in range(world_size): | |
| part_file = os.path.join(tmpdir, f"part_{i}.pkl") | |
| results.append(mmcv.load(part_file)) | |
| if not collect_by_master: | |
| synchronize() | |
| # remove tmp dir | |
| if rank == 0: | |
| shutil.rmtree(tmpdir) | |
| return results | |
| def all_gather_tensor(tensor, group_size=None, group=None): | |
| if group_size is None: | |
| group_size = get_world_size() | |
| if group_size == 1: | |
| output = [tensor] | |
| else: | |
| output = [torch.zeros_like(tensor) for _ in range(group_size)] | |
| dist.all_gather(output, tensor, group=group) | |
| return output | |
| def gather_difflen_tensor(feat, num_samples_list, concat=True, group=None, group_size=None): | |
| world_size = get_world_size() | |
| if world_size == 1: | |
| if not concat: | |
| return [feat] | |
| return feat | |
| num_samples, *feat_dim = feat.size() | |
| # padding to max number of samples | |
| feat_padding = feat.new_zeros((max(num_samples_list), *feat_dim)) | |
| feat_padding[:num_samples] = feat | |
| # gather | |
| feat_gather = all_gather_tensor(feat_padding, group=group, group_size=group_size) | |
| for r, num in enumerate(num_samples_list): | |
| feat_gather[r] = feat_gather[r][:num] | |
| if concat: | |
| feat_gather = torch.cat(feat_gather) | |
| return feat_gather | |
| class GatherLayer(torch.autograd.Function): | |
| """Gather tensors from all process, supporting backward propagation.""" | |
| def forward(ctx, input): | |
| ctx.save_for_backward(input) | |
| num_samples = torch.tensor(input.size(0), dtype=torch.long, device=input.device) | |
| ctx.num_samples_list = all_gather_tensor(num_samples) | |
| output = gather_difflen_tensor(input, ctx.num_samples_list, concat=False) | |
| return tuple(output) | |
| def backward(ctx, *grads): # tuple(output)'s grad | |
| (input,) = ctx.saved_tensors | |
| num_samples_list = ctx.num_samples_list | |
| rank = get_rank() | |
| start, end = sum(num_samples_list[:rank]), sum(num_samples_list[: rank + 1]) | |
| grads = torch.cat(grads) | |
| if is_distributed(): | |
| dist.all_reduce(grads) | |
| grad_out = torch.zeros_like(input) | |
| grad_out[:] = grads[start:end] | |
| return grad_out, None, None | |
| class GatherLayerWithGroup(torch.autograd.Function): | |
| """Gather tensors from all process, supporting backward propagation.""" | |
| def forward(ctx, input, group, group_size): | |
| ctx.save_for_backward(input) | |
| ctx.group_size = group_size | |
| output = all_gather_tensor(input, group=group, group_size=group_size) | |
| return tuple(output) | |
| def backward(ctx, *grads): # tuple(output)'s grad | |
| (input,) = ctx.saved_tensors | |
| grads = torch.stack(grads) | |
| if is_distributed(): | |
| dist.all_reduce(grads) | |
| grad_out = torch.zeros_like(input) | |
| grad_out[:] = grads[get_rank() % ctx.group_size] | |
| return grad_out, None, None | |
| def gather_layer_with_group(data, group=None, group_size=None): | |
| if group_size is None: | |
| group_size = get_world_size() | |
| output = GatherLayer.apply(data, group, group_size) | |
| return output | |
| import math | |
| from typing import Union | |
| # from torch.distributed.fsdp.fully_sharded_data_parallel import TrainingState_, _calc_grad_norm | |
| def clip_grad_norm_(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> None: | |
| self._lazy_init() | |
| self._wait_for_previous_optim_step() | |
| assert self._is_root, "clip_grad_norm should only be called on the root (parent) instance" | |
| self._assert_state(TrainingState_.IDLE) | |
| max_norm = float(max_norm) | |
| norm_type = float(norm_type) | |
| # Computes the max norm for this shard's gradients and sync's across workers | |
| local_norm = _calc_grad_norm(self.params_with_grad, norm_type).cuda() # type: ignore[arg-type] | |
| if norm_type == math.inf: | |
| total_norm = local_norm | |
| dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group) | |
| else: | |
| total_norm = local_norm**norm_type | |
| dist.all_reduce(total_norm, group=self.process_group) | |
| total_norm = total_norm ** (1.0 / norm_type) | |
| clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) | |
| if clip_coef < 1: | |
| # multiply by clip_coef, aka, (max_norm/total_norm). | |
| for p in self.params_with_grad: | |
| assert p.grad is not None | |
| p.grad.detach().mul_(clip_coef.to(p.grad.device)) | |
| return total_norm | |
| def flush(): | |
| gc.collect() | |
| torch.cuda.empty_cache() | |