| | import os |
| | import math |
| | import random |
| | import argparse |
| | import datetime |
| | import logging |
| | import inspect |
| | import subprocess |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | from torch.distributed.device_mesh import init_device_mesh |
| | from einops import rearrange, repeat |
| |
|
| |
|
| | dp_size = None |
| | cp_size = None |
| | dp_group = None |
| | cp_group = None |
| | cp_stream = None |
| | dp_ranks = None |
| | cp_ranks = None |
| | dp_rank = None |
| | cp_rank = None |
| |
|
| |
|
| | def init_context_parallel(context_parallel_size: int = 1, |
| | global_rank: int = 1, |
| | world_size: int = 1,): |
| |
|
| | global dp_size |
| | global cp_size |
| | global dp_group |
| | global cp_group |
| | global dp_ranks |
| | global cp_ranks |
| | global dp_rank |
| | global cp_rank |
| |
|
| |
|
| | if world_size%context_parallel_size != 0: |
| | raise RuntimeError(f'world_size {world_size} must be multiple of context_parallel_size {context_parallel_size}!!!') |
| |
|
| |
|
| | cp_size = context_parallel_size |
| | dp_size = world_size//context_parallel_size |
| |
|
| |
|
| | print(f'[rank {global_rank}] init_device_mesh [dp_size x cp_size]: [{dp_size} x {cp_size}]') |
| |
|
| | mesh_2d = init_device_mesh("cuda", (dp_size, cp_size), mesh_dim_names=("dp", "cp")) |
| |
|
| | print(f'[rank {global_rank}] mesh_2d: {mesh_2d}') |
| |
|
| | dp_group = mesh_2d.get_group(mesh_dim="dp") |
| | cp_group = mesh_2d.get_group(mesh_dim="cp") |
| |
|
| | dp_ranks = torch.distributed.get_process_group_ranks(dp_group) |
| | cp_ranks = torch.distributed.get_process_group_ranks(cp_group) |
| |
|
| | dp_rank = dist.get_rank(group=dp_group) |
| | cp_rank = dist.get_rank(group=cp_group) |
| |
|
| | global_rank_1 = torch.distributed.get_rank() |
| | print(f'[rank {global_rank_1}] [dp_rank, cp_rank]: [{dp_rank}, {cp_rank}], dp_ranks: {dp_ranks}, cp_ranks: {cp_ranks}') |
| |
|
| |
|
| | def get_cp_size(): |
| |
|
| | global cp_size |
| | return cp_size |
| |
|
| | def get_dp_size(): |
| |
|
| | global dp_size |
| | return dp_size |
| |
|
| | def get_cp_stream(): |
| |
|
| | global cp_stream |
| | if cp_stream == None: |
| | cp_stream = torch.cuda.Stream() |
| | |
| | return cp_stream |
| |
|
| | def get_dp_group(): |
| |
|
| | global dp_group |
| | return dp_group |
| |
|
| | def get_cp_group(): |
| |
|
| | global cp_group |
| | return cp_group |
| |
|
| |
|
| | def get_dp_rank(): |
| |
|
| | global dp_rank |
| | global cp_rank |
| |
|
| | return dp_rank |
| |
|
| |
|
| | def get_cp_rank(): |
| |
|
| | global dp_rank |
| | global cp_rank |
| |
|
| | return cp_rank |
| |
|
| |
|
| |
|
| | def get_cp_rank_list(): |
| | |
| | global cp_ranks |
| | if cp_ranks == None: |
| | cp_ranks = torch.distributed.get_process_group_ranks(cp_group) |
| | return cp_ranks |
| |
|
| |
|
| | def cp_broadcast(tensor, cp_index=0): |
| |
|
| | global dp_group |
| | global cp_group |
| |
|
| | cp_ranks = get_cp_rank_list() |
| |
|
| | torch.distributed.broadcast(tensor, cp_ranks[cp_index], group=cp_group) |
| |
|
| |
|
| |
|
| |
|
| | def cp_broadcast_objects(tensor): |
| |
|
| | global dp_group |
| | global cp_group |
| |
|
| | raise NotImplementedError("cp_broadcast_objects method is not yet implemented!!!") |
| |
|
| |
|
| |
|
| |
|
| | def split_tensor_in_cp(input, seq_dim): |
| |
|
| | global cp_size |
| |
|
| | seq_size = input.shape[seq_dim] |
| |
|
| | if seq_size%cp_size != 0: |
| | raise RuntimeError(f'seq_length {seq_size} in dim {seq_dim} must be multiple of cp_size {cp_size}!!!') |
| |
|
| | split_seq_size = seq_size//cp_size |
| |
|
| | tensor_splits = input.split(split_seq_size, dim=seq_dim) |
| |
|
| | cp_rank = get_cp_rank() |
| |
|
| | split_tensor = tensor_splits[cp_rank] |
| |
|
| | return split_tensor |
| |
|
| |
|
| |
|
| |
|
| |
|
| | class GatherFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, input, process_group, seq_dim, frames): |
| | ctx.cp_group = process_group |
| | ctx.seq_dim = seq_dim |
| | ctx.frames = frames |
| | ctx.cp_size = get_cp_size() |
| |
|
| | input = rearrange(input, "B (T S) C -> B T S C", T=frames) |
| |
|
| | with torch.no_grad(): |
| |
|
| | input = input.contiguous() |
| |
|
| | output_tensors = [torch.zeros_like(input) for _ in range(ctx.cp_size)] |
| |
|
| | dist.all_gather(output_tensors, input, group=ctx.cp_group) |
| |
|
| | output_tensor = torch.cat(output_tensors, dim=seq_dim) |
| |
|
| |
|
| |
|
| | output_tensor = rearrange(output_tensor, "B T S C -> B (T S) C", T=frames) |
| |
|
| |
|
| | return output_tensor |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | |
| |
|
| | with torch.no_grad(): |
| | |
| | grad_output = grad_output * ctx.cp_size |
| |
|
| | grad_output = rearrange(grad_output, "B (T S) C -> B T S C", T=ctx.frames) |
| |
|
| | grad_input = split_tensor_in_cp(grad_output, ctx.seq_dim) |
| |
|
| | grad_input = rearrange(grad_input, "B T S C -> B (T S) C", T=ctx.frames) |
| | |
| |
|
| | return grad_input, None, None, None |
| |
|
| |
|
| |
|
| |
|
| | class SplitFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, input, process_group, seq_dim): |
| | ctx.cp_group = process_group |
| | ctx.seq_dim = seq_dim |
| | ctx.cp_size = get_cp_size() |
| |
|
| | output_tensor = split_tensor_in_cp(input, ctx.seq_dim) |
| |
|
| | return output_tensor |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | |
| |
|
| | with torch.no_grad(): |
| |
|
| |
|
| | grad_output = grad_output / ctx.cp_size |
| |
|
| | output_tensors = [torch.zeros_like(grad_output) for _ in range(ctx.cp_size)] |
| |
|
| | dist.all_gather(output_tensors, grad_output, group=ctx.cp_group) |
| |
|
| | grad_input = torch.cat(output_tensors, dim=ctx.seq_dim) |
| |
|
| |
|
| | return grad_input, None, None |
| |
|
| |
|
| |
|
| | def gather_cp(input, frames): |
| |
|
| | cp_process_group = get_cp_group() |
| | |
| | output_tensor = GatherFunction.apply(input, cp_process_group, 2, frames) |
| |
|
| | return output_tensor |
| |
|
| |
|
| | def split_cp(input, seq_dim): |
| |
|
| | cp_process_group = get_cp_group() |
| | |
| | output_tensor = SplitFunction.apply(input, cp_process_group, seq_dim) |
| |
|
| | return output_tensor |
| |
|
| |
|
| |
|
| |
|
| | class ReduceFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, input, process_group): |
| | ctx.cp_group = process_group |
| |
|
| | output = input.detach().clone() |
| |
|
| | dist.all_reduce(output, group=ctx.cp_group) |
| |
|
| | return output |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | |
| | grad_input = grad_output.detach().clone() |
| |
|
| | return grad_input, None |
| | |
| |
|
| |
|
| | class ReplicateFunction(torch.autograd.Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, input, process_group): |
| | ctx.cp_group = process_group |
| |
|
| | output = input.detach().clone() |
| |
|
| |
|
| | return output |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output): |
| | |
| | grad_input = grad_output.detach().clone() |
| |
|
| | dist.all_reduce(grad_input, group=ctx.cp_group) |
| |
|
| |
|
| | return grad_input, None |
| |
|
| |
|
| | def reduce_cp(partial_sum, partial_square_sum): |
| |
|
| | cp_process_group = get_cp_group() |
| | |
| | all_sum = ReduceFunction.apply(partial_sum, cp_process_group) |
| | all_square_sum = ReduceFunction.apply(partial_square_sum, cp_process_group) |
| |
|
| | return all_sum, all_square_sum |
| |
|
| |
|
| | def replicate_cp(all_mean, all_var): |
| |
|
| | cp_process_group = get_cp_group() |
| | |
| | all_mean = ReplicateFunction.apply(all_mean, cp_process_group) |
| | all_var = ReplicateFunction.apply(all_var, cp_process_group) |
| |
|
| | return all_mean, all_var |
| |
|
| |
|
| |
|
| | def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim): |
| | input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)] |
| | output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)] |
| | dist.all_to_all(output_list, input_list, group=group) |
| | return torch.cat(output_list, dim=gather_dim).contiguous() |
| |
|
| |
|
| | class _AllToAll(torch.autograd.Function): |
| | """All-to-all communication. |
| | |
| | Args: |
| | input_: input matrix |
| | process_group: communication group |
| | scatter_dim: scatter dimension |
| | gather_dim: gather dimension |
| | """ |
| |
|
| | @staticmethod |
| | def forward(ctx, input_, process_group, scatter_dim, gather_dim): |
| | ctx.process_group = process_group |
| | ctx.scatter_dim = scatter_dim |
| | ctx.gather_dim = gather_dim |
| | world_size = dist.get_world_size(process_group) |
| |
|
| | return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim) |
| |
|
| | @staticmethod |
| | def backward(ctx, *grad_output): |
| | process_group = ctx.process_group |
| | scatter_dim = ctx.gather_dim |
| | gather_dim = ctx.scatter_dim |
| | return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim) |
| | return (return_grad, None, None, None) |
| |
|
| |
|
| | def all_to_all_with_pad( |
| | input_: torch.Tensor, |
| | process_group: dist.ProcessGroup, |
| | scatter_dim: int = 2, |
| | gather_dim: int = 1, |
| | scatter_pad: int = 0, |
| | gather_pad: int = 0, |
| | ): |
| | if scatter_pad > 0: |
| | pad_shape = list(input_.shape) |
| | pad_shape[scatter_dim] = scatter_pad |
| | pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype) |
| | input_ = torch.cat([input_, pad_tensor], dim=scatter_dim) |
| |
|
| | assert ( |
| | input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0 |
| | ), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})" |
| | input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim) |
| |
|
| | if gather_pad > 0: |
| | input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad) |
| |
|
| | return input_ |
| |
|
| |
|
| | def dynamic_switch(x, scatter_dim, gather_dim): |
| |
|
| | scatter_pad = 0 |
| | gather_pad = 0 |
| | cp_process_group = get_cp_group() |
| |
|
| | x = all_to_all_with_pad( |
| | x, |
| | cp_process_group, |
| | scatter_dim=scatter_dim, |
| | gather_dim=gather_dim, |
| | scatter_pad=scatter_pad, |
| | gather_pad=gather_pad, |
| | ) |
| | return x |
| |
|