|
{ |
|
"device": "$torch.device(f'cuda:{dist.get_rank()}')", |
|
"network": { |
|
"_target_": "torch.nn.parallel.DistributedDataParallel", |
|
"module": "$@network_def.to(@device)", |
|
"device_ids": [ |
|
"@device" |
|
] |
|
}, |
|
"train#sampler": { |
|
"_target_": "DistributedSampler", |
|
"dataset": "@train#dataset", |
|
"even_divisible": true, |
|
"shuffle": true |
|
}, |
|
"train#dataloader#sampler": "@train#sampler", |
|
"train#dataloader#shuffle": false, |
|
"train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]", |
|
"validate#sampler": { |
|
"_target_": "DistributedSampler", |
|
"dataset": "@validate#dataset", |
|
"even_divisible": false, |
|
"shuffle": false |
|
}, |
|
"validate#dataloader#sampler": "@validate#sampler", |
|
"validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers", |
|
"training": [ |
|
"$import sys", |
|
"$sys.path.append(@bundle_root)", |
|
"$import torch.distributed as dist", |
|
"$dist.init_process_group(backend='nccl')", |
|
"$torch.cuda.set_device(@device)", |
|
"$monai.utils.set_determinism(seed=123)", |
|
"$setattr(torch.backends.cudnn, 'benchmark', True)", |
|
"$import logging", |
|
"$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)", |
|
"$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)", |
|
"$@train#trainer.run()", |
|
"$dist.destroy_process_group()" |
|
] |
|
} |
|
|