glenn-jocher
commited on
Commit
•
fa201f9
1
Parent(s):
6d6e2ca
Update `train(hyp, *args)` to accept `hyp` file or dict (#3668)
Browse files
train.py
CHANGED
@@ -39,12 +39,11 @@ from utils.wandb_logging.wandb_utils import WandbLogger, check_wandb_resume
|
|
39 |
logger = logging.getLogger(__name__)
|
40 |
|
41 |
|
42 |
-
def train(hyp,
|
43 |
opt,
|
44 |
device,
|
45 |
tb_writer=None
|
46 |
):
|
47 |
-
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
48 |
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
49 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
50 |
opt.single_cls
|
@@ -56,6 +55,12 @@ def train(hyp,
|
|
56 |
best = wdir / 'best.pt'
|
57 |
results_file = save_dir / 'results.txt'
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
# Save run settings
|
60 |
with open(save_dir / 'hyp.yaml', 'w') as f:
|
61 |
yaml.safe_dump(hyp, f, sort_keys=False)
|
@@ -529,10 +534,6 @@ if __name__ == '__main__':
|
|
529 |
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
|
530 |
opt.batch_size = opt.total_batch_size // opt.world_size
|
531 |
|
532 |
-
# Hyperparameters
|
533 |
-
with open(opt.hyp) as f:
|
534 |
-
hyp = yaml.safe_load(f) # load hyps
|
535 |
-
|
536 |
# Train
|
537 |
logger.info(opt)
|
538 |
if not opt.evolve:
|
@@ -541,7 +542,7 @@ if __name__ == '__main__':
|
|
541 |
prefix = colorstr('tensorboard: ')
|
542 |
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
|
543 |
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
|
544 |
-
train(hyp, opt, device, tb_writer)
|
545 |
|
546 |
# Evolve hyperparameters (optional)
|
547 |
else:
|
@@ -575,6 +576,8 @@ if __name__ == '__main__':
|
|
575 |
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
|
576 |
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
|
577 |
|
|
|
|
|
578 |
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
579 |
opt.notest, opt.nosave = True, True # only test/save final epoch
|
580 |
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
|
|
|
39 |
logger = logging.getLogger(__name__)
|
40 |
|
41 |
|
42 |
+
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
43 |
opt,
|
44 |
device,
|
45 |
tb_writer=None
|
46 |
):
|
|
|
47 |
save_dir, epochs, batch_size, total_batch_size, weights, rank, single_cls = \
|
48 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank, \
|
49 |
opt.single_cls
|
|
|
55 |
best = wdir / 'best.pt'
|
56 |
results_file = save_dir / 'results.txt'
|
57 |
|
58 |
+
# Hyperparameters
|
59 |
+
if isinstance(hyp, str):
|
60 |
+
with open(hyp) as f:
|
61 |
+
hyp = yaml.safe_load(f) # load hyps dict
|
62 |
+
logger.info(colorstr('hyperparameters: ') + ', '.join(f'{k}={v}' for k, v in hyp.items()))
|
63 |
+
|
64 |
# Save run settings
|
65 |
with open(save_dir / 'hyp.yaml', 'w') as f:
|
66 |
yaml.safe_dump(hyp, f, sort_keys=False)
|
|
|
534 |
assert not opt.image_weights, '--image-weights argument is not compatible with DDP training'
|
535 |
opt.batch_size = opt.total_batch_size // opt.world_size
|
536 |
|
|
|
|
|
|
|
|
|
537 |
# Train
|
538 |
logger.info(opt)
|
539 |
if not opt.evolve:
|
|
|
542 |
prefix = colorstr('tensorboard: ')
|
543 |
logger.info(f"{prefix}Start with 'tensorboard --logdir {opt.project}', view at http://localhost:6006/")
|
544 |
tb_writer = SummaryWriter(opt.save_dir) # Tensorboard
|
545 |
+
train(opt.hyp, opt, device, tb_writer)
|
546 |
|
547 |
# Evolve hyperparameters (optional)
|
548 |
else:
|
|
|
576 |
'mosaic': (1, 0.0, 1.0), # image mixup (probability)
|
577 |
'mixup': (1, 0.0, 1.0)} # image mixup (probability)
|
578 |
|
579 |
+
with open(opt.hyp) as f:
|
580 |
+
hyp = yaml.safe_load(f) # load hyps dict
|
581 |
assert opt.local_rank == -1, 'DDP mode not implemented for --evolve'
|
582 |
opt.notest, opt.nosave = True, True # only test/save final epoch
|
583 |
# ei = [isinstance(x, (int, float)) for x in hyp.values()] # evolvable indices
|