Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import logging | |
import os | |
import signal | |
import threading | |
import torch | |
import torch.nn as nn | |
from torch.nn.parallel import DistributedDataParallel | |
from fairseq.distributed import ( | |
DistributedTimeoutWrapper, | |
LegacyDistributedDataParallel, | |
ModuleProxyWrapper, | |
TPUDistributedDataParallel, | |
) | |
logger = logging.getLogger(__name__) | |
_SLOWMO_DDP_DISABLED = False | |
try: | |
from fairscale.experimental.nn.data_parallel import ( | |
SlowMoBaseAlgorithm, | |
SlowMoDistributedDataParallel, | |
) | |
except ImportError: | |
_SLOWMO_DDP_DISABLED = True | |
def DistributedFairseqModel(args, model, process_group, device): | |
""" | |
Wrap a *model* to support distributed data parallel training. | |
This is similar to the built-in DistributedDataParallel, but allows | |
additional configuration of the DistributedDataParallel class to | |
use, and also provides easier access to the wrapped model by | |
forwarding requests for missing attributes to the wrapped model. | |
Args: | |
args (argparse.Namespace): fairseq args | |
model (BaseFairseqModel): model to wrap | |
process_group: the c10d process group to be used for distributed data | |
parallel all-reduction. | |
device: device to move model to | |
""" | |
assert isinstance(model, nn.Module) | |
if args.tpu: | |
wrapped_model = TPUDistributedDataParallel( | |
module=model.to(device), | |
process_group=process_group, | |
) | |
# forward missing getattr and state_dict/load_state_dict to orig model | |
wrapped_model = ModuleProxyWrapper(wrapped_model) | |
elif args.ddp_backend in {"c10d", "pytorch_ddp"}: | |
wrapped_model = DistributedDataParallel( | |
module=model.to(device), | |
device_ids=[args.device_id], | |
output_device=args.device_id, | |
broadcast_buffers=args.broadcast_buffers, | |
bucket_cap_mb=args.bucket_cap_mb, | |
process_group=process_group, | |
find_unused_parameters=args.find_unused_parameters, | |
gradient_as_bucket_view=args.gradient_as_bucket_view, | |
) | |
if args.ddp_comm_hook == "fp16": | |
logger.info("enable fp16 communication hook in DDP") | |
try: | |
from torch.distributed.algorithms.ddp_comm_hooks import ( | |
DDPCommHookType, | |
register_ddp_comm_hook, | |
) | |
except: | |
logger.error( | |
"Could not import from torch.distributed.algorithms.ddp_comm_hooks; you may need to update your pytorch version" | |
) | |
raise | |
register_ddp_comm_hook(DDPCommHookType.FP16_COMPRESS, wrapped_model) | |
# forward missing getattr and state_dict/load_state_dict to orig model | |
wrapped_model = ModuleProxyWrapper(wrapped_model) | |
elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: | |
wrapped_model = LegacyDistributedDataParallel( | |
module=model.to(device), | |
buffer_size=2**28, | |
process_group=process_group, | |
) | |
# forward missing getattr and state_dict/load_state_dict to orig model | |
wrapped_model = ModuleProxyWrapper(wrapped_model) | |
elif args.ddp_backend == "slowmo": | |
if _SLOWMO_DDP_DISABLED: | |
raise ImportError( | |
"Cannot find SlowMoDistributedDataParallel. " | |
"Please install fairscale with: pip install fairscale" | |
) | |
# The values of slowmo_momentum below were obtained by tuning on the | |
# En-De 16 dataset by training the transformer_wmt_en_de_large model | |
if args.slowmo_momentum is None: | |
if args.distributed_world_size <= 16: | |
args.slowmo_momentum = 0.0 | |
elif args.distributed_world_size <= 32: | |
args.slowmo_momentum = 0.2 | |
elif args.distributed_world_size <= 64: | |
args.slowmo_momentum = 0.5 | |
else: | |
args.slowmo_momentum = 0.6 | |
slowmo_base_algorithm = SlowMoBaseAlgorithm[args.slowmo_base_algorithm.upper()] | |
wrapped_model = SlowMoDistributedDataParallel( | |
module=model.to(device), | |
broadcast_buffers=args.broadcast_buffers, | |
nprocs_per_node=args.nprocs_per_node, | |
slowmo_momentum=args.slowmo_momentum, | |
slowmo_base_algorithm=slowmo_base_algorithm, | |
localsgd_frequency=args.localsgd_frequency, | |
) | |
# forward missing getattr and state_dict/load_state_dict to orig model | |
wrapped_model = ModuleProxyWrapper(wrapped_model) | |
elif args.ddp_backend == "fully_sharded": | |
try: | |
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | |
except ImportError: | |
raise ImportError( | |
"Cannot find FullyShardedDataParallel. " | |
"Please install fairscale with: pip install fairscale" | |
) | |
assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP" | |
wrapped_model = model | |
if args.memory_efficient_fp16: | |
wrapped_model = wrapped_model.half() | |
if not args.cpu_offload: | |
wrapped_model = wrapped_model.to(device=device) | |
else: | |
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) | |
# kill hung distributed jobs after a timeout | |
if getattr(args, "heartbeat_timeout", -1) > 0: | |
wrapped_model = DistributedTimeoutWrapper( | |
wrapped_model, timeout=getattr(args, "heartbeat_timeout", -1) | |
) | |
return wrapped_model | |