Ayush Chaurasia commited on
Commit
1bf9365
1 Parent(s): 0d891c6

W&B DDP fix (#2574)

Browse files
Files changed (2) hide show
  1. train.py +5 -3
  2. utils/wandb_logging/wandb_utils.py +4 -1
train.py CHANGED
@@ -66,14 +66,16 @@ def train(hyp, opt, device, tb_writer=None):
66
  is_coco = opt.data.endswith('coco.yaml')
67
 
68
  # Logging- Doing this before checking the dataset. Might update data_dict
 
69
  if rank in [-1, 0]:
70
  opt.hyp = hyp # add hyperparameters
71
  run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
72
  wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
 
73
  data_dict = wandb_logger.data_dict
74
  if wandb_logger.wandb:
75
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
76
- loggers = {'wandb': wandb_logger.wandb} # loggers dict
77
  nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
78
  names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
79
  assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
@@ -381,6 +383,7 @@ def train(hyp, opt, device, tb_writer=None):
381
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
382
  if fi > best_fitness:
383
  best_fitness = fi
 
384
 
385
  # Save model
386
  if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
@@ -402,7 +405,6 @@ def train(hyp, opt, device, tb_writer=None):
402
  wandb_logger.log_model(
403
  last.parent, opt, epoch, fi, best_model=best_fitness == fi)
404
  del ckpt
405
- wandb_logger.end_epoch(best_result=best_fitness == fi)
406
 
407
  # end epoch ----------------------------------------------------------------------------------------------------
408
  # end training
@@ -442,10 +444,10 @@ def train(hyp, opt, device, tb_writer=None):
442
  wandb_logger.wandb.log_artifact(str(final), type='model',
443
  name='run_' + wandb_logger.wandb_run.id + '_model',
444
  aliases=['last', 'best', 'stripped'])
 
445
  else:
446
  dist.destroy_process_group()
447
  torch.cuda.empty_cache()
448
- wandb_logger.finish_run()
449
  return results
450
 
451
 
 
66
  is_coco = opt.data.endswith('coco.yaml')
67
 
68
  # Logging- Doing this before checking the dataset. Might update data_dict
69
+ loggers = {'wandb': None} # loggers dict
70
  if rank in [-1, 0]:
71
  opt.hyp = hyp # add hyperparameters
72
  run_id = torch.load(weights).get('wandb_id') if weights.endswith('.pt') and os.path.isfile(weights) else None
73
  wandb_logger = WandbLogger(opt, Path(opt.save_dir).stem, run_id, data_dict)
74
+ loggers['wandb'] = wandb_logger.wandb
75
  data_dict = wandb_logger.data_dict
76
  if wandb_logger.wandb:
77
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp # WandbLogger might update weights, epochs if resuming
78
+
79
  nc = 1 if opt.single_cls else int(data_dict['nc']) # number of classes
80
  names = ['item'] if opt.single_cls and len(data_dict['names']) != 1 else data_dict['names'] # class names
81
  assert len(names) == nc, '%g names found for nc=%g dataset in %s' % (len(names), nc, opt.data) # check
 
383
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
384
  if fi > best_fitness:
385
  best_fitness = fi
386
+ wandb_logger.end_epoch(best_result=best_fitness == fi)
387
 
388
  # Save model
389
  if (not opt.nosave) or (final_epoch and not opt.evolve): # if save
 
405
  wandb_logger.log_model(
406
  last.parent, opt, epoch, fi, best_model=best_fitness == fi)
407
  del ckpt
 
408
 
409
  # end epoch ----------------------------------------------------------------------------------------------------
410
  # end training
 
444
  wandb_logger.wandb.log_artifact(str(final), type='model',
445
  name='run_' + wandb_logger.wandb_run.id + '_model',
446
  aliases=['last', 'best', 'stripped'])
447
+ wandb_logger.finish_run()
448
  else:
449
  dist.destroy_process_group()
450
  torch.cuda.empty_cache()
 
451
  return results
452
 
453
 
utils/wandb_logging/wandb_utils.py CHANGED
@@ -16,9 +16,9 @@ from utils.general import colorstr, xywh2xyxy, check_dataset
16
 
17
  try:
18
  import wandb
 
19
  except ImportError:
20
  wandb = None
21
- print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
22
 
23
  WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
24
 
@@ -71,6 +71,9 @@ class WandbLogger():
71
  self.data_dict = self.setup_training(opt, data_dict)
72
  if self.job_type == 'Dataset Creation':
73
  self.data_dict = self.check_and_upload_dataset(opt)
 
 
 
74
 
75
  def check_and_upload_dataset(self, opt):
76
  assert wandb, 'Install wandb to upload dataset'
 
16
 
17
  try:
18
  import wandb
19
+ from wandb import init, finish
20
  except ImportError:
21
  wandb = None
 
22
 
23
  WANDB_ARTIFACT_PREFIX = 'wandb-artifact://'
24
 
 
71
  self.data_dict = self.setup_training(opt, data_dict)
72
  if self.job_type == 'Dataset Creation':
73
  self.data_dict = self.check_and_upload_dataset(opt)
74
+ else:
75
+ print(f"{colorstr('wandb: ')}Install Weights & Biases for YOLOv5 logging with 'pip install wandb' (recommended)")
76
+
77
 
78
  def check_and_upload_dataset(self, opt):
79
  assert wandb, 'Install wandb to upload dataset'