Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import io | |
import logging | |
import os | |
import pickle | |
import random | |
import socket | |
import struct | |
import subprocess | |
import warnings | |
from argparse import Namespace | |
from collections import OrderedDict | |
from dataclasses import dataclass | |
from typing import Any, Dict, List, Mapping, Optional | |
import torch | |
import torch.distributed as dist | |
from fairseq.dataclass.configs import DistributedTrainingConfig, FairseqConfig | |
from omegaconf import open_dict | |
try: | |
import torch_xla.core.xla_model as xm | |
except ImportError: | |
xm = None | |
# Flag to indicate if we're using Megatron | |
# NOTE: this is a temporary hack until we move away from Megatron's model parallel init | |
_USE_MEGATRON = False | |
# Whether to use XLA ops (e.g., on TPUs) instead of CUDA ops. | |
_USE_XLA = False | |
logger = logging.getLogger(__name__) | |
def is_master(cfg: DistributedTrainingConfig): | |
return cfg.distributed_rank == 0 | |
def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): | |
if cfg.distributed_init_method is not None or cfg.tpu: | |
return | |
num_pipelines_per_node = None | |
if cfg.pipeline_model_parallel: | |
num_pipeline_devices, num_pipelines_per_node = _pipeline_parallel_pre_init(cfg) | |
if all( | |
key in os.environ | |
for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] | |
): | |
# support torch.distributed.launch | |
_infer_torch_distributed_launch_init(cfg) | |
elif cfg.distributed_port > 0: | |
# we can determine the init method automatically for Slurm | |
_infer_slurm_init(cfg, num_pipelines_per_node) | |
elif cfg.distributed_world_size > 1 or force_distributed: | |
# fallback for single node with multiple GPUs | |
_infer_single_node_init(cfg) | |
if cfg.pipeline_model_parallel: | |
_pipeline_parallel_post_init(cfg, num_pipeline_devices, num_pipelines_per_node) | |
elif not cfg.distributed_no_spawn: | |
with open_dict(cfg): | |
cfg.distributed_num_procs = min( | |
torch.cuda.device_count(), cfg.distributed_world_size | |
) | |
def _infer_torch_distributed_launch_init(cfg: DistributedTrainingConfig): | |
cfg.distributed_init_method = "env://" | |
cfg.distributed_world_size = int(os.environ["WORLD_SIZE"]) | |
cfg.distributed_rank = int(os.environ["RANK"]) | |
# processes are created by torch.distributed.launch | |
cfg.distributed_no_spawn = True | |
def _infer_slurm_init(cfg: DistributedTrainingConfig, num_pipelines_per_node): | |
node_list = os.environ.get("SLURM_STEP_NODELIST") | |
if node_list is None: | |
node_list = os.environ.get("SLURM_JOB_NODELIST") | |
if node_list is not None: | |
try: | |
hostnames = subprocess.check_output( | |
["scontrol", "show", "hostnames", node_list] | |
) | |
cfg.distributed_init_method = "tcp://{host}:{port}".format( | |
host=hostnames.split()[0].decode("utf-8"), | |
port=cfg.distributed_port, | |
) | |
nnodes = int(os.environ.get("SLURM_NNODES")) | |
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") | |
if ntasks_per_node is not None: | |
ntasks_per_node = int(ntasks_per_node) | |
else: | |
ntasks = int(os.environ.get("SLURM_NTASKS")) | |
nnodes = int(os.environ.get("SLURM_NNODES")) | |
assert ntasks % nnodes == 0 | |
ntasks_per_node = int(ntasks / nnodes) | |
if ntasks_per_node == 1: | |
gpus_per_node = torch.cuda.device_count() | |
node_id = int(os.environ.get("SLURM_NODEID")) | |
cfg.distributed_rank = node_id * gpus_per_node | |
cfg.distributed_world_size = nnodes * gpus_per_node | |
elif cfg.pipeline_model_parallel: | |
assert ntasks_per_node == num_pipelines_per_node, ( | |
"SLURM --ntasks-per-node must match number of pipelines per " | |
"node (={})".format(num_pipelines_per_node) | |
) | |
cfg.distributed_no_spawn = True | |
# For 4-way MP on nodes with 8 GPUs, ranks will be [0, 1] on | |
# the first node, [1, 2] on the second node, etc. This | |
# matches torch.distributed.launch. | |
node_id = int(os.environ.get("SLURM_NODEID")) | |
local_id = int(os.environ.get("SLURM_LOCALID")) | |
cfg.distributed_rank = node_id * num_pipelines_per_node + local_id | |
# In the above example, device_id will always be in [0, 1], | |
# which also matches torch.distributed.launch. | |
cfg.device_id = local_id | |
# We also want to set distributed_world_size to be the total | |
# number of pipelines across all nodes. | |
cfg.distributed_world_size = nnodes * num_pipelines_per_node | |
else: | |
assert ntasks_per_node == cfg.distributed_world_size // nnodes | |
cfg.distributed_no_spawn = True | |
cfg.distributed_rank = int(os.environ.get("SLURM_PROCID")) | |
cfg.device_id = int(os.environ.get("SLURM_LOCALID")) | |
except subprocess.CalledProcessError as e: # scontrol failed | |
raise e | |
except FileNotFoundError: # Slurm is not installed | |
pass | |
def _infer_single_node_init(cfg: DistributedTrainingConfig): | |
assert ( | |
cfg.distributed_world_size <= torch.cuda.device_count() | |
), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" | |
port = random.randint(10000, 20000) | |
cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) | |
def _pipeline_parallel_pre_init(cfg: DistributedTrainingConfig): | |
from fairseq import utils | |
balance_exists = ( | |
cfg.pipeline_balance is not None | |
or cfg.pipeline_encoder_balance is not None | |
or cfg.pipeline_decoder_balance is not None | |
) | |
devices_exist = ( | |
cfg.pipeline_devices is not None | |
or cfg.pipeline_encoder_devices is not None | |
or cfg.pipeline_decoder_devices is not None | |
) | |
if not balance_exists: | |
raise ValueError( | |
"--pipeline-balance is currently required for pipeline model parallelism" | |
) | |
if not devices_exist: | |
raise ValueError( | |
"--pipeline-devices is currently required for pipeline model parallelism" | |
) | |
cfg.pipeline_balance = utils.eval_str_list(cfg.pipeline_balance, type=int) | |
if cfg.pipeline_devices is not None: | |
cfg.pipeline_devices = utils.eval_str_list(cfg.pipeline_devices, type=int) | |
num_pipeline_devices = len(set(cfg.pipeline_devices)) | |
else: | |
cfg.pipeline_encoder_devices = utils.eval_str_list( | |
cfg.pipeline_encoder_devices, type=int | |
) | |
cfg.pipeline_decoder_devices = utils.eval_str_list( | |
cfg.pipeline_decoder_devices, type=int | |
) | |
num_pipeline_devices = len( | |
set(cfg.pipeline_encoder_devices + cfg.pipeline_decoder_devices) | |
) | |
gpus_per_node = torch.cuda.device_count() | |
assert ( | |
gpus_per_node >= num_pipeline_devices | |
and gpus_per_node % num_pipeline_devices == 0 | |
), ( | |
"the number of unique device IDs in --pipeline-devices must evenly divide " | |
"the number of GPUs per node (multi-node pipelining is not yet supported)" | |
) | |
num_pipelines_per_node = gpus_per_node // num_pipeline_devices | |
return num_pipeline_devices, num_pipelines_per_node | |
def _pipeline_parallel_post_init( | |
cfg: DistributedTrainingConfig, num_pipeline_devices, num_pipelines_per_node | |
): | |
if not cfg.distributed_no_spawn: | |
# When distributed_no_spawn is False, we expect distributed_rank and | |
# distributed_world_size to be based on the total number of GPUs, so | |
# we need to correct them to be based on the number of pipelines. | |
assert cfg.distributed_world_size % num_pipeline_devices == 0 | |
cfg.distributed_world_size = cfg.distributed_world_size // num_pipeline_devices | |
# In the case of 4-way MP on nodes with 8 GPUs, we want | |
# distributed_rank to be the starting GPU index for each pipeline | |
# i.e., 0, 2, ... | |
gpus_per_node = torch.cuda.device_count() | |
assert cfg.distributed_rank % gpus_per_node == 0 | |
assert cfg.distributed_rank % num_pipeline_devices == 0 | |
with open_dict(cfg): | |
cfg.distributed_rank = cfg.distributed_rank // num_pipeline_devices | |
# launch one process per pipeline | |
cfg.distributed_num_procs = num_pipelines_per_node | |
# if we have 4-way MP on a node with 8 GPUs, we want device_ids to be 0 | |
# and 4, indicating the starting device IDs for each pipeline | |
cfg.device_id *= num_pipeline_devices | |
if cfg.device_id > 0: | |
# if there's multiple pipelines on a node (e.g., 4-way MP on an 8 | |
# GPU node), we need to adjust pipeline_devices accordingly | |
logger.debug( | |
"setting CUDA device={} on rank {}".format( | |
cfg.device_id, cfg.distributed_rank | |
) | |
) | |
torch.cuda.set_device(cfg.device_id) | |
with open_dict(cfg): | |
cfg.pipeline_devices = [cfg.device_id + d for d in cfg.pipeline_devices] | |
logger.info( | |
"setting pipeline_devices={} on rank {}".format( | |
cfg.pipeline_devices, cfg.distributed_rank | |
) | |
) | |
def distributed_init(cfg: FairseqConfig): | |
if isinstance(cfg, Namespace): | |
from fairseq.dataclass.utils import convert_namespace_to_omegaconf | |
cfg = convert_namespace_to_omegaconf(cfg) | |
if not cfg.common.tpu: | |
if torch.distributed.is_available() and torch.distributed.is_initialized(): | |
warnings.warn( | |
"Distributed is already initialized, cannot initialize twice!" | |
) | |
else: | |
logger.info( | |
"distributed init (rank {}): {}".format( | |
cfg.distributed_training.distributed_rank, | |
cfg.distributed_training.distributed_init_method, | |
) | |
) | |
dist.init_process_group( | |
backend=cfg.distributed_training.distributed_backend, | |
init_method=cfg.distributed_training.distributed_init_method, | |
world_size=cfg.distributed_training.distributed_world_size, | |
rank=cfg.distributed_training.distributed_rank, | |
) | |
logger.info( | |
"initialized host {} as rank {}".format( | |
socket.gethostname(), | |
cfg.distributed_training.distributed_rank, | |
) | |
) | |
# perform a dummy all-reduce to initialize the NCCL communicator | |
if torch.cuda.is_available(): | |
dist.all_reduce(torch.zeros(1).cuda()) | |
cfg.distributed_training.distributed_rank = torch.distributed.get_rank() | |
else: | |
assert xm.xrt_world_size() == cfg.distributed_training.distributed_world_size | |
global _USE_XLA | |
_USE_XLA = True | |
cfg.distributed_training.device_id = xm.get_local_ordinal() | |
cfg.distributed_training.distributed_rank = xm.get_ordinal() | |
xm.rendezvous("distributed_init") # wait for all workers | |
if is_master(cfg.distributed_training): | |
logging.getLogger().setLevel(logging.INFO) | |
else: | |
logging.getLogger().setLevel(logging.WARNING) | |
if cfg.common.model_parallel_size > 1: | |
try: | |
from fairseq.model_parallel.megatron.mpu import ( | |
initialize_model_parallel, | |
model_parallel_cuda_manual_seed, | |
) | |
except ImportError: | |
raise ImportError( | |
"\n\nPlease install the megatron submodule:" | |
"\n\n git submodule update --init " | |
"fairseq/model_parallel/megatron" | |
) | |
global _USE_MEGATRON | |
_USE_MEGATRON = True | |
initialize_model_parallel(cfg.common.model_parallel_size) | |
model_parallel_cuda_manual_seed(cfg.common.seed) | |
model_part_number = get_model_parallel_rank() | |
cfg.checkpoint.checkpoint_suffix += "-model_part-{0}".format(model_part_number) | |
if hasattr(cfg, "model") and getattr(cfg.model, "base_layers", 0) > 0: | |
cfg.checkpoint.checkpoint_suffix = ( | |
f"-rank-{cfg.distributed_training.distributed_rank}" | |
) | |
return cfg.distributed_training.distributed_rank | |
def distributed_main(i, main, cfg: FairseqConfig, kwargs): | |
cfg.distributed_training.device_id = i | |
if torch.cuda.is_available() and not cfg.common.cpu and not cfg.common.tpu: | |
torch.cuda.set_device(cfg.distributed_training.device_id) | |
if cfg.distributed_training.distributed_rank is None: # torch.multiprocessing.spawn | |
cfg.distributed_training.distributed_rank = kwargs.pop("start_rank", 0) + i | |
cfg.distributed_training.distributed_rank = distributed_init(cfg) | |
after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) | |
if after_distributed_init_fn: | |
cfg = after_distributed_init_fn(cfg) | |
main(cfg, **kwargs) | |
if torch.distributed.is_initialized(): | |
torch.distributed.barrier(get_global_group()) | |
def call_main(cfg: FairseqConfig, main, **kwargs): | |
if cfg.distributed_training.distributed_init_method is None: | |
infer_init_method(cfg.distributed_training) | |
if cfg.distributed_training.distributed_init_method is not None: | |
# distributed training | |
if not cfg.distributed_training.distributed_no_spawn: | |
start_rank = cfg.distributed_training.distributed_rank | |
cfg.distributed_training.distributed_rank = None # assign automatically | |
kwargs["start_rank"] = start_rank | |
torch.multiprocessing.spawn( | |
fn=distributed_main, | |
args=(main, cfg, kwargs), | |
nprocs=min( | |
torch.cuda.device_count(), | |
cfg.distributed_training.distributed_world_size, | |
), | |
join=True, | |
) | |
else: | |
distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs) | |
elif cfg.common.tpu and cfg.distributed_training.distributed_world_size > 1: | |
import torch_xla.distributed.xla_multiprocessing as xmp | |
torch.multiprocessing.set_sharing_strategy("file_system") | |
xmp.spawn( | |
fn=distributed_main, | |
args=(main, cfg, kwargs), | |
# tpu-comment: | |
# 8 devices in one TPU VM, is the max processes to be spawned. | |
# The rest is driven by xm.distributed.xla_dist | |
nprocs=min(cfg.distributed_training.distributed_world_size, 8), | |
) | |
else: | |
# single GPU main | |
main(cfg, **kwargs) | |
def use_xla(): | |
global _USE_XLA | |
return _USE_XLA | |
def new_groups(grouped_ranks: List[List[int]]): | |
if use_xla(): | |
return ("tpu", grouped_ranks) | |
else: | |
groups = [dist.new_group(g) for g in grouped_ranks] | |
my_group_idx = _find_my_group_index(grouped_ranks) | |
return groups[my_group_idx] | |
def _find_my_group_index(grouped_ranks): | |
my_rank = get_global_rank() | |
for i, group in enumerate(grouped_ranks): | |
if my_rank in group: | |
return i | |
raise RuntimeError | |
def _find_my_group(grouped_ranks): | |
index = _find_my_group_index(grouped_ranks) | |
return grouped_ranks[index] | |
def get_rank(group): | |
if use_xla(): | |
assert group[0] == "tpu" | |
my_group = _find_my_group(group[1]) | |
return my_group.index(get_global_rank()) | |
else: | |
return dist.get_rank(group=group) | |
def get_world_size(group): | |
if use_xla(): | |
assert group[0] == "tpu" | |
my_group = _find_my_group(group[1]) | |
return len(my_group) | |
elif torch.distributed.is_initialized(): | |
return dist.get_world_size(group=group) | |
else: | |
return 1 | |
def get_global_group(): | |
if use_xla(): | |
return new_groups([list(range(get_global_world_size()))]) | |
elif torch.distributed.is_initialized(): | |
if not hasattr(get_global_group, "_global_group"): | |
# ideally we could use torch.distributed.group.WORLD, but it seems | |
# to cause random NCCL hangs in some cases | |
get_global_group._global_group = dist.new_group() | |
return get_global_group._global_group | |
else: | |
return None | |
def get_global_rank(): | |
if use_xla(): | |
return xm.get_ordinal() | |
elif torch.distributed.is_initialized(): | |
return torch.distributed.get_rank() | |
else: | |
return 0 | |
def get_global_world_size(): | |
if use_xla(): | |
return xm.xrt_world_size() | |
elif torch.distributed.is_initialized(): | |
return torch.distributed.get_world_size() | |
else: | |
return 1 | |
def get_data_parallel_group(): | |
"""Get the data parallel group the caller rank belongs to.""" | |
global _USE_MEGATRON | |
if _USE_MEGATRON: | |
from fairseq.model_parallel.megatron import mpu | |
return mpu.get_data_parallel_group() | |
else: | |
return get_global_group() | |
def get_data_parallel_rank(): | |
"""Return my rank for the data parallel group.""" | |
return get_rank(get_data_parallel_group()) | |
def get_data_parallel_world_size(): | |
"""Return world size for the data parallel group.""" | |
return get_world_size(get_data_parallel_group()) | |
def get_model_parallel_group(): | |
global _USE_MEGATRON | |
if _USE_MEGATRON: | |
from fairseq.model_parallel.megatron import mpu | |
return mpu.get_model_parallel_group() | |
else: | |
return None | |
def get_model_parallel_rank(): | |
"""Return my rank for the model parallel group.""" | |
return get_rank(get_model_parallel_group()) | |
def get_model_parallel_world_size(): | |
"""Return world size for the model parallel group.""" | |
return get_world_size(get_model_parallel_group()) | |
def all_reduce(tensor, group, op="sum"): | |
if use_xla(): | |
assert isinstance(group, tuple) and group[0] == "tpu" | |
tensor = [tensor] # wrap in a list to make xm.all_reduce in-place | |
return xm.all_reduce(op, tensor, groups=group[1])[0] | |
else: | |
if op == "sum": | |
op = dist.ReduceOp.SUM | |
elif op == "max": | |
op = dist.ReduceOp.MAX | |
else: | |
raise NotImplementedError | |
dist.all_reduce(tensor, op=op, group=group) | |
return tensor | |
def broadcast(tensor, src, group): | |
if use_xla(): | |
# XLA doesn't support broadcast, hack it with all_reduce | |
if get_rank(group) != src: | |
tensor.zero_() | |
all_reduce(tensor, group) | |
else: | |
dist.broadcast(tensor, src=src, group=group) | |
def all_to_all(tensor, group): | |
"""Perform an all-to-all operation on a 1D Tensor.""" | |
assert tensor.dim() == 1 | |
split_count = get_world_size(group=group) | |
assert tensor.numel() % split_count == 0 | |
if use_xla(): | |
assert isinstance(group, tuple) and group[0] == "tpu" | |
return xm.all_to_all( | |
tensor, | |
split_dimension=0, | |
concat_dimension=0, | |
split_count=split_count, | |
groups=group[1], | |
) | |
else: | |
output = torch.zeros_like(tensor) | |
dist.all_to_all_single(output, tensor, group=group) | |
return output | |
def all_gather(tensor, group, return_tensor=False): | |
"""Perform an all-gather operation.""" | |
if use_xla(): | |
result = xm.all_gather(tensor, groups=group[1]) | |
world_size = get_world_size(group=group) | |
result = result.view(world_size, *tensor.size()) | |
if return_tensor: | |
return result | |
else: | |
return [result[i] for i in range(world_size)] | |
else: | |
world_size = get_world_size(group=group) | |
rank = get_rank(group=group) | |
tensor_list = [ | |
tensor if i == rank else torch.empty_like(tensor) for i in range(world_size) | |
] | |
dist.all_gather(tensor_list, tensor, group=group) | |
if return_tensor: | |
return torch.stack(tensor_list, dim=0) | |
else: | |
return tensor_list | |
def all_gather_list(data, group=None, max_size=16384): | |
"""Gathers arbitrary data from all nodes into a list. | |
Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python | |
data. Note that *data* must be picklable and any CUDA tensors will be moved | |
to CPU and returned on CPU as well. | |
Args: | |
data (Any): data from the local worker to be gathered on other workers | |
group: group of the collective | |
max_size (int, optional): maximum size of the data to be gathered | |
across workers | |
""" | |
from fairseq import utils | |
if group is None: | |
group = get_global_group() | |
rank = get_rank(group=group) | |
world_size = get_world_size(group=group) | |
buffer_size = max_size * world_size | |
if ( | |
not hasattr(all_gather_list, "_buffer") | |
or all_gather_list._buffer.numel() < buffer_size | |
): | |
all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) | |
all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() | |
buffer = all_gather_list._buffer | |
buffer.zero_() | |
cpu_buffer = all_gather_list._cpu_buffer | |
data = utils.move_to_cpu(data) | |
enc = pickle.dumps(data) | |
enc_size = len(enc) | |
header_size = 4 # size of header that contains the length of the encoded data | |
size = header_size + enc_size | |
if size > max_size: | |
raise ValueError( | |
"encoded data size ({}) exceeds max_size ({})".format(size, max_size) | |
) | |
header = struct.pack(">I", enc_size) | |
cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) | |
start = rank * max_size | |
buffer[start : start + size].copy_(cpu_buffer[:size]) | |
all_reduce(buffer, group=group) | |
buffer = buffer.cpu() | |
try: | |
result = [] | |
for i in range(world_size): | |
out_buffer = buffer[i * max_size : (i + 1) * max_size] | |
(enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) | |
if enc_size > 0: | |
result.append( | |
pickle.loads( | |
bytes(out_buffer[header_size : header_size + enc_size].tolist()) | |
) | |
) | |
return result | |
except pickle.UnpicklingError: | |
raise Exception( | |
"Unable to unpickle data from other workers. all_gather_list requires all " | |
"workers to enter the function together, so this error usually indicates " | |
"that the workers have fallen out of sync somehow. Workers can fall out of " | |
"sync if one of them runs out of memory, or if there are other conditions " | |
"in your training script that can cause one worker to finish an epoch " | |
"while other workers are still iterating over their portions of the data. " | |
"Try rerunning with --ddp-backend=legacy_ddp and see if that helps." | |
) | |
def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]: | |
""" | |
AllReduce a dictionary of values across workers. We separately | |
reduce items that are already on the device and items on CPU for | |
better performance. | |
Args: | |
data (Mapping[str, Any]): dictionary of data to all-reduce, but | |
cannot be a nested dictionary | |
device (torch.device): device for the reduction | |
group: group of the collective | |
""" | |
data_keys = list(data.keys()) | |
# We want to separately reduce items that are already on the | |
# device and items on CPU for performance reasons. | |
cpu_data = OrderedDict() | |
device_data = OrderedDict() | |
for k in data_keys: | |
t = data[k] | |
if not torch.is_tensor(t): | |
cpu_data[k] = torch.tensor(t, dtype=torch.double) | |
elif t.device.type != device.type: | |
cpu_data[k] = t.to(dtype=torch.double) | |
else: | |
device_data[k] = t.to(dtype=torch.double) | |
def _all_reduce_dict(data: OrderedDict): | |
if len(data) == 0: | |
return data | |
buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device) | |
all_reduce(buf, group=group) | |
split_buf = torch.split(buf.clone(), [t.numel() for t in data.values()]) | |
reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())] | |
return OrderedDict(zip(data.keys(), reduced_data)) | |
cpu_data = _all_reduce_dict(cpu_data) | |
device_data = _all_reduce_dict(device_data) | |
def get_from_stack(key): | |
if key in cpu_data: | |
return cpu_data[key] | |
elif key in device_data: | |
return device_data[key] | |
raise KeyError | |
return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) | |
def broadcast_tensors( | |
tensors: Optional[List[torch.Tensor]], | |
src_rank: int, | |
group: object, | |
dist_device: Optional[torch.device] = None, | |
) -> List[torch.Tensor]: | |
""" | |
Broadcasts a list of tensors without other (non-src) ranks needing to know | |
the dtypes/shapes of the tensors. | |
""" | |
if dist_device is None: | |
if torch.distributed.get_backend(group) == "nccl": | |
dist_device = torch.device("cuda") | |
else: | |
dist_device = torch.device("cpu") | |
# share metadata first to simplify transfer | |
is_src_rank = get_rank(group) == src_rank | |
if is_src_rank: | |
metadata = [ | |
{"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors | |
] | |
metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) | |
else: | |
metadata = _broadcast_object_slow(None, src_rank, group, dist_device) | |
out_tensors = [] | |
for i, meta in enumerate(metadata): | |
if is_src_rank: | |
tensor = tensors[i] | |
broadcast(tensors[i].to(dist_device), src=src_rank, group=group) | |
else: | |
tensor = torch.zeros( | |
[meta["size"].numel()], dtype=meta["dtype"], device=dist_device | |
) | |
broadcast(tensor, src=src_rank, group=group) | |
tensor = tensor.view(meta["size"]).to(meta["device"]) | |
out_tensors.append(tensor) | |
return out_tensors | |
def broadcast_object( | |
obj: Any, | |
src_rank: int, | |
group: object, | |
dist_device: Optional[torch.device] = None, | |
) -> Any: | |
"""Broadcast an arbitrary Python object to other workers.""" | |
if dist_device is None: | |
if torch.distributed.get_backend(group) == "nccl": | |
dist_device = torch.device("cuda") | |
else: | |
dist_device = torch.device("cpu") | |
if get_rank(group) == src_rank: | |
# split the tensors from the non-tensors so we can broadcast them | |
# directly, avoiding unnecessary serialization/deserialization | |
tensors = [] | |
obj = _split_tensors_from_obj(obj, tensors) | |
obj = _broadcast_object_slow(obj, src_rank, group, dist_device) | |
tensors = broadcast_tensors(tensors, src_rank, group, dist_device) | |
else: | |
obj = _broadcast_object_slow(None, src_rank, group, dist_device) | |
tensors = broadcast_tensors(None, src_rank, group, dist_device) | |
return _put_tensors_in_obj(obj, tensors) | |
def _broadcast_object_slow( | |
obj: Any, | |
src_rank: int, | |
group: object, | |
dist_device: torch.device, | |
) -> Any: | |
if get_rank(group) == src_rank: | |
# Emit data | |
buffer = io.BytesIO() | |
torch.save(obj, buffer) | |
buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) | |
length = torch.LongTensor([len(buffer)]).to(dist_device) | |
broadcast(length, src=src_rank, group=group) | |
broadcast(buffer, src=src_rank, group=group) | |
else: | |
# Fetch from the source | |
length = torch.LongTensor([0]).to(dist_device) | |
broadcast(length, src=src_rank, group=group) | |
buffer = torch.ByteTensor(int(length.item())).to(dist_device) | |
broadcast(buffer, src=src_rank, group=group) | |
buffer = io.BytesIO(buffer.cpu().numpy()) | |
obj = torch.load(buffer, map_location="cpu") | |
return obj | |
class _TensorPlaceholder: | |
index: int | |
def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: | |
if torch.is_tensor(obj): | |
placeholder = _TensorPlaceholder(index=len(tensors)) | |
tensors.append(obj) | |
return placeholder | |
elif isinstance(obj, dict): | |
return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [_split_tensors_from_obj(v, tensors) for v in obj] | |
elif isinstance(obj, tuple): | |
return tuple(_split_tensors_from_obj(v, tensors) for v in obj) | |
elif isinstance(obj, set): | |
return {_split_tensors_from_obj(v, tensors) for v in obj} | |
else: | |
return obj | |
def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: | |
if isinstance(obj, _TensorPlaceholder): | |
return tensors[obj.index] | |
elif isinstance(obj, dict): | |
return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [_put_tensors_in_obj(v, tensors) for v in obj] | |
elif isinstance(obj, tuple): | |
return tuple(_put_tensors_in_obj(v, tensors) for v in obj) | |
elif isinstance(obj, set): | |
return {_put_tensors_in_obj(v, tensors) for v in obj} | |
else: | |
return obj | |