File size: 5,775 Bytes
24bea5e b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 b74929c 4103ce9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license
"""
Callback utils
"""
class Callbacks:
""""
Handles all registered callbacks for YOLOv5 Hooks
"""
_callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],
'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],
'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],
'teardown': [],
}
def __init__(self):
return
def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook The callback hook name to register the action to
name The name of the action
callback The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})
def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook The name of the hook to check, defaults to all
"""
if hook:
return self._callbacks[hook]
else:
return self._callbacks
def run_callbacks(self, hook, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
"""
for logger in self._callbacks[hook]:
# print(f"Running callbacks.{logger['callback'].__name__}()")
logger['callback'](*args, **kwargs)
def on_pretrain_routine_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each pretraining routine
"""
self.run_callbacks('on_pretrain_routine_start', *args, **kwargs)
def on_pretrain_routine_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each pretraining routine
"""
self.run_callbacks('on_pretrain_routine_end', *args, **kwargs)
def on_train_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training
"""
self.run_callbacks('on_train_start', *args, **kwargs)
def on_train_epoch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training epoch
"""
self.run_callbacks('on_train_epoch_start', *args, **kwargs)
def on_train_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each training batch
"""
self.run_callbacks('on_train_batch_start', *args, **kwargs)
def optimizer_step(self, *args, **kwargs):
"""
Fires all registered callbacks on each optimizer step
"""
self.run_callbacks('optimizer_step', *args, **kwargs)
def on_before_zero_grad(self, *args, **kwargs):
"""
Fires all registered callbacks before zero grad
"""
self.run_callbacks('on_before_zero_grad', *args, **kwargs)
def on_train_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training batch
"""
self.run_callbacks('on_train_batch_end', *args, **kwargs)
def on_train_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each training epoch
"""
self.run_callbacks('on_train_epoch_end', *args, **kwargs)
def on_val_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of the validation
"""
self.run_callbacks('on_val_start', *args, **kwargs)
def on_val_batch_start(self, *args, **kwargs):
"""
Fires all registered callbacks at the start of each validation batch
"""
self.run_callbacks('on_val_batch_start', *args, **kwargs)
def on_val_image_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each val image
"""
self.run_callbacks('on_val_image_end', *args, **kwargs)
def on_val_batch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each validation batch
"""
self.run_callbacks('on_val_batch_end', *args, **kwargs)
def on_val_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of the validation
"""
self.run_callbacks('on_val_end', *args, **kwargs)
def on_fit_epoch_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of each fit (train+val) epoch
"""
self.run_callbacks('on_fit_epoch_end', *args, **kwargs)
def on_model_save(self, *args, **kwargs):
"""
Fires all registered callbacks after each model save
"""
self.run_callbacks('on_model_save', *args, **kwargs)
def on_train_end(self, *args, **kwargs):
"""
Fires all registered callbacks at the end of training
"""
self.run_callbacks('on_train_end', *args, **kwargs)
def teardown(self, *args, **kwargs):
"""
Fires all registered callbacks before teardown
"""
self.run_callbacks('teardown', *args, **kwargs)
|