glenn-jocher commited on
Commit
fad27c0
1 Parent(s): 5bab9a2

Update DDP for `torch.distributed.run` with `gloo` backend (#3680)

Browse files

* Update DDP for `torch.distributed.run`

* Add LOCAL_RANK

* remove opt.local_rank

* backend="gloo|nccl"

* print

* print

* debug

* debug

* os.getenv

* gloo

* gloo

* gloo

* cleanup

* fix getenv

* cleanup

* cleanup destroy

* try nccl

* return opt

* add --local_rank

* add timeout

* add init_method

* gloo

* move destroy

* move destroy

* move print(opt) under if RANK

* destroy only RANK 0

* move destroy inside train()

* restore destroy outside train()

* update print(opt)

* cleanup

* nccl

* gloo with 60 second timeout

* update namespace printing

detect.py CHANGED
@@ -8,8 +8,8 @@ import torch.backends.cudnn as cudnn
8
 
9
  from models.experimental import attempt_load
10
  from utils.datasets import LoadStreams, LoadImages
11
- from utils.general import check_img_size, check_requirements, check_imshow, non_max_suppression, apply_classifier, \
12
- scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
13
  from utils.plots import colors, plot_one_box
14
  from utils.torch_utils import select_device, load_classifier, time_synchronized
15
 
@@ -202,7 +202,7 @@ def parse_opt():
202
 
203
 
204
  def main(opt):
205
- print(opt)
206
  check_requirements(exclude=('tensorboard', 'thop'))
207
  detect(**vars(opt))
208
 
 
8
 
9
  from models.experimental import attempt_load
10
  from utils.datasets import LoadStreams, LoadImages
11
+ from utils.general import check_img_size, check_requirements, check_imshow, colorstr, non_max_suppression, \
12
+ apply_classifier, scale_coords, xyxy2xywh, strip_optimizer, set_logging, increment_path, save_one_box
13
  from utils.plots import colors, plot_one_box
14
  from utils.torch_utils import select_device, load_classifier, time_synchronized
15
 
 
202
 
203
 
204
  def main(opt):
205
+ print(colorstr('detect: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
206
  check_requirements(exclude=('tensorboard', 'thop'))
207
  detect(**vars(opt))
208
 
models/export.py CHANGED
@@ -163,8 +163,8 @@ def parse_opt():
163
 
164
 
165
  def main(opt):
166
- print(opt)
167
  set_logging()
 
168
  export(**vars(opt))
169
 
170
 
 
163
 
164
 
165
  def main(opt):
 
166
  set_logging()
167
+ print(colorstr('export: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
168
  export(**vars(opt))
169
 
170
 
test.py CHANGED
@@ -51,7 +51,6 @@ def test(data,
51
  device = next(model.parameters()).device # get model device
52
 
53
  else: # called directly
54
- set_logging()
55
  device = select_device(device, batch_size=batch_size)
56
 
57
  # Directories
@@ -323,7 +322,8 @@ def parse_opt():
323
 
324
 
325
  def main(opt):
326
- print(opt)
 
327
  check_requirements(exclude=('tensorboard', 'thop'))
328
 
329
  if opt.task in ('train', 'val', 'test'): # run normally
 
51
  device = next(model.parameters()).device # get model device
52
 
53
  else: # called directly
 
54
  device = select_device(device, batch_size=batch_size)
55
 
56
  # Directories
 
322
 
323
 
324
  def main(opt):
325
+ set_logging()
326
+ print(colorstr('test: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
327
  check_requirements(exclude=('tensorboard', 'thop'))
328
 
329
  if opt.task in ('train', 'val', 'test'): # run normally
train.py CHANGED
@@ -37,15 +37,17 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
37
  from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
38
 
39
  logger = logging.getLogger(__name__)
 
 
 
40
 
41
 
42
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
43
  opt,
44
  device,
45
  ):
46
- save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
47
- Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
48
- opt.single_cls
49
 
50
  # Directories
51
  wdir = save_dir / 'weights'
@@ -69,13 +71,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
69
  # Configure
70
  plots = not opt.evolve # create plots
71
  cuda = device.type != 'cpu'
72
- init_seeds(2 + rank)
73
  with open(opt.data) as f:
74
  data_dict = yaml.safe_load(f) # data dict
75
 
76
  # Loggers
77
  loggers = {'wandb': None, 'tb': None} # loggers dict
78
- if rank in [-1, 0]:
79
  # TensorBoard
80
  if not opt.evolve:
81
  prefix = colorstr('tensorboard: ')
@@ -99,7 +101,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
99
  # Model
100
  pretrained = weights.endswith('.pt')
101
  if pretrained:
102
- with torch_distributed_zero_first(rank):
103
  weights = attempt_download(weights) # download if not found locally
104
  ckpt = torch.load(weights, map_location=device) # load checkpoint
105
  model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
@@ -110,7 +112,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
110
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
111
  else:
112
  model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
113
- with torch_distributed_zero_first(rank):
114
  check_dataset(data_dict) # check
115
  train_path = data_dict['train']
116
  test_path = data_dict['val']
@@ -158,7 +160,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
158
  # plot_lr_scheduler(optimizer, scheduler, epochs)
159
 
160
  # EMA
161
- ema = ModelEMA(model) if rank in [-1, 0] else None
162
 
163
  # Resume
164
  start_epoch, best_fitness = 0, 0.0
@@ -194,28 +196,28 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
194
  imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
195
 
196
  # DP mode
197
- if cuda and rank == -1 and torch.cuda.device_count() > 1:
198
  model = torch.nn.DataParallel(model)
199
 
200
  # SyncBatchNorm
201
- if opt.sync_bn and cuda and rank != -1:
202
  model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
203
  logger.info('Using SyncBatchNorm()')
204
 
205
  # Trainloader
206
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
207
- hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=rank,
208
- world_size=opt.world_size, workers=opt.workers,
209
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
210
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
211
  nb = len(dataloader) # number of batches
212
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
213
 
214
  # Process 0
215
- if rank in [-1, 0]:
216
  testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
217
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
218
- world_size=opt.world_size, workers=opt.workers,
219
  pad=0.5, prefix=colorstr('val: '))[0]
220
 
221
  if not opt.resume:
@@ -234,8 +236,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
234
  model.half().float() # pre-reduce anchor precision
235
 
236
  # DDP mode
237
- if cuda and rank != -1:
238
- model = DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank,
239
  # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
240
  find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
241
 
@@ -269,15 +271,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
269
  # Update image weights (optional)
270
  if opt.image_weights:
271
  # Generate indices
272
- if rank in [-1, 0]:
273
  cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
274
  iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
275
  dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
276
  # Broadcast if DDP
277
- if rank != -1:
278
- indices = (torch.tensor(dataset.indices) if rank == 0 else torch.zeros(dataset.n)).int()
279
  dist.broadcast(indices, 0)
280
- if rank != 0:
281
  dataset.indices = indices.cpu().numpy()
282
 
283
  # Update mosaic border
@@ -285,11 +287,11 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
285
  # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
286
 
287
  mloss = torch.zeros(4, device=device) # mean losses
288
- if rank != -1:
289
  dataloader.sampler.set_epoch(epoch)
290
  pbar = enumerate(dataloader)
291
  logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
292
- if rank in [-1, 0]:
293
  pbar = tqdm(pbar, total=nb) # progress bar
294
  optimizer.zero_grad()
295
  for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
@@ -319,8 +321,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
319
  with amp.autocast(enabled=cuda):
320
  pred = model(imgs) # forward
321
  loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
322
- if rank != -1:
323
- loss *= opt.world_size # gradient averaged between devices in DDP mode
324
  if opt.quad:
325
  loss *= 4.
326
 
@@ -336,7 +338,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
336
  ema.update(model)
337
 
338
  # Print
339
- if rank in [-1, 0]:
340
  mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
341
  mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
342
  s = ('%10s' * 2 + '%10.4g' * 6) % (
@@ -362,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
362
  scheduler.step()
363
 
364
  # DDP process 0 or single-GPU
365
- if rank in [-1, 0]:
366
  # mAP
367
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
368
  final_epoch = epoch + 1 == epochs
@@ -424,7 +426,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
424
 
425
  # end epoch ----------------------------------------------------------------------------------------------------
426
  # end training -----------------------------------------------------------------------------------------------------
427
- if rank in [-1, 0]:
428
  logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
429
  if plots:
430
  plot_results(save_dir=save_dir) # save as results.png
@@ -457,8 +459,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
457
  name='run_' + wandb_logger.wandb_run.id + '_model',
458
  aliases=['latest', 'best', 'stripped'])
459
  wandb_logger.finish_run()
460
- else:
461
- dist.destroy_process_group()
462
  torch.cuda.empty_cache()
463
  return results
464
 
@@ -486,7 +487,6 @@ def parse_opt():
486
  parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
487
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
488
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
489
- parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
490
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
491
  parser.add_argument('--project', default='runs/train', help='save to project/name')
492
  parser.add_argument('--entity', default=None, help='W&B entity')
@@ -499,18 +499,15 @@ def parse_opt():
499
  parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
500
  parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
501
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
 
502
  opt = parser.parse_args()
503
-
504
- # Set DDP variables
505
- opt.world_size = int(getattr(os.environ, 'WORLD_SIZE', 1))
506
- opt.global_rank = int(getattr(os.environ, 'RANK', -1))
507
  return opt
508
 
509
 
510
  def main(opt):
511
- print(opt)
512
- set_logging(opt.global_rank)
513
- if opt.global_rank in [-1, 0]:
514
  check_git_status()
515
  check_requirements(exclude=['thop'])
516
 
@@ -519,11 +516,9 @@ def main(opt):
519
  if opt.resume and not wandb_run: # resume an interrupted run
520
  ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
521
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
522
- apriori = opt.global_rank, opt.local_rank
523
  with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
524
  opt = argparse.Namespace(**yaml.safe_load(f)) # replace
525
- opt.cfg, opt.weights, opt.resume, opt.batch_size, opt.global_rank, opt.local_rank = \
526
- '', ckpt, True, opt.total_batch_size, *apriori # reinstate
527
  logger.info('Resuming training from %s' % ckpt)
528
  else:
529
  # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
@@ -536,19 +531,21 @@ def main(opt):
536
  # DDP mode
537
  opt.total_batch_size = opt.batch_size
538
  device = select_device(opt.device, batch_size=opt.batch_size)
539
- if opt.local_rank != -1:
540
- assert torch.cuda.device_count() > opt.local_rank
541
- torch.cuda.set_device(opt.local_rank)
542
- device = torch.device('cuda', opt.local_rank)
543
- dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
544
- assert opt.batch_size % opt.world_size == 0, '--batch-size must be multiple of CUDA device count'
 
545
  assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
546
- opt.batch_size = opt.total_batch_size // opt.world_size
547
 
548
  # Train
549
- logger.info(opt)
550
  if not opt.evolve:
551
  train(opt.hyp, opt, device)
 
 
552
 
553
  # Evolve hyperparameters (optional)
554
  else:
@@ -584,7 +581,7 @@ def main(opt):
584
 
585
  with open(opt.hyp) as f:
586
  hyp = yaml.safe_load(f) # load hyps dict
587
- assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
588
  opt.notest, opt.nosave = True, True # only test/save final epoch
589
  # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
590
  yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
 
37
  from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
38
 
39
  logger = logging.getLogger(__name__)
40
+ LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
41
+ RANK = int(os.getenv('RANK', -1))
42
+ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
43
 
44
 
45
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
46
  opt,
47
  device,
48
  ):
49
+ save_dir, epochs, batch_size, total_batch_size, weights, single_cls = \
50
+ Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.single_cls
 
51
 
52
  # Directories
53
  wdir = save_dir / 'weights'
 
71
  # Configure
72
  plots = not opt.evolve # create plots
73
  cuda = device.type != 'cpu'
74
+ init_seeds(2 + RANK)
75
  with open(opt.data) as f:
76
  data_dict = yaml.safe_load(f) # data dict
77
 
78
  # Loggers
79
  loggers = {'wandb': None, 'tb': None} # loggers dict
80
+ if RANK in [-1, 0]:
81
  # TensorBoard
82
  if not opt.evolve:
83
  prefix = colorstr('tensorboard: ')
 
101
  # Model
102
  pretrained = weights.endswith('.pt')
103
  if pretrained:
104
+ with torch_distributed_zero_first(RANK):
105
  weights = attempt_download(weights) # download if not found locally
106
  ckpt = torch.load(weights, map_location=device) # load checkpoint
107
  model = Model(opt.cfg or ckpt['model'].yaml, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
 
112
  logger.info('Transferred %g/%g items from %s' % (len(state_dict), len(model.state_dict()), weights)) # report
113
  else:
114
  model = Model(opt.cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device) # create
115
+ with torch_distributed_zero_first(RANK):
116
  check_dataset(data_dict) # check
117
  train_path = data_dict['train']
118
  test_path = data_dict['val']
 
160
  # plot_lr_scheduler(optimizer, scheduler, epochs)
161
 
162
  # EMA
163
+ ema = ModelEMA(model) if RANK in [-1, 0] else None
164
 
165
  # Resume
166
  start_epoch, best_fitness = 0, 0.0
 
196
  imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size] # verify imgsz are gs-multiples
197
 
198
  # DP mode
199
+ if cuda and RANK == -1 and torch.cuda.device_count() > 1:
200
  model = torch.nn.DataParallel(model)
201
 
202
  # SyncBatchNorm
203
+ if opt.sync_bn and cuda and RANK != -1:
204
  model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
205
  logger.info('Using SyncBatchNorm()')
206
 
207
  # Trainloader
208
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, single_cls,
209
+ hyp=hyp, augment=True, cache=opt.cache_images, rect=opt.rect, rank=RANK,
210
+ workers=opt.workers,
211
  image_weights=opt.image_weights, quad=opt.quad, prefix=colorstr('train: '))
212
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
213
  nb = len(dataloader) # number of batches
214
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
215
 
216
  # Process 0
217
+ if RANK in [-1, 0]:
218
  testloader = create_dataloader(test_path, imgsz_test, batch_size * 2, gs, single_cls,
219
  hyp=hyp, cache=opt.cache_images and not opt.notest, rect=True, rank=-1,
220
+ workers=opt.workers,
221
  pad=0.5, prefix=colorstr('val: '))[0]
222
 
223
  if not opt.resume:
 
236
  model.half().float() # pre-reduce anchor precision
237
 
238
  # DDP mode
239
+ if cuda and RANK != -1:
240
+ model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK,
241
  # nn.MultiheadAttention incompatibility with DDP https://github.com/pytorch/pytorch/issues/26698
242
  find_unused_parameters=any(isinstance(layer, nn.MultiheadAttention) for layer in model.modules()))
243
 
 
271
  # Update image weights (optional)
272
  if opt.image_weights:
273
  # Generate indices
274
+ if RANK in [-1, 0]:
275
  cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc # class weights
276
  iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw) # image weights
277
  dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n) # rand weighted idx
278
  # Broadcast if DDP
279
+ if RANK != -1:
280
+ indices = (torch.tensor(dataset.indices) if RANK == 0 else torch.zeros(dataset.n)).int()
281
  dist.broadcast(indices, 0)
282
+ if RANK != 0:
283
  dataset.indices = indices.cpu().numpy()
284
 
285
  # Update mosaic border
 
287
  # dataset.mosaic_border = [b - imgsz, -b] # height, width borders
288
 
289
  mloss = torch.zeros(4, device=device) # mean losses
290
+ if RANK != -1:
291
  dataloader.sampler.set_epoch(epoch)
292
  pbar = enumerate(dataloader)
293
  logger.info(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'box', 'obj', 'cls', 'total', 'labels', 'img_size'))
294
+ if RANK in [-1, 0]:
295
  pbar = tqdm(pbar, total=nb) # progress bar
296
  optimizer.zero_grad()
297
  for i, (imgs, targets, paths, _) in pbar: # batch -------------------------------------------------------------
 
321
  with amp.autocast(enabled=cuda):
322
  pred = model(imgs) # forward
323
  loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
324
+ if RANK != -1:
325
+ loss *= WORLD_SIZE # gradient averaged between devices in DDP mode
326
  if opt.quad:
327
  loss *= 4.
328
 
 
338
  ema.update(model)
339
 
340
  # Print
341
+ if RANK in [-1, 0]:
342
  mloss = (mloss * i + loss_items) / (i + 1) # update mean losses
343
  mem = '%.3gG' % (torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0) # (GB)
344
  s = ('%10s' * 2 + '%10.4g' * 6) % (
 
364
  scheduler.step()
365
 
366
  # DDP process 0 or single-GPU
367
+ if RANK in [-1, 0]:
368
  # mAP
369
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'gr', 'names', 'stride', 'class_weights'])
370
  final_epoch = epoch + 1 == epochs
 
426
 
427
  # end epoch ----------------------------------------------------------------------------------------------------
428
  # end training -----------------------------------------------------------------------------------------------------
