glenn-jocher commited on
Commit
fa201f9
1 Parent(s): 6d6e2ca

Update `train(hyp, *args)` to accept `hyp` file or dict (#3668)

Browse files
Files changed (1) hide show
  1. train.py +10 -7
train.py CHANGED
@@ -39,12 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
39
  logger = logging.getLogger(__name__)
40
 
41
 
42
- def train(hyp,
43
  opt,
44
  device,
45
  tb_writer=None
46
  ):
47
- logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
48
  save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
49
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
50
  opt.single_cls
@@ -56,6 +55,12 @@ def train(hyp,
56
  best = wdir / 'best.pt'
57
  results_file = save_dir / 'results.txt'
58
 
 
 
 
 
 
 
59
  # Save run settings
60
  with open(save_dir / 'hyp.yaml', 'w') as f:
61
  yaml.safe_dump(hyp, f, sort_keys=False)
@@ -529,10 +534,6 @@ if __name__ == '__main__':
529
  assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
530
  opt.batch_size = opt.total_batch_size // opt.world_size
531
 
532
- # Hyperparameters
533
- with open(opt.hyp) as f:
534
- hyp = yaml.safe_load(f) # load hyps
535
-
536
  # Train
537
  logger.info(opt)
538
  if not opt.evolve:
@@ -541,7 +542,7 @@ if __name__ == '__main__':
541
  prefix = colorstr('tensorboard: ')
542
  logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
543
  tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
544
- train(hyp, opt, device, tb_writer)
545
 
546
  # Evolve hyperparameters (optional)
547
  else:
@@ -575,6 +576,8 @@ if __name__ == '__main__':
575
  'mosaic': (1, 0.0, 1.0), # image mixup (probability)
576
  'mixup': (1, 0.0, 1.0)} # image mixup (probability)
577
 
 
 
578
  assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
579
  opt.notest, opt.nosave = True, True # only test/save final epoch
580
  # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
 
39
  logger = logging.getLogger(__name__)
40
 
41
 
42
+ def train(hyp, # path/to/hyp.yaml or hyp dictionary
43
  opt,
44
  device,
45
  tb_writer=None
46
  ):
 
47
  save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
48
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
49
  opt.single_cls
 
55
  best = wdir / 'best.pt'
56
  results_file = save_dir / 'results.txt'
57
 
58
+ # Hyperparameters
59
+ if isinstance(hyp, str):
60
+ with open(hyp) as f:
61
+ hyp = yaml.safe_load(f) # load hyps dict
62
+ logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
63
+
64
  # Save run settings
65
  with open(save_dir / 'hyp.yaml', 'w') as f:
66
  yaml.safe_dump(hyp, f, sort_keys=False)
 
534
  assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
535
  opt.batch_size = opt.total_batch_size // opt.world_size
536
 
 
 
 
 
537
  # Train
538
  logger.info(opt)
539
  if not opt.evolve:
 
542
  prefix = colorstr('tensorboard: ')
543
  logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
544
  tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
545
+ train(opt.hyp, opt, device, tb_writer)
546
 
547
  # Evolve hyperparameters (optional)
548
  else:
 
576
  'mosaic': (1, 0.0, 1.0), # image mixup (probability)
577
  'mixup': (1, 0.0, 1.0)} # image mixup (probability)
578
 
579
+ with open(opt.hyp) as f:
580
+ hyp = yaml.safe_load(f) # load hyps dict
581
  assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
582
  opt.notest, opt.nosave = True, True # only test/save final epoch
583
  # ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices