Kalen Michael glenn-jocher commited on
Commit
2317f86
1 Parent(s): 5487451

Optimised Callback Class to Reduce Code and Fix Errors (#4688)

Browse files

* added callbacks

* added back callback to main

* added save_dir to callback output

* reduced code count

* updated callbacks

* added default callback class to main, added missing parameters to on_model_save

* Glenn updates

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (3) hide show
  1. train.py +10 -10
  2. utils/callbacks.py +10 -113
  3. val.py +2 -2
train.py CHANGED
@@ -56,7 +56,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
56
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
57
  opt,
58
  device,
59
- callbacks=Callbacks()
60
  ):
61
  save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
62
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
@@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
231
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
232
  model.half().float() # pre-reduce anchor precision
233
 
234
- callbacks.on_pretrain_routine_end()
235
 
236
  # DDP mode
237
  if cuda and RANK != -1:
@@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
333
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
334
  pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
335
  f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
336
- callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots, opt.sync_bn)
337
  # end batch ------------------------------------------------------------------------------------------------
338
 
339
  # Scheduler
@@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
342
 
343
  if RANK in [-1, 0]:
344
  # mAP
345
- callbacks.on_train_epoch_end(epoch=epoch)
346
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
347
  final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
348
  if not noval or final_epoch: # Calculate mAP
@@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
364
  if fi > best_fitness:
365
  best_fitness = fi
366
  log_vals = list(mloss) + list(results) + lr
367
- callbacks.on_fit_epoch_end(log_vals, epoch, best_fitness, fi)
368
 
369
  # Save model
370
  if (not nosave) or (final_epoch and not evolve): # if save
@@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
381
  if best_fitness == fi:
382
  torch.save(ckpt, best)
383
  del ckpt
384
- callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
385
 
386
  # Stop Single-GPU
387
  if RANK == -1 and stopper(epoch=epoch, fitness=fi):
@@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
418
  for f in last, best:
419
  if f.exists():
420
  strip_optimizer(f) # strip optimizers
421
- callbacks.on_train_end(last, best, plots, epoch)
422
  LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
423
 
424
  torch.cuda.empty_cache()
@@ -467,7 +467,7 @@ def parse_opt(known=False):
467
  return opt
468
 
469
 
470
- def main(opt):
471
  # Checks
472
  set_logging(RANK)
473
  if RANK in [-1, 0]:
@@ -505,7 +505,7 @@ def main(opt):
505
 
506
  # Train
507
  if not opt.evolve:
508
- train(opt.hyp, opt, device)
509
  if WORLD_SIZE > 1 and RANK == 0:
510
  _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
511
 
@@ -585,7 +585,7 @@ def main(opt):
585
  hyp[k] = round(hyp[k], 5) # significant digits
586
 
587
  # Train mutation
588
- results = train(hyp.copy(), opt, device)
589
 
590
  # Write mutation results
591
  print_mutation(results, hyp.copy(), save_dir, opt.bucket)
 
56
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
57
  opt,
58
  device,
59
+ callbacks
60
  ):
61
  save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
62
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
 
231
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
232
  model.half().float() # pre-reduce anchor precision
233
 
234
+ callbacks.run('on_pretrain_routine_end')
235
 
236
  # DDP mode
237
  if cuda and RANK != -1:
 
333
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
334
  pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
335
  f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
336
+ callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
337
  # end batch ------------------------------------------------------------------------------------------------
338
 
339
  # Scheduler
 
342
 
343
  if RANK in [-1, 0]:
344
  # mAP
345
+ callbacks.run('on_train_epoch_end', epoch=epoch)
346
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
347
  final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
348
  if not noval or final_epoch: # Calculate mAP
 
364
  if fi > best_fitness:
365
  best_fitness = fi
366
  log_vals = list(mloss) + list(results) + lr
367
+ callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
368
 
369
  # Save model
370
  if (not nosave) or (final_epoch and not evolve): # if save
 
381
  if best_fitness == fi:
382
  torch.save(ckpt, best)
383
  del ckpt
384
+ callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
385
 
386
  # Stop Single-GPU