429
+ if RANK in [-1, 0]:
430
  logger.info(f'{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.\n')
431
  if plots:
432
  plot_results(save_dir=save_dir) # save as results.png
 
459
  name='run_' + wandb_logger.wandb_run.id + '_model',
460
  aliases=['latest', 'best', 'stripped'])
461
  wandb_logger.finish_run()
462
+
 
463
  torch.cuda.empty_cache()
464
  return results
465
 
 
487
  parser.add_argument('--single-cls', action='store_true', help='train multi-class data as single-class')
488
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
489
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
 
490
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
491
  parser.add_argument('--project', default='runs/train', help='save to project/name')
492
  parser.add_argument('--entity', default=None, help='W&B entity')
 
499
  parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
500
  parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
501
  parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
502
+ parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
503
  opt = parser.parse_args()
 
 
 
 
504
  return opt
505
 
506
 
507
  def main(opt):
508
+ set_logging(RANK)
509
+ if RANK in [-1, 0]:
510
+ print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
511
  check_git_status()
512
  check_requirements(exclude=['thop'])
513
 
 
516
  if opt.resume and not wandb_run: # resume an interrupted run
517
  ckpt = opt.resume if isinstance(opt.resume, str) else get_latest_run() # specified or most recent path
518
  assert os.path.isfile(ckpt), 'ERROR: --resume checkpoint does not exist'
 
