Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
from torch import Tensor | |
from torch.distributed import ProcessGroup | |
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for | |
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent | |
# version of PyTorch. The following 4 lines are for backward compatibility with | |
# older PyTorch. | |
if "all_gather_into_tensor" not in dir(torch.distributed): | |
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base | |
if "reduce_scatter_tensor" not in dir(torch.distributed): | |
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base | |
# Raw operation, does not support autograd, but does support async | |
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | |
world_size = torch.distributed.get_world_size(process_group) | |
output = torch.empty( | |
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device | |
) | |
handle = torch.distributed.all_gather_into_tensor( | |
output, input_.contiguous(), group=process_group, async_op=async_op | |
) | |
return output, handle | |
# Raw operation, does not support autograd, but does support async | |
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | |
world_size = torch.distributed.get_world_size(process_group) | |
assert input_.shape[0] % world_size == 0 | |
output = torch.empty( | |
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device | |
) | |
handle = torch.distributed.reduce_scatter_tensor( | |
output, input_.contiguous(), group=process_group, async_op=async_op | |
) | |
return output, handle | |
# Raw operation, does not support autograd, but does support async | |
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): | |
input_ = input_.contiguous() | |
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) | |
return input_, handle | |
class AllGatherFunc(torch.autograd.Function): | |
"""Gather the input from sequence parallel region and concatenate.""" | |
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | |
ctx.process_group = process_group | |
output, _ = all_gather_raw(input_, process_group) | |
return output | |
def backward(ctx, grad_output: Tensor): | |
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) | |
return grad_input, None | |
# Supports autograd, but does not support async | |
all_gather = AllGatherFunc.apply | |
class ReduceScatterFunc(torch.autograd.Function): | |
"""Reduce scatter the input from the sequence parallel region and concatenate.""" | |
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | |
ctx.process_group = process_group | |
output, _ = reduce_scatter_raw(input_, process_group) | |
return output | |
def backward(ctx, grad_output: Tensor): | |
grad_input, _ = all_gather_raw(grad_output, ctx.process_group) | |
return grad_input, None | |
# Supports autograd, but does not support async | |
reduce_scatter = ReduceScatterFunc.apply | |
class AllReduceFunc(torch.autograd.Function): | |
"""Gather the input from sequence parallel region and concatenate.""" | |
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: | |
ctx.process_group = process_group | |
output, _ = all_reduce_raw(input_, process_group) | |
return output | |
def backward(ctx, grad_output: Tensor): | |
return grad_output, None | |
# Supports autograd, but does not support async | |
all_reduce = AllReduceFunc.apply | |
def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): | |
# We want to iterate over parameters with _shared_params=True in the same order, | |
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). | |
pamams_shared = { | |
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) | |
} | |
for _, p in sorted(pamams_shared.items()): | |
with torch.no_grad(): | |
# Broadcast needs src to be global rank, not group rank | |
torch.distributed.broadcast( | |
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group | |
) | |
# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 | |
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): | |
# We want to iterate over parameters with _sequence_parallel=True in the same order, | |
# as different ranks might have different number of parameters (e.g., only rank 0 has bias). | |
params_seqparallel = { | |
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) | |
} | |
grads = [p.grad for _, p in sorted(params_seqparallel.items())] | |
if grads: | |
with torch.no_grad(): | |
coalesced = torch._utils._flatten_dense_tensors(grads) | |
torch.distributed.all_reduce(coalesced, group=process_group) | |
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): | |
buf.copy_(synced) | |
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: | |
"""Get the dim for the local rank derived from splitting dim on world_size processes. | |
The split may not be even across the world_size processes. | |
""" | |
multiple = dim // multiple_of | |
div = multiple // world_size | |
mod = multiple % world_size | |
local_multiple = div + int(local_rank < mod) | |
return local_multiple * multiple_of | |