387
  if RANK == -1 and stopper(epoch=epoch, fitness=fi):
 
418
  for f in last, best:
419
  if f.exists():
420
  strip_optimizer(f) # strip optimizers
421
+ callbacks.run('on_train_end', last, best, plots, epoch)
422
  LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
423
 
424
  torch.cuda.empty_cache()
 
467
  return opt
468
 
469
 
470
+ def main(opt, callbacks=Callbacks()):
471
  # Checks
472
  set_logging(RANK)
473
  if RANK in [-1, 0]:
 
505
 
506
  # Train
507
  if not opt.evolve:
508
+ train(opt.hyp, opt, device, callbacks)
509
  if WORLD_SIZE > 1 and RANK == 0:
510
  _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
511
 
 
585
  hyp[k] = round(hyp[k], 5) # significant digits
586
 
587
  # Train mutation
588
+ results = train(hyp.copy(), opt, device, callbacks)
589
 
590
  # Write mutation results
591
  print_mutation(results, hyp.copy(), save_dir, opt.bucket)
utils/callbacks.py CHANGED
@@ -9,6 +9,7 @@ class Callbacks:
9
  Handles all registered callbacks for YOLOv5 Hooks
10
  """
11
 
 
12
  _callbacks = {
13
  'on_pretrain_routine_start': [],
14
  'on_pretrain_routine_end': [],
@@ -34,16 +35,13 @@ class Callbacks:
34
  'teardown': [],
35
  }
36
 
37
- def __init__(self):
38
- return
39
-
40
  def register_action(self, hook, name='', callback=None):
41
  """
42
  Register a new action to a callback hook
43
 
44
  Args:
45
  hook The callback hook name to register the action to
46
- name The name of the action
47
  callback The callback to fire
48
  """
49
  assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
@@ -62,118 +60,17 @@ class Callbacks:
62
  else:
63
  return self._callbacks
64
 
65
- def run_callbacks(self, hook, *args, **kwargs):
66
  """
67
  Loop through the registered actions and fire all callbacks
