|
|
|
|
|
|
|
|
|
|
|
""" |
|
Train a network across multiple GPUs. |
|
""" |
|
|
|
import contextlib |
|
import logging |
|
import sys |
|
import time |
|
from argparse import Namespace |
|
from itertools import chain |
|
from typing import Any, Dict, List |
|
|
|
import torch |
|
from fairseq import checkpoint_utils, models, optim, utils |
|
from fairseq.dataclass.configs import FairseqConfig |
|
from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
from fairseq.distributed import utils as distributed_utils |
|
from fairseq.file_io import PathManager |
|
from fairseq.logging import meters, metrics |
|
from fairseq.nan_detector import NanDetector |
|
from fairseq.optim import lr_scheduler |
|
from omegaconf import OmegaConf |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class Trainer(object): |
|
"""Main class for data parallel training. |
|
|
|
This class supports synchronous distributed data parallel training, |
|
where multiple workers each have a full model replica and gradients |
|
are accumulated across workers before each update. We use |
|
:class:`~torch.nn.parallel.DistributedDataParallel` to handle |
|
communication of the gradients across workers. |
|
""" |
|
|
|
def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): |
|
|
|
if isinstance(cfg, Namespace): |
|
logger.warning( |
|
"argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" |
|
) |
|
cfg = convert_namespace_to_omegaconf(cfg) |
|
|
|
self.cfg = cfg |
|
self.task = task |
|
|
|
|
|
shared_params = _catalog_shared_params(model) |
|
self.tpu = cfg.common.tpu |
|
self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu |
|
if self.cuda: |
|
self.device = torch.device("cuda") |
|
elif self.tpu: |
|
self.device = utils.get_tpu_device() |
|
else: |
|
self.device = torch.device("cpu") |
|
|
|
if self.cfg.distributed_training.ddp_backend == "fully_sharded": |
|
if self.cfg.common.bf16: |
|
raise ValueError( |
|
"FullyShardedDataParallel is not compatible with --bf16 or " |
|
"--memory-efficient-bf16" |
|
) |
|
if self.cfg.distributed_training.zero_sharding != "none": |
|
raise ValueError( |
|
"FullyShardedDataParallel is not compatible with --zero-sharding " |
|
"option (it's already built in)" |
|
) |
|
else: |
|
if ( |
|
hasattr(self.cfg.distributed_training, "cpu_offload") |
|
and self.cfg.distributed_training.cpu_offload |
|
): |
|
raise ValueError("--cpu-offload requires --ddp-backend=fully_sharded") |
|
|
|
|
|
self._criterion = criterion |
|
self._model = model |
|
if cfg.distributed_training.ddp_backend != "fully_sharded": |
|
if cfg.common.fp16: |
|
assert not cfg.common.amp, "Cannot use fp16 and AMP together" |
|
self._criterion = self._criterion.half() |
|
self._model = self._model.half() |
|
elif cfg.common.bf16: |
|
self._criterion = self._criterion.to(dtype=torch.bfloat16) |
|
self._model = self._model.to(dtype=torch.bfloat16) |
|
elif cfg.common.amp: |
|
self._amp_retries = 0 |
|
if ( |
|
not cfg.distributed_training.pipeline_model_parallel |
|
|
|
|
|
and not self.use_distributed_wrapper |
|
): |
|
self._criterion = self._criterion.to(device=self.device) |
|
self._model = self._model.to(device=self.device) |
|
self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel |
|
self.last_device = None |
|
if self.cuda and self.pipeline_model_parallel: |
|
self.last_device = torch.device( |
|
cfg.distributed_training.pipeline_devices[-1] |
|
) |
|
|
|
|
|
for shared_param in shared_params: |
|
ref = _get_module_by_path(self._model, shared_param[0]) |
|
for path in shared_param[1:]: |
|
logger.info( |
|
"detected shared parameter: {} <- {}".format(shared_param[0], path) |
|
) |
|
_set_module_by_path(self._model, path, ref) |
|
|
|
self._dummy_batch = None |
|
self._lr_scheduler = None |
|
self._num_updates = 0 |
|
self._num_xla_compiles = 0 |
|
self._optim_history = None |
|
self._optimizer = None |
|
self._warn_once = set() |
|
self._wrapped_criterion = None |
|
self._wrapped_model = None |
|
|
|
|
|
if self.cuda and self.data_parallel_world_size > 1: |
|
self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) |
|
else: |
|
self._grad_norm_buf = None |
|
|
|
self.quantizer = quantizer |
|
if self.quantizer is not None: |
|
self.quantizer.set_trainer(self) |
|
|
|
|
|
if self.cuda: |
|
self.cuda_env = utils.CudaEnvironment() |
|
if self.data_parallel_world_size > 1: |
|
self.cuda_env_arr = distributed_utils.all_gather_list( |
|
self.cuda_env, group=distributed_utils.get_global_group() |
|
) |
|
else: |
|
self.cuda_env_arr = [self.cuda_env] |
|
if self.data_parallel_rank == 0: |
|
utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) |
|
else: |
|
self.cuda_env = None |
|
self.cuda_env_arr = None |
|
|
|
metrics.log_start_time("wall", priority=790, round=0) |
|
|
|
self._start_time = time.time() |
|
self._previous_training_time = 0 |
|
self._cumulative_training_time = None |
|
|
|
def reinitialize(self): |
|
"""Reinitialize the Trainer, typically after model params change.""" |
|
self._lr_scheduler = None |
|
self._optimizer = None |
|
self._wrapped_criterion = None |
|
self._wrapped_model = None |
|
|
|
@property |
|
def data_parallel_world_size(self): |
|
if self.cfg.distributed_training.distributed_world_size == 1: |
|
return 1 |
|
return distributed_utils.get_data_parallel_world_size() |
|
|
|
@property |
|
def data_parallel_process_group(self): |
|
return distributed_utils.get_data_parallel_group() |
|
|
|
@property |
|
def data_parallel_rank(self): |
|
if self.cfg.distributed_training.distributed_world_size == 1: |
|
return 0 |
|
return distributed_utils.get_data_parallel_rank() |
|
|
|
@property |
|
def is_data_parallel_master(self): |
|
|
|
|
|
return self.data_parallel_rank == 0 |
|
|
|
@property |
|
def use_distributed_wrapper(self) -> bool: |
|
return ( |
|
self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf |
|
) or ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and self.cfg.distributed_training.cpu_offload |
|
) |
|
|
|
@property |
|
def should_save_checkpoint_on_current_rank(self) -> bool: |
|
"""Indicates whether to save checkpoints on the current DDP rank.""" |
|
if ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and self.cfg.distributed_training.use_sharded_state |
|
) or getattr(self.cfg.model, "base_layers", 0) > 0: |
|
return True |
|
else: |
|
return self.is_data_parallel_master |
|
|
|
@property |
|
def always_call_state_dict_during_save_checkpoint(self) -> bool: |
|
if ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and not self.cfg.distributed_training.use_sharded_state |
|
): |
|
|
|
return True |
|
else: |
|
return False |
|
|
|
@property |
|
def checkpoint_suffix(self) -> str: |
|
"""Suffix to add to the checkpoint file name.""" |
|
if ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and self.cfg.distributed_training.use_sharded_state |
|
): |
|
return self.cfg.checkpoint.checkpoint_suffix + "-shard{0}".format( |
|
self.data_parallel_rank |
|
) |
|
else: |
|
return self.cfg.checkpoint.checkpoint_suffix or "" |
|
|
|
@property |
|
def criterion(self): |
|
if self._wrapped_criterion is None: |
|
if utils.has_parameters(self._criterion) and self.use_distributed_wrapper: |
|
self._wrapped_criterion = models.DistributedFairseqModel( |
|
self.cfg.distributed_training, |
|
self._criterion, |
|
process_group=self.data_parallel_process_group, |
|
device=self.device, |
|
) |
|
else: |
|
self._wrapped_criterion = self._criterion |
|
return self._wrapped_criterion |
|
|
|
@property |
|
def model(self): |
|
if self._wrapped_model is None: |
|
if self.use_distributed_wrapper: |
|
self._wrapped_model = models.DistributedFairseqModel( |
|
self.cfg.distributed_training, |
|
self._model, |
|
process_group=self.data_parallel_process_group, |
|
device=self.device, |
|
) |
|
else: |
|
self._wrapped_model = self._model |
|
return self._wrapped_model |
|
|
|
@property |
|
def optimizer(self): |
|
if self._optimizer is None: |
|
self._build_optimizer() |
|
return self._optimizer |
|
|
|
@property |
|
def lr_scheduler(self): |
|
if self._lr_scheduler is None: |
|
self._build_optimizer() |
|
return self._lr_scheduler |
|
|
|
def _build_optimizer(self): |
|
params = list( |
|
filter( |
|
lambda p: p.requires_grad, |
|
chain(self.model.parameters(), self.criterion.parameters()), |
|
) |
|
) |
|
|
|
if ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and self.cfg.common.fp16 |
|
): |
|
|
|
|
|
|
|
|
|
|
|
allow_unsupported = not self.cfg.common.memory_efficient_fp16 |
|
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( |
|
self.cfg, params, allow_unsupported=allow_unsupported |
|
) |
|
elif self.cfg.common.fp16 or self.cfg.common.bf16 or self.cfg.common.amp: |
|
if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: |
|
logger.info( |
|
"NOTE: your device does NOT support faster training with --fp16 or --amp, " |
|
"please switch to FP32 which is likely to be faster" |
|
) |
|
if ( |
|
self.cfg.common.memory_efficient_fp16 |
|
or self.cfg.common.memory_efficient_bf16 |
|
): |
|
self._optimizer = optim.MemoryEfficientFP16Optimizer.build_optimizer( |
|
self.cfg, params |
|
) |
|
elif self.cfg.common.amp: |
|
self._optimizer = optim.AMPOptimizer.build_optimizer(self.cfg, params) |
|
else: |
|
self._optimizer = optim.FP16Optimizer.build_optimizer(self.cfg, params) |
|
else: |
|
if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: |
|
logger.info("NOTE: your device may support faster training with --fp16 or --amp") |
|
self._optimizer = optim.build_optimizer(self.cfg.optimizer, params) |
|
|
|
if self.cfg.distributed_training.ddp_backend == "fully_sharded": |
|
assert ( |
|
not self.cfg.optimization.use_bmuf |
|
), "--ddp-backend=fully_sharded is not compatible with BMUF" |
|
assert self._optimizer.supports_flat_params, ( |
|
"--ddp-backend=fully_sharded is only compatible with pointwise " |
|
"optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.). " |
|
"However, the sharding will result in slightly different results when " |
|
"using non-pointwise optimizers (e.g., Adagrad, Adafactor, LAMB)" |
|
) |
|
|
|
if self.cfg.optimization.use_bmuf: |
|
self._optimizer = optim.FairseqBMUF( |
|
self.cfg.bmuf, |
|
self._optimizer, |
|
) |
|
|
|
if self.cfg.distributed_training.zero_sharding == "os": |
|
if ( |
|
self.cfg.common.fp16 |
|
and not self.cfg.common.memory_efficient_fp16 |
|
and not self.cfg.common.memory_efficient_bf16 |
|
) and not self.cfg.common.fp16_no_flatten_grads: |
|
raise ValueError( |
|
"ZeRO is incomptabile with fp16 and flattened grads. " |
|
"Please use --fp16-no-flatten-grads" |
|
) |
|
else: |
|
optim.shard_(self._optimizer, self.data_parallel_process_group) |
|
|
|
|
|
|
|
self._lr_scheduler = lr_scheduler.build_lr_scheduler( |
|
self.cfg.lr_scheduler, |
|
self.optimizer, |
|
) |
|
self._lr_scheduler.step_update(0) |
|
|
|
def consolidate_optimizer(self): |
|
"""For OSS, we need to consolidate the state dict.""" |
|
if self.cfg.checkpoint.no_save_optimizer_state: |
|
return |
|
self._gathered_optim_state = None |
|
if hasattr(self.optimizer.optimizer, "consolidate_state_dict"): |
|
self.optimizer.optimizer.consolidate_state_dict() |
|
|
|
elif ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and not self.model.use_sharded_state |
|
): |
|
st = self.model.gather_full_optim_state_dict( |
|
self.optimizer |
|
) |
|
self._gathered_optim_state = st |
|
|
|
def state_dict(self): |
|
state_dict = { |
|
"args": None, |
|
"cfg": ( |
|
OmegaConf.to_container(self.cfg, resolve=True, enum_to_str=True) |
|
if OmegaConf.is_config(self.cfg) |
|
else self.cfg |
|
), |
|
"model": self.model.state_dict(), |
|
"criterion": ( |
|
self.criterion.state_dict() |
|
if utils.has_parameters(self.criterion) |
|
else None |
|
), |
|
"optimizer_history": (self._optim_history or []) |
|
+ [ |
|
{ |
|
"criterion_name": self.get_criterion().__class__.__name__, |
|
"optimizer_name": self.optimizer.__class__.__name__, |
|
"lr_scheduler_state": self.lr_scheduler.state_dict(), |
|
"num_updates": self.get_num_updates(), |
|
} |
|
], |
|
"task_state": self.task.state_dict() if self.task is not None else {}, |
|
"extra_state": { |
|
"metrics": metrics.state_dict(), |
|
"previous_training_time": self.cumulative_training_time(), |
|
}, |
|
} |
|
if not self.cfg.checkpoint.no_save_optimizer_state: |
|
if self._gathered_optim_state is not None: |
|
state_dict["last_optimizer_state"] = self._gathered_optim_state |
|
self._gathered_optim_state = None |
|
else: |
|
state_dict["last_optimizer_state"] = self.optimizer.state_dict() |
|
if self.cfg.distributed_training.ddp_backend == "fully_sharded": |
|
|
|
state_dict["fsdp_metadata"] = self.model.local_metadata_dict() |
|
return state_dict |
|
|
|
def save_checkpoint(self, filename, extra_state): |
|
"""Save all training state in a checkpoint file.""" |
|
logger.info(f"Saving checkpoint to {filename}") |
|
|
|
state_dict = utils.move_to_cpu(self.state_dict()) |
|
state_dict["extra_state"].update(extra_state) |
|
if self.should_save_checkpoint_on_current_rank: |
|
checkpoint_utils.torch_persistent_save( |
|
state_dict, |
|
filename, |
|
async_write=self.cfg.checkpoint.write_checkpoints_asynchronously, |
|
) |
|
logger.info(f"Finished saving checkpoint to {filename}") |
|
|
|
def load_checkpoint( |
|
self, |
|
filename, |
|
reset_optimizer=False, |
|
reset_lr_scheduler=False, |
|
optimizer_overrides=None, |
|
reset_meters=False, |
|
): |
|
""" |
|
Load all training state from a checkpoint file. |
|
rank = 0 will load the checkpoint, and then broadcast it to all |
|
other ranks. |
|
""" |
|
extra_state, self._optim_history, last_optim_state = None, [], None |
|
|
|
logger.info(f"Preparing to load checkpoint {filename}") |
|
is_distributed = self.data_parallel_world_size > 1 |
|
bexists = PathManager.isfile(filename) |
|
if bexists: |
|
load_on_all_ranks = ( |
|
self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks |
|
|
|
|
|
or self.tpu |
|
|
|
or ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and self.cfg.distributed_training.use_sharded_state |
|
) |
|
or getattr(self.cfg.model, "base_layers", 0) > 0 |
|
) |
|
|
|
if load_on_all_ranks or self.data_parallel_rank == 0: |
|
state = checkpoint_utils.load_checkpoint_to_cpu( |
|
filename, load_on_all_ranks=load_on_all_ranks |
|
) |
|
last_optim_state = state.get("last_optimizer_state", None) |
|
|
|
|
|
|
|
|
|
if ( |
|
not load_on_all_ranks |
|
and self.cfg.distributed_training.zero_sharding == "os" |
|
and "last_optimizer_state" in state |
|
and is_distributed |
|
): |
|
state["last_optimizer_state"] = "SHARDED" |
|
else: |
|
last_optim_state = None |
|
state = None |
|
|
|
if is_distributed and not load_on_all_ranks: |
|
state = distributed_utils.broadcast_object( |
|
state, |
|
src_rank=0, |
|
group=self.data_parallel_process_group, |
|
dist_device=self.device, |
|
) |
|
if self.data_parallel_rank > 0: |
|
last_optim_state = state.get("last_optimizer_state", None) |
|
|
|
|
|
try: |
|
self.model.load_state_dict( |
|
state["model"], strict=True, model_cfg=self.cfg.model |
|
) |
|
|
|
del state["model"] |
|
if utils.has_parameters(self.get_criterion()): |
|
self.get_criterion().load_state_dict( |
|
state["criterion"], strict=True |
|
) |
|
del state["criterion"] |
|
|
|
except Exception: |
|
raise Exception( |
|
"Cannot load model parameters from checkpoint {}; " |
|
"please ensure that the architectures match.".format(filename) |
|
) |
|
extra_state = state["extra_state"] |
|
self._optim_history = state["optimizer_history"] |
|
|
|
if last_optim_state is not None and not reset_optimizer: |
|
|
|
self._build_optimizer() |
|
|
|
|
|
last_optim = self._optim_history[-1] |
|
assert ( |
|
last_optim["criterion_name"] == self.get_criterion().__class__.__name__ |
|
), f"Criterion does not match; please reset the optimizer (--reset-optimizer). {last_optim['criterion_name']} vs {self.get_criterion().__class__.__name__}" |
|
assert ( |
|
last_optim["optimizer_name"] == self.optimizer.__class__.__name__ |
|
), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" |
|
|
|
if not reset_lr_scheduler: |
|
self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) |
|
|
|
if ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and not self.model.use_sharded_state |
|
): |
|
|
|
last_optim_state = self.model.get_shard_from_optim_state_dict( |
|
last_optim_state |
|
) |
|
elif not load_on_all_ranks and is_distributed: |
|
last_optim_state = self.optimizer.broadcast_global_state_dict( |
|
last_optim_state |
|
) |
|
|
|
self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) |
|
|
|
self.set_num_updates(last_optim["num_updates"]) |
|
|
|
if extra_state is not None: |
|
itr_state = extra_state["train_iterator"] |
|
epoch = itr_state["epoch"] |
|
|
|
if "previous_training_time" in extra_state: |
|
self._previous_training_time = extra_state["previous_training_time"] |
|
self._start_time = time.time() |
|
|
|
self.lr_step(epoch) |
|
|
|
if ( |
|
itr_state.get("version", 1) >= 2 |
|
and itr_state["iterations_in_epoch"] == 0 |
|
): |
|
|
|
reset_meters = True |
|
|
|
if "metrics" in extra_state and not reset_meters: |
|
metrics.load_state_dict(extra_state["metrics"]) |
|
|
|
|
|
for meter in metrics.get_meters("default"): |
|
if isinstance(meter, meters.TimeMeter): |
|
meter.reset() |
|
|
|
logger.info( |
|
"Loaded checkpoint {} (epoch {} @ {} updates)".format( |
|
filename, epoch, self.get_num_updates() |
|
) |
|
) |
|
|
|
else: |
|
logger.info("No existing checkpoint found {}".format(filename)) |
|
|
|
return extra_state |
|
|
|
def get_train_iterator( |
|
self, |
|
epoch, |
|
combine=True, |
|
load_dataset=True, |
|
data_selector=None, |
|
shard_batch_itr=True, |
|
disable_iterator_cache=False, |
|
): |
|
"""Return an EpochBatchIterator over the training set for a given epoch.""" |
|
if load_dataset: |
|
logger.info("loading train data for epoch {}".format(epoch)) |
|
self.task.load_dataset( |
|
self.cfg.dataset.train_subset, |
|
epoch=epoch, |
|
combine=combine, |
|
data_selector=data_selector, |
|
tpu=self.tpu, |
|
) |
|
batch_iterator = self.task.get_batch_iterator( |
|
dataset=self.task.dataset(self.cfg.dataset.train_subset), |
|
max_tokens=self.cfg.dataset.max_tokens, |
|
max_sentences=self.cfg.dataset.batch_size, |
|
max_positions=utils.resolve_max_positions( |
|
self.task.max_positions(), |
|
self.model.max_positions(), |
|
self.cfg.dataset.max_tokens, |
|
), |
|
ignore_invalid_inputs=True, |
|
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, |
|
seed=self.cfg.common.seed, |
|
num_shards=self.data_parallel_world_size if shard_batch_itr else 1, |
|
shard_id=self.data_parallel_rank if shard_batch_itr else 0, |
|
num_workers=self.cfg.dataset.num_workers, |
|
epoch=epoch, |
|
data_buffer_size=self.cfg.dataset.data_buffer_size, |
|
disable_iterator_cache=disable_iterator_cache, |
|
) |
|
self.reset_dummy_batch(batch_iterator.first_batch) |
|
return batch_iterator |
|
|
|
def get_valid_iterator( |
|
self, |
|
subset, |
|
disable_iterator_cache=False, |
|
): |
|
"""Return an EpochBatchIterator over given validation subset for a given epoch.""" |
|
batch_iterator = self.task.get_batch_iterator( |
|
dataset=self.task.dataset(subset), |
|
max_tokens=self.cfg.dataset.max_tokens_valid, |
|
max_sentences=self.cfg.dataset.batch_size_valid, |
|
max_positions=utils.resolve_max_positions( |
|
self.task.max_positions(), |
|
self.model.max_positions(), |
|
), |
|
ignore_invalid_inputs=self.cfg.dataset.skip_invalid_size_inputs_valid_test, |
|
required_batch_size_multiple=self.cfg.dataset.required_batch_size_multiple, |
|
seed=self.cfg.common.seed, |
|
num_shards=self.data_parallel_world_size, |
|
shard_id=self.data_parallel_rank, |
|
num_workers=self.cfg.dataset.num_workers, |
|
|
|
|
|
epoch=1, |
|
data_buffer_size=self.cfg.dataset.data_buffer_size, |
|
disable_iterator_cache=disable_iterator_cache, |
|
) |
|
self.reset_dummy_batch(batch_iterator.first_batch) |
|
return batch_iterator |
|
|
|
def begin_epoch(self, epoch): |
|
"""Called at the beginning of each epoch.""" |
|
logger.info("begin training epoch {}".format(epoch)) |
|
|
|
self.lr_step_begin_epoch(epoch) |
|
|
|
if self.quantizer is not None: |
|
self.quantizer.begin_epoch(epoch) |
|
|
|
|
|
self.task.begin_epoch(epoch, self.get_model()) |
|
|
|
if self.tpu: |
|
import torch_xla.core.xla_model as xm |
|
|
|
xm.rendezvous("begin_epoch") |
|
xm.mark_step() |
|
|
|
def begin_valid_epoch(self, epoch): |
|
"""Called at the beginning of each validation epoch.""" |
|
|
|
|
|
self.task.begin_valid_epoch(epoch, self.get_model()) |
|
|
|
def reset_dummy_batch(self, batch): |
|
self._dummy_batch = batch |
|
|
|
@metrics.aggregate("train") |
|
def train_step(self, samples, raise_oom=False): |
|
"""Do forward, backward and parameter update.""" |
|
self._set_seed() |
|
self.model.train() |
|
self.criterion.train() |
|
self.zero_grad() |
|
|
|
metrics.log_start_time("train_wall", priority=800, round=0) |
|
|
|
|
|
logging_outputs, sample_size, ooms = [], 0, 0 |
|
for i, sample in enumerate(samples): |
|
sample, is_dummy_batch = self._prepare_sample(sample) |
|
|
|
def maybe_no_sync(): |
|
""" |
|
Whenever *samples* contains more than one mini-batch, we |
|
want to accumulate gradients locally and only call |
|
all-reduce in the last backwards pass. |
|
""" |
|
if ( |
|
self.data_parallel_world_size > 1 |
|
and hasattr(self.model, "no_sync") |
|
and i < len(samples) - 1 |
|
): |
|
return self.model.no_sync() |
|
else: |
|
return contextlib.ExitStack() |
|
|
|
try: |
|
with maybe_no_sync(): |
|
|
|
loss, sample_size_i, logging_output = self.task.train_step( |
|
sample=sample, |
|
model=self.model, |
|
criterion=self.criterion, |
|
optimizer=self.optimizer, |
|
update_num=self.get_num_updates(), |
|
ignore_grad=is_dummy_batch, |
|
) |
|
del loss |
|
|
|
logging_outputs.append(logging_output) |
|
sample_size += sample_size_i |
|
|
|
|
|
|
|
if self.cuda and self.get_num_updates() == 0: |
|
torch.cuda.empty_cache() |
|
except RuntimeError as e: |
|
if "out of memory" in str(e): |
|
self._log_oom(e) |
|
if raise_oom: |
|
raise e |
|
logger.warning( |
|
"attempting to recover from OOM in forward/backward pass" |
|
) |
|
ooms += 1 |
|
self.zero_grad() |
|
if self.cuda: |
|
torch.cuda.empty_cache() |
|
if self.cfg.distributed_training.distributed_world_size == 1: |
|
return None |
|
else: |
|
raise e |
|
|
|
if self.tpu and i < len(samples) - 1: |
|
|
|
|
|
|
|
|
|
|
|
self._xla_markstep_and_send_to_cpu() |
|
|
|
if is_dummy_batch: |
|
if torch.is_tensor(sample_size): |
|
sample_size.zero_() |
|
else: |
|
sample_size *= 0.0 |
|
|
|
if torch.is_tensor(sample_size): |
|
sample_size = sample_size.float() |
|
else: |
|
sample_size = float(sample_size) |
|
|
|
|
|
if self._sync_stats(): |
|
train_time = self._local_cumulative_training_time() |
|
logging_outputs, ( |
|
sample_size, |
|
ooms, |
|
total_train_time, |
|
) = self._aggregate_logging_outputs( |
|
logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch |
|
) |
|
self._cumulative_training_time = ( |
|
total_train_time / self.data_parallel_world_size |
|
) |
|
|
|
overflow = False |
|
try: |
|
with torch.autograd.profiler.record_function("reduce-grads"): |
|
|
|
self.optimizer.all_reduce_grads(self.model) |
|
if utils.has_parameters(self.criterion): |
|
self.optimizer.all_reduce_grads(self.criterion) |
|
|
|
with torch.autograd.profiler.record_function("multiply-grads"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
numer = ( |
|
self.data_parallel_world_size |
|
if not self.cfg.optimization.use_bmuf or self._sync_stats() |
|
else 1 |
|
) |
|
self.optimizer.multiply_grads(numer / (sample_size or 1.0)) |
|
|
|
|
|
|
|
|
|
with torch.autograd.profiler.record_function("clip-grads"): |
|
|
|
grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) |
|
|
|
|
|
|
|
if not self.tpu: |
|
if ( |
|
not self.cfg.optimization.use_bmuf |
|
and self.cfg.distributed_training.ddp_backend != "slow_mo" |
|
): |
|
self._check_grad_norms(grad_norm) |
|
if not torch.isfinite(grad_norm).all(): |
|
|
|
|
|
if self.cfg.common.amp: |
|
overflow = True |
|
else: |
|
|
|
raise FloatingPointError("gradients are Nan/Inf") |
|
|
|
with torch.autograd.profiler.record_function("optimizer"): |
|
|
|
self.task.optimizer_step( |
|
self.optimizer, model=self.model, update_num=self.get_num_updates() |
|
) |
|
if self.cfg.common.amp and overflow: |
|
if self._amp_retries == self.cfg.common.amp_batch_retries: |
|
logger.info("AMP: skipping this batch.") |
|
self._amp_retries = 0 |
|
else: |
|
self._amp_retries += 1 |
|
return self.train_step(samples, raise_oom) |
|
|
|
except FloatingPointError: |
|
|
|
|
|
self.zero_grad() |
|
with NanDetector(self.get_model()): |
|
for _, sample in enumerate(samples): |
|
sample, _ = self._prepare_sample(sample) |
|
self.task.train_step( |
|
sample, |
|
self.model, |
|
self.criterion, |
|
self.optimizer, |
|
self.get_num_updates(), |
|
ignore_grad=False, |
|
) |
|
raise |
|
except OverflowError as e: |
|
overflow = True |
|
logger.info( |
|
f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}" |
|
) |
|
grad_norm = torch.tensor(0.0).cuda() |
|
self.zero_grad() |
|
except RuntimeError as e: |
|
if "out of memory" in str(e): |
|
self._log_oom(e) |
|
logger.error("OOM during optimization, irrecoverable") |
|
raise e |
|
|
|
|
|
|
|
if hasattr(self.model, "perform_additional_optimizer_actions"): |
|
if hasattr(self.optimizer, "fp32_params"): |
|
self.model.perform_additional_optimizer_actions( |
|
self.optimizer.optimizer, self.optimizer.fp32_params |
|
) |
|
else: |
|
self.model.perform_additional_optimizer_actions( |
|
self.optimizer.optimizer |
|
) |
|
|
|
logging_output = None |
|
if not overflow or self.cfg.distributed_training.ddp_backend == "slow_mo": |
|
self.set_num_updates(self.get_num_updates() + 1) |
|
|
|
if self.tpu: |
|
import torch_xla.core.xla_model as xm |
|
|
|
|
|
self._xla_markstep_and_send_to_cpu() |
|
|
|
|
|
|
|
logging_output = {} |
|
if self.get_num_updates() % self.cfg.common.log_interval == 0: |
|
|
|
mem_info = xm.get_memory_info(self.device) |
|
gb_free = mem_info["kb_free"] / 1024 / 1024 |
|
gb_total = mem_info["kb_total"] / 1024 / 1024 |
|
metrics.log_scalar( |
|
"gb_free", gb_free, priority=1500, round=1, weight=0 |
|
) |
|
metrics.log_scalar( |
|
"gb_total", gb_total, priority=1600, round=1, weight=0 |
|
) |
|
logging_outputs = self._xla_markstep_and_send_to_cpu( |
|
logging_outputs |
|
) |
|
logging_output = self._reduce_and_log_stats( |
|
logging_outputs, sample_size, grad_norm |
|
) |
|
|
|
|
|
|
|
|
|
self._check_xla_compilation() |
|
else: |
|
if self.cuda and self.cuda_env is not None: |
|
|
|
gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 |
|
torch.cuda.reset_peak_memory_stats() |
|
gb_free = self.cuda_env.total_memory_in_GB - gb_used |
|
metrics.log_scalar( |
|
"gb_free", gb_free, priority=1500, round=1, weight=0 |
|
) |
|
|
|
|
|
logging_output = self._reduce_and_log_stats( |
|
logging_outputs, sample_size, grad_norm |
|
) |
|
|
|
|
|
if ( |
|
self.cuda |
|
and self.cfg.common.empty_cache_freq > 0 |
|
and ( |
|
(self.get_num_updates() + self.cfg.common.empty_cache_freq - 1) |
|
% self.cfg.common.empty_cache_freq |
|
) |
|
== 0 |
|
): |
|
torch.cuda.empty_cache() |
|
|
|
if self.cfg.common.fp16 or self.cfg.common.amp: |
|
metrics.log_scalar( |
|
"loss_scale", |
|
( |
|
self.optimizer.scaler.loss_scale |
|
if self.cfg.common.fp16 |
|
else self.optimizer.scaler.get_scale() |
|
), |
|
priority=700, |
|
round=4, |
|
weight=0, |
|
) |
|
|
|
metrics.log_stop_time("train_wall") |
|
return logging_output |
|
|
|
@metrics.aggregate("valid") |
|
def valid_step(self, sample, raise_oom=False): |
|
"""Do forward pass in evaluation mode.""" |
|
if self.tpu: |
|
import torch_xla.core.xla_model as xm |
|
|
|
xm.rendezvous("valid_step") |
|
|
|
with torch.no_grad(): |
|
self.model.eval() |
|
self.criterion.eval() |
|
|
|
sample, is_dummy_batch = self._prepare_sample(sample) |
|
|
|
try: |
|
_loss, sample_size, logging_output = self.task.valid_step( |
|
sample, self.model, self.criterion |
|
) |
|
except RuntimeError as e: |
|
if "out of memory" in str(e): |
|
self._log_oom(e) |
|
if not raise_oom: |
|
logger.warning( |
|
"ran out of memory in validation step, retrying batch" |
|
) |
|
for p in self.model.parameters(): |
|
if p.grad is not None: |
|
p.grad = None |
|
if self.cuda: |
|
torch.cuda.empty_cache() |
|
return self.valid_step(sample, raise_oom=True) |
|
raise e |
|
|
|
logging_outputs = [logging_output] |
|
if is_dummy_batch: |
|
if torch.is_tensor(sample_size): |
|
sample_size.zero_() |
|
else: |
|
sample_size *= 0.0 |
|
|
|
|
|
if self.data_parallel_world_size > 1: |
|
logging_outputs, (sample_size,) = self._aggregate_logging_outputs( |
|
logging_outputs, |
|
sample_size, |
|
ignore=is_dummy_batch, |
|
) |
|
|
|
|
|
if self.tpu: |
|
logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) |
|
logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) |
|
|
|
return logging_output |
|
|
|
def zero_grad(self): |
|
self.optimizer.zero_grad() |
|
|
|
def lr_step_begin_epoch(self, epoch): |
|
"""Adjust the learning rate at the beginning of the epoch.""" |
|
self.lr_scheduler.step_begin_epoch(epoch) |
|
|
|
return self.lr_step_update() |
|
|
|
def lr_step(self, epoch, val_loss=None): |
|
"""Adjust the learning rate at the end of the epoch.""" |
|
self.lr_scheduler.step(epoch, val_loss) |
|
|
|
return self.lr_step_update() |
|
|
|
def lr_step_update(self): |
|
"""Update the learning rate after each update.""" |
|
new_lr = self.lr_scheduler.step_update(self.get_num_updates()) |
|
if isinstance(new_lr, dict): |
|
for k, v in new_lr.items(): |
|
metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300) |
|
new_lr = new_lr.get("default", next(iter(new_lr.values()))) |
|
else: |
|
metrics.log_scalar("lr", new_lr, weight=0, priority=300) |
|
return new_lr |
|
|
|
def get_lr(self): |
|
"""Get the current learning rate.""" |
|
return self.optimizer.get_lr() |
|
|
|
def get_model(self): |
|
"""Get the (non-wrapped) model instance.""" |
|
return self._model |
|
|
|
def get_criterion(self): |
|
"""Get the (non-wrapped) criterion instance.""" |
|
return self._criterion |
|
|
|
def get_meter(self, name): |
|
"""[deprecated] Get a specific meter by name.""" |
|
from fairseq import meters |
|
|
|
if "get_meter" not in self._warn_once: |
|
self._warn_once.add("get_meter") |
|
utils.deprecation_warning( |
|
"Trainer.get_meter is deprecated. Please use fairseq.metrics instead." |
|
) |
|
|
|
train_meters = metrics.get_meters("train") |
|
if train_meters is None: |
|
train_meters = {} |
|
|
|
if name == "train_loss" and "loss" in train_meters: |
|
return train_meters["loss"] |
|
elif name == "train_nll_loss": |
|
|
|
|
|
m = train_meters.get("nll_loss", None) |
|
return m or meters.AverageMeter() |
|
elif name == "wall": |
|
|
|
|
|
m = metrics.get_meter("default", "wall") |
|
return m or meters.TimeMeter() |
|
elif name == "wps": |
|
m = metrics.get_meter("train", "wps") |
|
return m or meters.TimeMeter() |
|
elif name in {"valid_loss", "valid_nll_loss"}: |
|
|
|
|
|
k = name[len("valid_") :] |
|
m = metrics.get_meter("valid", k) |
|
return m or meters.AverageMeter() |
|
elif name == "oom": |
|
return meters.AverageMeter() |
|
elif name in train_meters: |
|
return train_meters[name] |
|
return None |
|
|
|
def get_num_updates(self): |
|
"""Get the number of parameters updates.""" |
|
return self._num_updates |
|
|
|
def set_num_updates(self, num_updates): |
|
"""Set the number of parameters updates.""" |
|
self._num_updates = num_updates |
|
self.lr_step_update() |
|
if self.quantizer: |
|
self.quantizer.step_update(self._num_updates) |
|
metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) |
|
|
|
def clip_grad_norm(self, clip_norm): |
|
def agg_norm_fn(total_norm): |
|
total_norm = total_norm.cuda().float() ** 2 |
|
total_norm = distributed_utils.all_reduce( |
|
total_norm, group=self.data_parallel_process_group |
|
) |
|
return total_norm ** 0.5 |
|
|
|
should_agg_norm = ( |
|
self.cfg.distributed_training.ddp_backend == "fully_sharded" |
|
and ( |
|
self.data_parallel_process_group is not None |
|
or torch.distributed.is_initialized() |
|
) |
|
) |
|
return self.optimizer.clip_grad_norm( |
|
clip_norm, aggregate_norm_fn=agg_norm_fn if should_agg_norm else None |
|
) |
|
|
|
def cumulative_training_time(self): |
|
if self._cumulative_training_time is None: |
|
|
|
return self._local_cumulative_training_time() |
|
else: |
|
return self._cumulative_training_time |
|
|
|
def _local_cumulative_training_time(self): |
|
"""Aggregate training time in seconds.""" |
|
return time.time() - self._start_time + self._previous_training_time |
|
|
|
def _fp_convert_sample(self, sample): |
|
def apply_half(t): |
|
if t.dtype is torch.float32: |
|
return t.to(dtype=torch.half) |
|
return t |
|
|
|
def apply_bfloat16(t): |
|
if t.dtype is torch.float32: |
|
return t.to(dtype=torch.bfloat16) |
|
return t |
|
|
|
if self.cfg.common.fp16: |
|
sample = utils.apply_to_sample(apply_half, sample) |
|
|
|
if self.cfg.common.bf16: |
|
sample = utils.apply_to_sample(apply_bfloat16, sample) |
|
|
|
return sample |
|
|
|
def _prepare_sample(self, sample, is_dummy=False): |
|
if sample == "DUMMY": |
|
raise Exception( |
|
"Trying to use an uninitialized 'dummy' batch. This usually indicates " |
|
"that the total number of batches is smaller than the number of " |
|
"participating GPUs. Try reducing the batch size or using fewer GPUs." |
|
) |
|
|
|
if sample is None or len(sample) == 0: |
|
assert ( |
|
self._dummy_batch is not None and len(self._dummy_batch) > 0 |
|
), "Invalid dummy batch: {}".format(self._dummy_batch) |
|
sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) |
|
return sample, True |
|
|
|
|
|
|
|
|
|
|
|
if self.cfg.common.on_cpu_convert_precision: |
|
sample = self._fp_convert_sample(sample) |
|
|
|
if self.cuda: |
|
if self.pipeline_model_parallel: |
|
if 'target' in sample: |
|
sample['target'] = utils.move_to_cuda(sample['target'], device=self.last_device) |
|
else: |
|
sample = utils.move_to_cuda(sample) |
|
elif self.tpu and is_dummy: |
|
|
|
sample = utils.move_to_cuda(sample, device=self.device) |
|
|
|
if not self.cfg.common.on_cpu_convert_precision: |
|
sample = self._fp_convert_sample(sample) |
|
|
|
if self._dummy_batch == "DUMMY": |
|
self._dummy_batch = sample |
|
|
|
return sample, False |
|
|
|
def _set_seed(self): |
|
|
|
|
|
seed = self.cfg.common.seed + self.get_num_updates() |
|
utils.set_torch_seed(seed) |
|
|
|
def _sync_stats(self): |
|
|
|
|
|
if self.data_parallel_world_size == 1: |
|
return False |
|
elif self.cfg.optimization.use_bmuf: |
|
return ( |
|
self.get_num_updates() + 1 |
|
) % self.cfg.bmuf.global_sync_iter == 0 and ( |
|
self.get_num_updates() + 1 |
|
) > self.cfg.bmuf.warmup_iterations |
|
else: |
|
return True |
|
|
|
def _log_oom(self, exc): |
|
msg = "OOM: Ran out of memory with exception: {}".format(exc) |
|
logger.warning(msg) |
|
if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): |
|
for device_idx in range(torch.cuda.device_count()): |
|
logger.warning(torch.cuda.memory_summary(device=device_idx)) |
|
sys.stderr.flush() |
|
|
|
def _aggregate_logging_outputs( |
|
self, |
|
logging_outputs: List[Dict[str, Any]], |
|
*extra_stats_to_sum, |
|
ignore=False, |
|
): |
|
if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): |
|
return self._fast_stat_sync_sum( |
|
logging_outputs, *extra_stats_to_sum, ignore=ignore |
|
) |
|
else: |
|
return self._all_gather_list_sync( |
|
logging_outputs, *extra_stats_to_sum, ignore=ignore |
|
) |
|
|
|
def _all_gather_list_sync( |
|
self, |
|
logging_outputs: List[Dict[str, Any]], |
|
*extra_stats_to_sum, |
|
ignore=False, |
|
): |
|
""" |
|
Sync logging outputs across workers. all_gather_list_sync is |
|
suitable when logging outputs are complex types. |
|
""" |
|
if self.tpu: |
|
raise NotImplementedError |
|
if ignore: |
|
logging_outputs = [] |
|
results = list( |
|
zip( |
|
*distributed_utils.all_gather_list( |
|
[logging_outputs] + list(extra_stats_to_sum), |
|
max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), |
|
group=self.data_parallel_process_group, |
|
) |
|
) |
|
) |
|
logging_outputs, extra_stats_to_sum = results[0], results[1:] |
|
logging_outputs = list(chain.from_iterable(logging_outputs)) |
|
extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] |
|
return logging_outputs, extra_stats_to_sum |
|
|
|
def _fast_stat_sync_sum( |
|
self, |
|
logging_outputs: List[Dict[str, Any]], |
|
*extra_stats_to_sum, |
|
ignore=False, |
|
): |
|
""" |
|
Sync logging outputs across workers. fast_stat_sync_sum is |
|
faster than all_gather_list_sync, but is only suitable when |
|
logging outputs are scalars and can be summed. Note that |
|
*logging_outputs* cannot contain any nested dicts/lists. |
|
""" |
|
data = {} |
|
for i, stat in enumerate(extra_stats_to_sum): |
|
data["extra_stats_" + str(i)] = stat |
|
if len(logging_outputs) > 0: |
|
log_keys = list(logging_outputs[0].keys()) |
|
for k in log_keys: |
|
if not ignore: |
|
v = sum(log[k] for log in logging_outputs if k in log) |
|
else: |
|
v = logging_outputs[0][k] |
|
v = torch.zeros_like(v) if torch.is_tensor(v) else 0 |
|
data["logging_outputs_" + k] = v |
|
else: |
|
log_keys = None |
|
|
|
data = distributed_utils.all_reduce_dict( |
|
data, device=self.device, group=self.data_parallel_process_group |
|
) |
|
|
|
extra_stats_to_sum = [ |
|
data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum)) |
|
] |
|
if log_keys is not None: |
|
logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}] |
|
else: |
|
logging_outputs = [] |
|
return logging_outputs, extra_stats_to_sum |
|
|
|
def _check_grad_norms(self, grad_norm): |
|
"""Check that grad norms are consistent across workers.""" |
|
if self._grad_norm_buf is not None: |
|
self._grad_norm_buf.zero_() |
|
self._grad_norm_buf[self.data_parallel_rank] = grad_norm |
|
distributed_utils.all_reduce( |
|
self._grad_norm_buf, group=self.data_parallel_process_group |
|
) |
|
|
|
def is_consistent(tensor): |
|
max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) |
|
return ( |
|
(torch.isfinite(tensor).all() |
|
and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all()) |
|
or |
|
(self.cfg.common.amp and not torch.isfinite(tensor).all()) |
|
|
|
) |
|
|
|
if not is_consistent(self._grad_norm_buf): |
|
pretty_detail = "\n".join( |
|
"rank {:3d} = {:.8f}".format(r, n) |
|
for r, n in enumerate(self._grad_norm_buf.tolist()) |
|
) |
|
error_detail = "grad_norm across the workers:\n{}\n".format( |
|
pretty_detail |
|
) |
|
|
|
raise FloatingPointError( |
|
"Fatal error: gradients are inconsistent between workers. " |
|
"Try --ddp-backend=legacy_ddp. " |
|
"Or are you mixing up different generation of GPUs in training?" |
|
+ "\n" |
|
+ "-" * 80 |
|
+ "\n{}\n".format(error_detail) |
|
+ "-" * 80 |
|
) |
|
|
|
def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): |
|
if grad_norm is not None and ( |
|
not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm) |
|
): |
|
metrics.log_speed("ups", 1.0, priority=100, round=2) |
|
metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) |
|
if self.cfg.optimization.clip_norm > 0: |
|
metrics.log_scalar( |
|
"clip", |
|
torch.where( |
|
grad_norm > self.cfg.optimization.clip_norm, |
|
grad_norm.new_tensor(100), |
|
grad_norm.new_tensor(0), |
|
), |
|
priority=500, |
|
round=1, |
|
) |
|
|
|
with metrics.aggregate() as agg: |
|
if logging_outputs is not None: |
|
self.task.reduce_metrics(logging_outputs, self.get_criterion()) |
|
del logging_outputs |
|
|
|
|
|
if "loss" not in agg: |
|
if "loss" not in self._warn_once: |
|
self._warn_once.add("loss") |
|
logger.warning( |
|
"Criterion.reduce_metrics did not log a 'loss' value, " |
|
"which may break some functionality" |
|
) |
|
metrics.log_scalar("loss", -1) |
|
|
|
|
|
if self.tpu: |
|
logging_output = {} |
|
else: |
|
logging_output = agg.get_smoothed_values() |
|
logging_output["sample_size"] = sample_size |
|
for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: |
|
if key_to_delete in logging_output: |
|
del logging_output[key_to_delete] |
|
return logging_output |
|
|
|
def _check_xla_compilation(self): |
|
import torch_xla.debug.metrics as met |
|
|
|
compile_stats = met.metric_data("CompileTime") |
|
if compile_stats is None: |
|
return |
|
num_xla_compiles = compile_stats[0] |
|
if num_xla_compiles > self._num_xla_compiles: |
|
logger.warning( |
|
"XLA compilation detected on device #{}; too many of these can lead " |
|
"to slow training, but we expect a few in the beginning".format( |
|
self.cfg.distributed_training.distributed_rank |
|
) |
|
) |
|
self._num_xla_compiles = num_xla_compiles |
|
|
|
def _xla_markstep_and_send_to_cpu(self, data=None): |
|
import torch_xla.core.xla_model as xm |
|
|
|
xm.mark_step() |
|
if data is not None: |
|
from fairseq.utils import xla_device_to_cpu |
|
|
|
return xla_device_to_cpu(data) |
|
|
|
|
|
def _catalog_shared_params(module, memo=None, prefix=""): |
|
if memo is None: |
|
first_call = True |
|
memo = {} |
|
else: |
|
first_call = False |
|
for name, param in module._parameters.items(): |
|
param_prefix = prefix + ("." if prefix else "") + name |
|
if param not in memo: |
|
memo[param] = [] |
|
memo[param].append(param_prefix) |
|
for name, m in module._modules.items(): |
|
if m is None: |
|
continue |
|
submodule_prefix = prefix + ("." if prefix else "") + name |
|
_catalog_shared_params(m, memo, submodule_prefix) |
|
if first_call: |
|
return [x for x in memo.values() if len(x) > 1] |
|
|
|
|
|
def _get_module_by_path(module, path): |
|
path = path.split(".") |
|
for name in path: |
|
module = getattr(module, name) |
|
return module |
|
|
|
|
|
def _set_module_by_path(module, path, value): |
|
path = path.split(".") |
|
for name in path[:-1]: |
|
module = getattr(module, name) |
|
setattr(module, path[-1], value) |
|
|