glenn-jocher commited on
Commit
5afc9c2
1 Parent(s): d133968

Implement `--save-period` locally (#5047)

Browse files

This PR adds a new training argument `--save-period` to save training checkpoints every `x` epochs. To save training every 50 epochs for example:
```
python train.py --save-period 50 # saves epoch50.pt, epoch100.pt, epoch150.pt, ... etc.
```

This saved checkpoints in addition to existing last.pt and best.pt checkpoints and does not affect their behavior. Default value is -1, i.e. disabled.

Files changed (1) hide show
  1. train.py +12 -7
train.py CHANGED
@@ -382,6 +382,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
382
  torch.save(ckpt, last)
383
  if best_fitness == fi:
384
  torch.save(ckpt, best)
 
 
385
  del ckpt
386
  callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
387
 
@@ -453,20 +455,23 @@ def parse_opt(known=False):
453
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
454
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
455
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
456
- parser.add_argument('--entity', default=None, help='W&B entity')
457
  parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
458
  parser.add_argument('--name', default='exp', help='save to project/name')
459
  parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
460
  parser.add_argument('--quad', action='store_true', help='quad dataloader')
461
  parser.add_argument('--linear-lr', action='store_true', help='linear LR')
462
  parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
463
- parser.add_argument('--upload_dataset', action='store_true', help='Upload dataset as W&B artifact table')
464
- parser.add_argument('--bbox_interval', type=int, default=-1, help='Set bounding-box image logging interval for W&B')
465
- parser.add_argument('--save_period', type=int, default=-1, help='Log model after every "save_period" epoch')
466
- parser.add_argument('--artifact_alias', type=str, default="latest", help='version of dataset artifact to be used')
467
- parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
468
- parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
469
  parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
 
 
 
 
 
 
 
 
 
 
470
  opt = parser.parse_known_args()[0] if known else parser.parse_args()
471
  return opt
472
 
 
382
  torch.save(ckpt, last)
383
  if best_fitness == fi:
384
  torch.save(ckpt, best)
385
+ if (epoch > 0) and (opt.save_period > 0) and (epoch % opt.save_period == 0):
386
+ torch.save(ckpt, w / f'epoch{epoch}.pt')
387
  del ckpt
388
  callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
389
 
 
455
  parser.add_argument('--adam', action='store_true', help='use torch.optim.Adam() optimizer')
456
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
457
  parser.add_argument('--workers', type=int, default=8, help='maximum number of dataloader workers')
 
458
  parser.add_argument('--project', default=ROOT / 'runs/train', help='save to project/name')
459
  parser.add_argument('--name', default='exp', help='save to project/name')
460
  parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
461
  parser.add_argument('--quad', action='store_true', help='quad dataloader')
462
  parser.add_argument('--linear-lr', action='store_true', help='linear LR')
463
  parser.add_argument('--label-smoothing', type=float, default=0.0, help='Label smoothing epsilon')
 
 
 
 
 
 
464
  parser.add_argument('--patience', type=int, default=100, help='EarlyStopping patience (epochs without improvement)')
465
+ parser.add_argument('--freeze', type=int, default=0, help='Number of layers to freeze. backbone=10, all=24')
466
+ parser.add_argument('--save-period', type=int, default=-1, help='Save checkpoint every x epochs (disabled if < 1)')
467
+ parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
468
+
469
+ # Weights & Biases arguments
470
+ parser.add_argument('--entity', default=None, help='W&B: Entity')
471
+ parser.add_argument('--upload_dataset', action='store_true', help='W&B: Upload dataset as artifact table')
472
+ parser.add_argument('--bbox_interval', type=int, default=-1, help='W&B: Set bounding-box image logging interval')
473
+ parser.add_argument('--artifact_alias', type=str, default='latest', help='W&B: Version of dataset artifact to use')
474
+
475
  opt = parser.parse_known_args()[0] if known else parser.parse_args()
476
  return opt
477