519
  with open(Path(ckpt).parent.parent / 'opt.yaml') as f:
520
  opt = argparse.Namespace(**yaml.safe_load(f)) # replace
521
+ opt.cfg, opt.weights, opt.resume, opt.batch_size = '', ckpt, True, opt.total_batch_size # reinstate
 
522
  logger.info('Resuming training from %s' % ckpt)
523
  else:
524
  # opt.hyp = opt.hyp or ('hyp.finetune.yaml' if opt.weights else 'hyp.scratch.yaml')
 
531
  # DDP mode
532
  opt.total_batch_size = opt.batch_size
533
  device = select_device(opt.device, batch_size=opt.batch_size)
534
+ if LOCAL_RANK != -1:
535
+ from datetime import timedelta
536
+ assert torch.cuda.device_count() > LOCAL_RANK, 'too few GPUS for DDP command'
537
+ torch.cuda.set_device(LOCAL_RANK)
538
+ device = torch.device('cuda', LOCAL_RANK)
539
+ dist.init_process_group(backend="gloo", timeout=timedelta(seconds=60))
540
+ assert opt.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
541
  assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
542
+ opt.batch_size = opt.total_batch_size // WORLD_SIZE
543
 
544
  # Train
 
545
  if not opt.evolve:
546
  train(opt.hyp, opt, device)
547
+ if WORLD_SIZE > 1 and RANK == 0:
548
+ _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
549
 
550
  # Evolve hyperparameters (optional)
551
  else:
 
581
 
582
  with open(opt.hyp) as f:
583
  hyp = yaml.safe_load(f) # load hyps dict
584
+ assert LOCAL_RANK == -1, 'DDP mode not implemented for --evolve'
585
  opt.notest, opt.nosave = True, True # only test/save final epoch
586
  # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
587
  yaml_file = Path(opt.save_dir) / 'hyp_evolved.yaml' # save best result here
utils/datasets.py CHANGED
@@ -64,7 +64,7 @@ def exif_size(img):
64
 
65
 
66
  def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
67
- rect=False, rank=-1, world_size=1, workers=8, image_weights=False, quad=False, prefix=''):
68
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
69
  with torch_distributed_zero_first(rank):
70
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
@@ -79,7 +79,7 @@ def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=Non
79
  prefix=prefix)
80
 
81
  batch_size = min(batch_size, len(dataset))
82
- nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
83
  sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
84
  loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
85
  # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
 
64
 
65
 
66
  def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
67
+ rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
68
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
69
  with torch_distributed_zero_first(rank):
70
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
 
79
  prefix=prefix)
80
 
81
  batch_size = min(batch_size, len(dataset))
