Commit
•
2317f86
1
Parent(s):
5487451
Optimised Callback Class to Reduce Code and Fix Errors (#4688)
Browse files* added callbacks
* added back callback to main
* added save_dir to callback output
* reduced code count
* updated callbacks
* added default callback class to main, added missing parameters to on_model_save
* Glenn updates
Co-authored-by: Glenn Jocher <[email protected]>
- train.py +10 -10
- utils/callbacks.py +10 -113
- val.py +2 -2
train.py
CHANGED
@@ -56,7 +56,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
|
|
56 |
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
57 |
opt,
|
58 |
device,
|
59 |
-
callbacks
|
60 |
):
|
61 |
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
62 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
@@ -231,7 +231,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
231 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
232 |
model.half().float() # pre-reduce anchor precision
|
233 |
|
234 |
-
callbacks.on_pretrain_routine_end
|
235 |
|
236 |
# DDP mode
|
237 |
if cuda and RANK != -1:
|
@@ -333,7 +333,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
333 |
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
334 |
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
335 |
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
336 |
-
callbacks.on_train_batch_end
|
337 |
# end batch ------------------------------------------------------------------------------------------------
|
338 |
|
339 |
# Scheduler
|
@@ -342,7 +342,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
342 |
|
343 |
if RANK in [-1, 0]:
|
344 |
# mAP
|
345 |
-
callbacks.on_train_epoch_end
|
346 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
347 |
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
|
348 |
if not noval or final_epoch: # Calculate mAP
|
@@ -364,7 +364,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
364 |
if fi > best_fitness:
|
365 |
best_fitness = fi
|
366 |
log_vals = list(mloss) + list(results) + lr
|
367 |
-
callbacks.on_fit_epoch_end
|
368 |
|
369 |
# Save model
|
370 |
if (not nosave) or (final_epoch and not evolve): # if save
|
@@ -381,7 +381,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
381 |
if best_fitness == fi:
|
382 |
torch.save(ckpt, best)
|
383 |
del ckpt
|
384 |
-
callbacks.on_model_save
|
385 |
|
386 |
# Stop Single-GPU
|
387 |
if RANK == -1 and stopper(epoch=epoch, fitness=fi):
|
@@ -418,7 +418,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
|
418 |
for f in last, best:
|
419 |
if f.exists():
|
420 |
strip_optimizer(f) # strip optimizers
|
421 |
-
callbacks.on_train_end
|
422 |
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
423 |
|
424 |
torch.cuda.empty_cache()
|
@@ -467,7 +467,7 @@ def parse_opt(known=False):
|
|
467 |
return opt
|
468 |
|
469 |
|
470 |
-
def main(opt):
|
471 |
# Checks
|
472 |
set_logging(RANK)
|
473 |
if RANK in [-1, 0]:
|
@@ -505,7 +505,7 @@ def main(opt):
|
|
505 |
|
506 |
# Train
|
507 |
if not opt.evolve:
|
508 |
-
train(opt.hyp, opt, device)
|
509 |
if WORLD_SIZE > 1 and RANK == 0:
|
510 |
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
|
511 |
|
@@ -585,7 +585,7 @@ def main(opt):
|
|
585 |
hyp[k] = round(hyp[k], 5) # significant digits
|
586 |
|
587 |
# Train mutation
|
588 |
-
results = train(hyp.copy(), opt, device)
|
589 |
|
590 |
# Write mutation results
|
591 |
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
|
|
|
56 |
def train(hyp, # path/to/hyp.yaml or hyp dictionary
|
57 |
opt,
|
58 |
device,
|
59 |
+
callbacks
|
60 |
):
|
61 |
save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
|
62 |
Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
|
|
|
231 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
232 |
model.half().float() # pre-reduce anchor precision
|
233 |
|
234 |
+
callbacks.run('on_pretrain_routine_end')
|
235 |
|
236 |
# DDP mode
|
237 |
if cuda and RANK != -1:
|
|
|
333 |
mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
|
334 |
pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
|
335 |
f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
|
336 |
+
callbacks.run('on_train_batch_end', ni, model, imgs, targets, paths, plots, opt.sync_bn)
|
337 |
# end batch ------------------------------------------------------------------------------------------------
|
338 |
|
339 |
# Scheduler
|
|
|
342 |
|
343 |
if RANK in [-1, 0]:
|
344 |
# mAP
|
345 |
+
callbacks.run('on_train_epoch_end', epoch=epoch)
|
346 |
ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
|
347 |
final_epoch = (epoch + 1 == epochs) or stopper.possible_stop
|
348 |
if not noval or final_epoch: # Calculate mAP
|
|
|
364 |
if fi > best_fitness:
|
365 |
best_fitness = fi
|
366 |
log_vals = list(mloss) + list(results) + lr
|
367 |
+
callbacks.run('on_fit_epoch_end', log_vals, epoch, best_fitness, fi)
|
368 |
|
369 |
# Save model
|
370 |
if (not nosave) or (final_epoch and not evolve): # if save
|
|
|
381 |
if best_fitness == fi:
|
382 |
torch.save(ckpt, best)
|
383 |
del ckpt
|
384 |
+
callbacks.run('on_model_save', last, epoch, final_epoch, best_fitness, fi)
|
385 |
|
386 |
# Stop Single-GPU
|
387 |
if RANK == -1 and stopper(epoch=epoch, fitness=fi):
|
|
|
418 |
for f in last, best:
|
419 |
if f.exists():
|
420 |
strip_optimizer(f) # strip optimizers
|
421 |
+
callbacks.run('on_train_end', last, best, plots, epoch)
|
422 |
LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
|
423 |
|
424 |
torch.cuda.empty_cache()
|
|
|
467 |
return opt
|
468 |
|
469 |
|
470 |
+
def main(opt, callbacks=Callbacks()):
|
471 |
# Checks
|
472 |
set_logging(RANK)
|
473 |
if RANK in [-1, 0]:
|
|
|
505 |
|
506 |
# Train
|
507 |
if not opt.evolve:
|
508 |
+
train(opt.hyp, opt, device, callbacks)
|
509 |
if WORLD_SIZE > 1 and RANK == 0:
|
510 |
_ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
|
511 |
|
|
|
585 |
hyp[k] = round(hyp[k], 5) # significant digits
|
586 |
|
587 |
# Train mutation
|
588 |
+
results = train(hyp.copy(), opt, device, callbacks)
|
589 |
|
590 |
# Write mutation results
|
591 |
print_mutation(results, hyp.copy(), save_dir, opt.bucket)
|
utils/callbacks.py
CHANGED
@@ -9,6 +9,7 @@ class Callbacks:
|
|
9 |
Handles all registered callbacks for YOLOv5 Hooks
|
10 |
"""
|
11 |
|
|
|
12 |
_callbacks = {
|
13 |
'on_pretrain_routine_start': [],
|
14 |
'on_pretrain_routine_end': [],
|
@@ -34,16 +35,13 @@ class Callbacks:
|
|
34 |
'teardown': [],
|
35 |
}
|
36 |
|
37 |
-
def __init__(self):
|
38 |
-
return
|
39 |
-
|
40 |
def register_action(self, hook, name='', callback=None):
|
41 |
"""
|
42 |
Register a new action to a callback hook
|
43 |
|
44 |
Args:
|
45 |
hook The callback hook name to register the action to
|
46 |
-
name The name of the action
|
47 |
callback The callback to fire
|
48 |
"""
|
49 |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
@@ -62,118 +60,17 @@ class Callbacks:
|
|
62 |
else:
|
63 |
return self._callbacks
|
64 |
|
65 |
-
def
|
66 |
"""
|
67 |
Loop through the registered actions and fire all callbacks
|
68 |
-
"""
|
69 |
-
for logger in self._callbacks[hook]:
|
70 |
-
# print(f"Running callbacks.{logger['callback'].__name__}()")
|
71 |
-
logger['callback'](*args, **kwargs)
|
72 |
-
|
73 |
-
def on_pretrain_routine_start(self, *args, **kwargs):
|
74 |
-
"""
|
75 |
-
Fires all registered callbacks at the start of each pretraining routine
|
76 |
-
"""
|
77 |
-
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
|
78 |
-
|
79 |
-
def on_pretrain_routine_end(self, *args, **kwargs):
|
80 |
-
"""
|
81 |
-
Fires all registered callbacks at the end of each pretraining routine
|
82 |
-
"""
|
83 |
-
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
|
84 |
-
|
85 |
-
def on_train_start(self, *args, **kwargs):
|
86 |
-
"""
|
87 |
-
Fires all registered callbacks at the start of each training
|
88 |
-
"""
|
89 |
-
self.run_callbacks('on_train_start', *args, **kwargs)
|
90 |
-
|
91 |
-
def on_train_epoch_start(self, *args, **kwargs):
|
92 |
-
"""
|
93 |
-
Fires all registered callbacks at the start of each training epoch
|
94 |
-
"""
|
95 |
-
self.run_callbacks('on_train_epoch_start', *args, **kwargs)
|
96 |
-
|
97 |
-
def on_train_batch_start(self, *args, **kwargs):
|
98 |
-
"""
|
99 |
-
Fires all registered callbacks at the start of each training batch
|
100 |
-
"""
|
101 |
-
self.run_callbacks('on_train_batch_start', *args, **kwargs)
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
self.run_callbacks('optimizer_step', *args, **kwargs)
|
108 |
-
|
109 |
-
def on_before_zero_grad(self, *args, **kwargs):
|
110 |
-
"""
|
111 |
-
Fires all registered callbacks before zero grad
|
112 |
-
"""
|
113 |
-
self.run_callbacks('on_before_zero_grad', *args, **kwargs)
|
114 |
-
|
115 |
-
def on_train_batch_end(self, *args, **kwargs):
|
116 |
-
"""
|
117 |
-
Fires all registered callbacks at the end of each training batch
|
118 |
-
"""
|
119 |
-
self.run_callbacks('on_train_batch_end', *args, **kwargs)
|
120 |
-
|
121 |
-
def on_train_epoch_end(self, *args, **kwargs):
|
122 |
-
"""
|
123 |
-
Fires all registered callbacks at the end of each training epoch
|
124 |
-
"""
|
125 |
-
self.run_callbacks('on_train_epoch_end', *args, **kwargs)
|
126 |
-
|
127 |
-
def on_val_start(self, *args, **kwargs):
|
128 |
-
"""
|
129 |
-
Fires all registered callbacks at the start of the validation
|
130 |
-
"""
|
131 |
-
self.run_callbacks('on_val_start', *args, **kwargs)
|
132 |
-
|
133 |
-
def on_val_batch_start(self, *args, **kwargs):
|
134 |
-
"""
|
135 |
-
Fires all registered callbacks at the start of each validation batch
|
136 |
-
"""
|
137 |
-
self.run_callbacks('on_val_batch_start', *args, **kwargs)
|
138 |
-
|
139 |
-
def on_val_image_end(self, *args, **kwargs):
|
140 |
-
"""
|
141 |
-
Fires all registered callbacks at the end of each val image
|
142 |
-
"""
|
143 |
-
self.run_callbacks('on_val_image_end', *args, **kwargs)
|
144 |
-
|
145 |
-
def on_val_batch_end(self, *args, **kwargs):
|
146 |
-
"""
|
147 |
-
Fires all registered callbacks at the end of each validation batch
|
148 |
-
"""
|
149 |
-
self.run_callbacks('on_val_batch_end', *args, **kwargs)
|
150 |
-
|
151 |
-
def on_val_end(self, *args, **kwargs):
|
152 |
-
"""
|
153 |
-
Fires all registered callbacks at the end of the validation
|
154 |
-
"""
|
155 |
-
self.run_callbacks('on_val_end', *args, **kwargs)
|
156 |
-
|
157 |
-
def on_fit_epoch_end(self, *args, **kwargs):
|
158 |
-
"""
|
159 |
-
Fires all registered callbacks at the end of each fit (train+val) epoch
|
160 |
-
"""
|
161 |
-
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
|
162 |
-
|
163 |
-
def on_model_save(self, *args, **kwargs):
|
164 |
-
"""
|
165 |
-
Fires all registered callbacks after each model save
|
166 |
"""
|
167 |
-
self.run_callbacks('on_model_save', *args, **kwargs)
|
168 |
|
169 |
-
|
170 |
-
"""
|
171 |
-
Fires all registered callbacks at the end of training
|
172 |
-
"""
|
173 |
-
self.run_callbacks('on_train_end', *args, **kwargs)
|
174 |
|
175 |
-
|
176 |
-
|
177 |
-
Fires all registered callbacks before teardown
|
178 |
-
"""
|
179 |
-
self.run_callbacks('teardown', *args, **kwargs)
|
|
|
9 |
Handles all registered callbacks for YOLOv5 Hooks
|
10 |
"""
|
11 |
|
12 |
+
# Define the available callbacks
|
13 |
_callbacks = {
|
14 |
'on_pretrain_routine_start': [],
|
15 |
'on_pretrain_routine_end': [],
|
|
|
35 |
'teardown': [],
|
36 |
}
|
37 |
|
|
|
|
|
|
|
38 |
def register_action(self, hook, name='', callback=None):
|
39 |
"""
|
40 |
Register a new action to a callback hook
|
41 |
|
42 |
Args:
|
43 |
hook The callback hook name to register the action to
|
44 |
+
name The name of the action for later reference
|
45 |
callback The callback to fire
|
46 |
"""
|
47 |
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
|
|
60 |
else:
|
61 |
return self._callbacks
|
62 |
|
63 |
+
def run(self, hook, *args, **kwargs):
|
64 |
"""
|
65 |
Loop through the registered actions and fire all callbacks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
+
Args:
|
68 |
+
hook The name of the hook to check, defaults to all
|
69 |
+
args Arguments to receive from YOLOv5
|
70 |
+
kwargs Keyword Arguments to receive from YOLOv5
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
"""
|
|
|
72 |
|
73 |
+
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
for logger in self._callbacks[hook]:
|
76 |
+
logger['callback'](*args, **kwargs)
|
|
|
|
|
|
val.py
CHANGED
@@ -216,7 +216,7 @@ def run(data,
|
|
216 |
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
217 |
if save_json:
|
218 |
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
219 |
-
callbacks.on_val_image_end
|
220 |
|
221 |
# Plot images
|
222 |
if plots and batch_i < 3:
|
@@ -253,7 +253,7 @@ def run(data,
|
|
253 |
# Plots
|
254 |
if plots:
|
255 |
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
256 |
-
callbacks.on_val_end
|
257 |
|
258 |
# Save JSON
|
259 |
if save_json and len(jdict):
|
|
|
216 |
save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
|
217 |
if save_json:
|
218 |
save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
|
219 |
+
callbacks.run('on_val_image_end', pred, predn, path, names, img[si])
|
220 |
|
221 |
# Plot images
|
222 |
if plots and batch_i < 3:
|
|
|
253 |
# Plots
|
254 |
if plots:
|
255 |
confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
|
256 |
+
callbacks.run('on_val_end')
|
257 |
|
258 |
# Save JSON
|
259 |
if save_json and len(jdict):
|