68
- """
69
- for logger in self._callbacks[hook]:
70
- # print(f"Running callbacks.{logger['callback'].__name__}()")
71
- logger['callback'](*args, **kwargs)
72
-
73
- def on_pretrain_routine_start(self, *args, **kwargs):
74
- """
75
- Fires all registered callbacks at the start of each pretraining routine
76
- """
77
- self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
78
-
79
- def on_pretrain_routine_end(self, *args, **kwargs):
80
- """
81
- Fires all registered callbacks at the end of each pretraining routine
82
- """
83
- self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
84
-
85
- def on_train_start(self, *args, **kwargs):
86
- """
87
- Fires all registered callbacks at the start of each training
88
- """
89
- self.run_callbacks('on_train_start', *args, **kwargs)
90
-
91
- def on_train_epoch_start(self, *args, **kwargs):
92
- """
93
- Fires all registered callbacks at the start of each training epoch
94
- """
95
- self.run_callbacks('on_train_epoch_start', *args, **kwargs)
96
-
97
- def on_train_batch_start(self, *args, **kwargs):
98
- """
99
- Fires all registered callbacks at the start of each training batch
100
- """
101
- self.run_callbacks('on_train_batch_start', *args, **kwargs)
102
 
103
- def optimizer_step(self, *args, **kwargs):
104
- """
105
- Fires all registered callbacks on each optimizer step
106
- """
107
- self.run_callbacks('optimizer_step', *args, **kwargs)
108
-
109
- def on_before_zero_grad(self, *args, **kwargs):
110
- """
111
- Fires all registered callbacks before zero grad
112
- """
113
- self.run_callbacks('on_before_zero_grad', *args, **kwargs)
114
-
115
- def on_train_batch_end(self, *args, **kwargs):
116
- """
117
- Fires all registered callbacks at the end of each training batch
118
- """
119
- self.run_callbacks('on_train_batch_end', *args, **kwargs)
120
-
121
- def on_train_epoch_end(self, *args, **kwargs):
122
- """
123
- Fires all registered callbacks at the end of each training epoch
124
- """
125
- self.run_callbacks('on_train_epoch_end', *args, **kwargs)
126
-
127
- def on_val_start(self, *args, **kwargs):
128
- """
129
- Fires all registered callbacks at the start of the validation
130
- """
131
- self.run_callbacks('on_val_start', *args, **kwargs)
132
-
133
- def on_val_batch_start(self, *args, **kwargs):
134
- """
135
- Fires all registered callbacks at the start of each validation batch
136
- """
137
- self.run_callbacks('on_val_batch_start', *args, **kwargs)
138
-
139
- def on_val_image_end(self, *args, **kwargs):
140
- """
141
- Fires all registered callbacks at the end of each val image
142
- """
143
- self.run_callbacks('on_val_image_end', *args, **kwargs)
144
-
145
- def on_val_batch_end(self, *args, **kwargs):
146
- """
147
- Fires all registered callbacks at the end of each validation batch
148
- """
149
- self.run_callbacks('on_val_batch_end', *args, **kwargs)
150
-
151
- def on_val_end(self, *args, **kwargs):
152
- """
153
- Fires all registered callbacks at the end of the validation
154
- """
155
- self.run_callbacks('on_val_end', *args, **kwargs)
156
-
157
- def on_fit_epoch_end(self, *args, **kwargs):
158
- """
159
- Fires all registered callbacks at the end of each fit (train+val) epoch
160
- """
161
- self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
162
-
163
- def on_model_save(self, *args, **kwargs):
164
- """
165
- Fires all registered callbacks after each model save
166
  """
167
- self.run_callbacks('on_model_save', *args, **kwargs)
168
 
169
- def on_train_end(self, *args, **kwargs):
170
- """
171
- Fires all registered callbacks at the end of training
172
- """
173
- self.run_callbacks('on_train_end', *args, **kwargs)
174
 
175
- def teardown(self, *args, **kwargs):
176
- """
177
- Fires all registered callbacks before teardown
178
- """
179
- self.run_callbacks('teardown', *args, **kwargs)
 
9
  Handles all registered callbacks for YOLOv5 Hooks
10
  """
11
 
12
+ # Define the available callbacks
13
  _callbacks = {
14
  'on_pretrain_routine_start': [],
15
  'on_pretrain_routine_end': [],
 
35
  'teardown': [],
36
  }
37
 
 
 
 
38
  def register_action(self, hook, name='', callback=None):
39
  """
40
  Register a new action to a callback hook
41
 
42
  Args:
43
  hook The callback hook name to register the action to
44
+ name The name of the action for later reference
45
  callback The callback to fire
46
  """
47
  assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
 
60
  else:
61
  return self._callbacks
62
 
63
+ def run(self, hook, *args, **kwargs):
64
  """
65
  Loop through the registered actions and fire all callbacks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
+ Args:
68
+ hook The name of the hook to check, defaults to all
69
+ args Arguments to receive from YOLOv5
70
+ kwargs Keyword Arguments to receive from YOLOv5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
 
72
 
73
+ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
 
 
 
 
74
 
75
+ for logger in self._callbacks[hook]:
76
+ logger['callback'](*args, **kwargs)
 
 
 
val.py CHANGED
@@ -216,7 +216,7 @@ def run(data,
216
  save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
217
  if save_json:
218
  save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
219
- callbacks.on_val_image_end(pred, predn, path, names, img[si])
220
 
221
  # Plot images
222
  if plots and batch_i < 3:
@@ -253,7 +253,7 @@ def run(data,
253
  # Plots
254
  if plots:
255
  confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
256
- callbacks.on_val_end()
257
 
258
  # Save JSON
259
  if save_json and len(jdict):
 
216
  save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
217
  if save_json:
218
  save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
219
+ callbacks.run('on_val_image_end', pred, predn, path, names, img[si])
220
 
221
  # Plot images
222
  if plots and batch_i < 3:
 
253
  # Plots
254
  if plots:
255
  confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
256
+ callbacks.run('on_val_end')
257
 
258
  # Save JSON
259
  if save_json and len(jdict):