82
+ nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers]) # number of workers
83
  sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
84
  loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
85
  # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
utils/torch_utils.py CHANGED
@@ -13,6 +13,7 @@ from pathlib import Path
13
 
14
  import torch
15
  import torch.backends.cudnn as cudnn
 
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
  import torchvision
@@ -30,10 +31,10 @@ def torch_distributed_zero_first(local_rank: int):
30
  Decorator to make all processes in distributed training wait for each local_master to do something.
31
  """
32
  if local_rank not in [-1, 0]:
33
- torch.distributed.barrier()
34
  yield
35
  if local_rank == 0:
36
- torch.distributed.barrier()
37
 
38
 
39
  def init_torch_seeds(seed=0):
 
13
 
14
  import torch
15
  import torch.backends.cudnn as cudnn
16
+ import torch.distributed as dist
17
  import torch.nn as nn
18
  import torch.nn.functional as F
19
  import torchvision
 
31
  Decorator to make all processes in distributed training wait for each local_master to do something.
32
  """
33
  if local_rank not in [-1, 0]:
34
+ dist.barrier()
35
  yield
36
  if local_rank == 0:
37
+ dist.barrier()
38
 
39
 
40
  def init_torch_seeds(seed=0):
utils/wandb_logging/wandb_utils.py CHANGED
@@ -1,5 +1,6 @@
1
  """Utilities and tools for tracking runs with Weights & Biases."""
2
  import logging
 
3
  import sys
4
  from contextlib import contextmanager
5
  from pathlib import Path
@@ -18,6 +19,7 @@ try:
18
  except ImportError:
19
  wandb = None
20
 
 
21
  WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
22
 
23
 
@@ -42,10 +44,10 @@ def get_run_info(run_path):
42
 
43
 
44
  def check_wandb_resume(opt):
45
- process_wandb_config_ddp_mode(opt) if opt.global_rank not in [-1, 0] else None
46
  if isinstance(opt.resume, str):
47
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
48
- if opt.global_rank not in [-1, 0]: # For resuming DDP runs
49
  entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
50
  api = wandb.Api()
51
  artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')
 
1
  """Utilities and tools for tracking runs with Weights & Biases."""
2
  import logging
3
+ import os
4
  import sys
5
  from contextlib import contextmanager
6
  from pathlib import Path
 
19
  except ImportError:
20
  wandb = None
21
 
22
+ RANK = int(os.getenv('RANK', -1))
23
  WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
24
 
25
 
 
44
 
45
 
46
  def check_wandb_resume(opt):
47
+ process_wandb_config_ddp_mode(opt) if RANK not in [-1, 0] else None
48
  if isinstance(opt.resume, str):
49
  if opt.resume.startswith(WANDB_ARTIFACT_PREFIX):
50
+ if RANK not in [-1, 0]: # For resuming DDP runs
51
  entity, project, run_id, model_artifact_name = get_run_info(opt.resume)
52
  api = wandb.Api()
53
  artifact = api.artifact(entity + '/' + project + '/' + model_artifact_name + ':latest')