yellowdolphin glenn-jocher commited on
Commit
3974d72
1 Parent(s): 5e976a2

Fix warmup `accumulate` (#3722)

Browse files

* gradient accumulation during warmup in train.py

Context:
`accumulate` is the number of batches/gradients accumulated before calling the next optimizer.step().
During warmup, it is ramped up from 1 to the final value nbs / batch_size.
Although I have not seen this in other libraries, I like the idea. During warmup, as grads are large, too large steps are more of on issue than gradient noise due to small steps.

The bug:
The condition to perform the opt step is wrong
> if ni % accumulate == 0:
This produces irregular step sizes if `accumulate` is not constant. It becomes relevant when batch_size is small and `accumulate` changes many times during warmup.

This demo also shows the proposed solution, to use a ">=" condition instead:
https://colab.research.google.com/drive/1MA2z2eCXYB_BC5UZqgXueqL_y1Tz_XVq?usp=sharing

Further, I propose not to restrict the number of warmup iterations to >= 1000. If the user changes hyp['warmup_epochs'], this causes unexpected behavior. Also, it makes evolution unstable if this parameter was to be optimized.

* replace last_opt_step tracking by do_step(ni)

* add docstrings

* move down nw

* Update train.py

* revert math import move

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. train.py +3 -1
train.py CHANGED
@@ -270,6 +270,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
270
  t0 = time.time()
271
  nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
272
  # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
 
273
  maps = np.zeros(nc) # mAP per class
274
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
275
  scheduler.last_epoch = start_epoch - 1 # do not move
@@ -344,12 +345,13 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
344
  scaler.scale(loss).backward()
345
 
346
  # Optimize
347
- if ni % accumulate == 0:
348
  scaler.step(optimizer) # optimizer.step
349
  scaler.update()
350
  optimizer.zero_grad()
351
  if ema:
352
  ema.update(model)
 
353
 
354
  # Print
355
  if RANK in [-1, 0]:
 
270
  t0 = time.time()
271
  nw = max(round(hyp['warmup_epochs'] * nb), 1000) # number of warmup iterations, max(3 epochs, 1k iterations)
272
  # nw = min(nw, (epochs - start_epoch) / 2 * nb) # limit warmup to < 1/2 of training
273
+ last_opt_step = -1
274
  maps = np.zeros(nc) # mAP per class
275
  results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
276
  scheduler.last_epoch = start_epoch - 1 # do not move
 
345
  scaler.scale(loss).backward()
346
 
347
  # Optimize
348
+ if ni - last_opt_step >= accumulate:
349
  scaler.step(optimizer) # optimizer.step
350
  scaler.update()
351
  optimizer.zero_grad()
352
  if ema:
353
  ema.update(model)
354
+ last_opt_step = ni
355
 
356
  # Print
357
  if RANK in [-1, 0]: