brats_mri_axial_slices_generative_diffusion / configs /multi_gpu_train_autoencoder.json
katielink's picture
fix the wrong GPU index issue of multi-node
2462f96
{
"device": "$torch.device('cuda:' + os.environ['LOCAL_RANK'])",
"gnetwork": {
"_target_": "torch.nn.parallel.DistributedDataParallel",
"module": "$@autoencoder_def.to(@device)",
"device_ids": [
"@device"
],
"find_unused_parameters": true
},
"dnetwork": {
"_target_": "torch.nn.parallel.DistributedDataParallel",
"module": "$@discriminator_def.to(@device)",
"device_ids": [
"@device"
],
"find_unused_parameters": true
},
"train#sampler": {
"_target_": "DistributedSampler",
"dataset": "@train#dataset",
"even_divisible": true,
"shuffle": true
},
"train#dataloader#sampler": "@train#sampler",
"train#dataloader#shuffle": false,
"train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
"initialize": [
"$import torch.distributed as dist",
"$import os",
"$dist.is_initialized() or dist.init_process_group(backend='nccl')",
"$torch.cuda.set_device(@device)",
"$monai.utils.set_determinism(seed=123)",
"$import logging",
"$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)"
],
"run": [
"$@train#trainer.run()"
],
"finalize": [
"$dist.is_initialized() and dist.destroy_process_group()"
]
}