Kalen Michael glenn-jocher commited on
Commit
b74929c
1 Parent(s): d8f1883

Add `train.py` and `val.py` callbacks (#4220)

Browse files

* added callbacks

* Update callbacks.py

* Update train.py

* Update val.py

* Fix CamlCase add staticmethod

* Refactor logger into callbacks

* Cleanup

* New callback on_val_image_end()

* Add curves and results images to TensorBoard

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (6) hide show
  1. train.py +19 -10
  2. utils/callbacks.py +176 -0
  3. utils/general.py +5 -0
  4. utils/loggers/__init__.py +24 -21
  5. utils/plots.py +1 -5
  6. val.py +5 -5
train.py CHANGED
@@ -34,7 +34,7 @@ from utils.autoanchor import check_anchors
34
  from utils.datasets import create_dataloader
35
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
36
  strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
37
- check_requirements, print_mutation, set_logging, one_cycle, colorstr
38
  from utils.downloads import attempt_download
39
  from utils.loss import ComputeLoss
40
  from utils.plots import plot_labels, plot_evolution
@@ -42,6 +42,7 @@ from utils.torch_utils import ModelEMA, select_device, intersect_dicts, torch_di
42
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
43
  from utils.metrics import fitness
44
  from utils.loggers import Loggers
 
45
 
46
  LOGGER = logging.getLogger(__name__)
47
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
@@ -52,6 +53,7 @@ WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
52
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
53
  opt,
54
  device,
 
55
  ):
56
  save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
57
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
@@ -77,12 +79,16 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
77
 
78
  # Loggers
79
  if RANK in [-1, 0]:
80
- loggers = Loggers(save_dir, weights, opt, hyp, LOGGER).start() # loggers dict
81
  if loggers.wandb:
82
  data_dict = loggers.wandb.data_dict
83
  if resume:
84
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
85
 
 
 
 
 
86
  # Config
87
  plots = not evolve # create plots
88
  cuda = device.type != 'cpu'
@@ -215,13 +221,15 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
215
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
216
  # model._initialize_biases(cf.to(device))
217
  if plots:
218
- plot_labels(labels, names, save_dir, loggers)
219
 
220
  # Anchors
221
  if not opt.noautoanchor:
222
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
223
  model.half().float() # pre-reduce anchor precision
224
 
 
 
225
  # DDP mode
226
  if cuda and RANK != -1:
227
  model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
@@ -329,8 +337,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
329
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
330
  pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
331
  f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
332
- loggers.on_train_batch_end(ni, model, imgs, targets, paths, plots)
333
-
334
  # end batch ------------------------------------------------------------------------------------------------
335
 
336
  # Scheduler
@@ -339,7 +346,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
339
 
340
  if RANK in [-1, 0]:
341
  # mAP
342
- loggers.on_train_epoch_end(epoch)
343
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
344
  final_epoch = epoch + 1 == epochs
345
  if not noval or final_epoch: # Calculate mAP
@@ -353,14 +360,14 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
353
  save_json=is_coco and final_epoch,
354
  verbose=nc < 50 and final_epoch,
355
  plots=plots and final_epoch,
356
- loggers=loggers,
357
  compute_loss=compute_loss)
358
 
359
  # Update best mAP
360
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
361
  if fi > best_fitness:
362
  best_fitness = fi
363
- loggers.on_train_val_end(mloss, results, lr, epoch, best_fitness, fi)
364
 
365
  # Save model
366
  if (not nosave) or (final_epoch and not evolve): # if save
@@ -377,7 +384,7 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
377
  if best_fitness == fi:
378
  torch.save(ckpt, best)
379
  del ckpt
380
- loggers.on_model_save(last, epoch, final_epoch, best_fitness, fi)
381
 
382
  # end epoch ----------------------------------------------------------------------------------------------------
383
  # end training -----------------------------------------------------------------------------------------------------
@@ -400,7 +407,8 @@ def train(hyp, # path/to/hyp.yaml or hyp dictionary
400
  for f in last, best:
401
  if f.exists():
402
  strip_optimizer(f) # strip optimizers
403
- loggers.on_train_end(last, best, plots)
 
404
 
405
  torch.cuda.empty_cache()
406
  return results
@@ -448,6 +456,7 @@ def parse_opt(known=False):
448
 
449
 
450
  def main(opt):
 
451
  set_logging(RANK)
452
  if RANK in [-1, 0]:
453
  print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
 
34
  from utils.datasets import create_dataloader
35
  from utils.general import labels_to_class_weights, increment_path, labels_to_image_weights, init_seeds, \
36
  strip_optimizer, get_latest_run, check_dataset, check_file, check_git_status, check_img_size, \
37
+ check_requirements, print_mutation, set_logging, one_cycle, colorstr, methods
38
  from utils.downloads import attempt_download
39
  from utils.loss import ComputeLoss
40
  from utils.plots import plot_labels, plot_evolution
 
42
  from utils.loggers.wandb.wandb_utils import check_wandb_resume
43
  from utils.metrics import fitness
44
  from utils.loggers import Loggers
45
+ from utils.callbacks import Callbacks
46
 
47
  LOGGER = logging.getLogger(__name__)
48
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
 
53
  def train(hyp, # path/to/hyp.yaml or hyp dictionary
54
  opt,
55
  device,
56
+ callbacks=Callbacks()
57
  ):
58
  save_dir, epochs, batch_size, weights, single_cls, evolve, data, cfg, resume, noval, nosave, workers, freeze, = \
59
  Path(opt.save_dir), opt.epochs, opt.batch_size, opt.weights, opt.single_cls, opt.evolve, opt.data, opt.cfg, \
 
79
 
80
  # Loggers
81
  if RANK in [-1, 0]:
82
+ loggers = Loggers(save_dir, weights, opt, hyp, LOGGER) # loggers instance
83
  if loggers.wandb:
84
  data_dict = loggers.wandb.data_dict
85
  if resume:
86
  weights, epochs, hyp = opt.weights, opt.epochs, opt.hyp
87
 
88
+ # Register actions
89
+ for k in methods(loggers):
90
+ callbacks.register_action(k, callback=getattr(loggers, k))
91
+
92
  # Config
93
  plots = not evolve # create plots
94
  cuda = device.type != 'cpu'
 
221
  # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
222
  # model._initialize_biases(cf.to(device))
223
  if plots:
224
+ plot_labels(labels, names, save_dir)
225
 
226
  # Anchors
227
  if not opt.noautoanchor:
228
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
229
  model.half().float() # pre-reduce anchor precision
230
 
231
+ callbacks.on_pretrain_routine_end()
232
+
233
  # DDP mode
234
  if cuda and RANK != -1:
235
  model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
 
337
  mem = f'{torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0:.3g}G' # (GB)
338
  pbar.set_description(('%10s' * 2 + '%10.4g' * 5) % (
339
  f'{epoch}/{epochs - 1}', mem, *mloss, targets.shape[0], imgs.shape[-1]))
340
+ callbacks.on_train_batch_end(ni, model, imgs, targets, paths, plots)
 
341
  # end batch ------------------------------------------------------------------------------------------------
342
 
343
  # Scheduler
 
346
 
347
  if RANK in [-1, 0]:
348
  # mAP
349
+ callbacks.on_train_epoch_end(epoch=epoch)
350
  ema.update_attr(model, include=['yaml', 'nc', 'hyp', 'names', 'stride', 'class_weights'])
351
  final_epoch = epoch + 1 == epochs
352
  if not noval or final_epoch: # Calculate mAP
 
360
  save_json=is_coco and final_epoch,
361
  verbose=nc < 50 and final_epoch,
362
  plots=plots and final_epoch,
363
+ callbacks=callbacks,
364
  compute_loss=compute_loss)
365
 
366
  # Update best mAP
367
  fi = fitness(np.array(results).reshape(1, -1)) # weighted combination of [P, R, [email protected], [email protected]]
368
  if fi > best_fitness:
369
  best_fitness = fi
370
+ callbacks.on_fit_epoch_end(mloss, results, lr, epoch, best_fitness, fi)
371
 
372
  # Save model
373
  if (not nosave) or (final_epoch and not evolve): # if save
 
384
  if best_fitness == fi:
385
  torch.save(ckpt, best)
386
  del ckpt
387
+ callbacks.on_model_save(last, epoch, final_epoch, best_fitness, fi)
388
 
389
  # end epoch ----------------------------------------------------------------------------------------------------
390
  # end training -----------------------------------------------------------------------------------------------------
 
407
  for f in last, best:
408
  if f.exists():
409
  strip_optimizer(f) # strip optimizers
410
+ callbacks.on_train_end(last, best, plots, epoch)
411
+ LOGGER.info(f"Results saved to {colorstr('bold', save_dir)}")
412
 
413
  torch.cuda.empty_cache()
414
  return results
 
456
 
457
 
458
  def main(opt):
459
+ # Checks
460
  set_logging(RANK)
461
  if RANK in [-1, 0]:
462
  print(colorstr('train: ') + ', '.join(f'{k}={v}' for k, v in vars(opt).items()))
utils/callbacks.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ class Callbacks:
4
+ """"
5
+ Handles all registered callbacks for YOLOv5 Hooks
6
+ """
7
+
8
+ _callbacks = {
9
+ 'on_pretrain_routine_start': [],
10
+ 'on_pretrain_routine_end': [],
11
+
12
+ 'on_train_start': [],
13
+ 'on_train_epoch_start': [],
14
+ 'on_train_batch_start': [],
15
+ 'optimizer_step': [],
16
+ 'on_before_zero_grad': [],
17
+ 'on_train_batch_end': [],
18
+ 'on_train_epoch_end': [],
19
+
20
+ 'on_val_start': [],
21
+ 'on_val_batch_start': [],
22
+ 'on_val_image_end': [],
23
+ 'on_val_batch_end': [],
24
+ 'on_val_end': [],
25
+
26
+ 'on_fit_epoch_end': [], # fit = train + val
27
+ 'on_model_save': [],
28
+ 'on_train_end': [],
29
+
30
+ 'teardown': [],
31
+ }
32
+
33
+ def __init__(self):
34
+ return
35
+
36
+ def register_action(self, hook, name='', callback=None):
37
+ """
38
+ Register a new action to a callback hook
39
+
40
+ Args:
41
+ hook The callback hook name to register the action to
42
+ name The name of the action
43
+ callback The callback to fire
44
+ """
45
+ assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
46
+ assert callable(callback), f"callback '{callback}' is not callable"
47
+ self._callbacks[hook].append({'name': name, 'callback': callback})
48
+
49
+ def get_registered_actions(self, hook=None):
50
+ """"
51
+ Returns all the registered actions by callback hook
52
+
53
+ Args:
54
+ hook The name of the hook to check, defaults to all
55
+ """
56
+ if hook:
57
+ return self._callbacks[hook]
58
+ else:
59
+ return self._callbacks
60
+
61
+ @staticmethod
62
+ def run_callbacks(register, *args, **kwargs):
63
+ """
64
+ Loop through the registered actions and fire all callbacks
65
+ """
66
+ for logger in register:
67
+ # print(f"Running callbacks.{logger['callback'].__name__}()")
68
+ logger['callback'](*args, **kwargs)
69
+
70
+ def on_pretrain_routine_start(self, *args, **kwargs):
71
+ """
72
+ Fires all registered callbacks at the start of each pretraining routine
73
+ """
74
+ self.run_callbacks(self._callbacks['on_pretrain_routine_start'], *args, **kwargs)
75
+
76
+ def on_pretrain_routine_end(self, *args, **kwargs):
77
+ """
78
+ Fires all registered callbacks at the end of each pretraining routine
79
+ """
80
+ self.run_callbacks(self._callbacks['on_pretrain_routine_end'], *args, **kwargs)
81
+
82
+ def on_train_start(self, *args, **kwargs):
83
+ """
84
+ Fires all registered callbacks at the start of each training
85
+ """
86
+ self.run_callbacks(self._callbacks['on_train_start'], *args, **kwargs)
87
+
88
+ def on_train_epoch_start(self, *args, **kwargs):
89
+ """
90
+ Fires all registered callbacks at the start of each training epoch
91
+ """
92
+ self.run_callbacks(self._callbacks['on_train_epoch_start'], *args, **kwargs)
93
+
94
+ def on_train_batch_start(self, *args, **kwargs):
95
+ """
96
+ Fires all registered callbacks at the start of each training batch
97
+ """
98
+ self.run_callbacks(self._callbacks['on_train_batch_start'], *args, **kwargs)
99
+
100
+ def optimizer_step(self, *args, **kwargs):
101
+ """
102
+ Fires all registered callbacks on each optimizer step
103
+ """
104
+ self.run_callbacks(self._callbacks['optimizer_step'], *args, **kwargs)
105
+
106
+ def on_before_zero_grad(self, *args, **kwargs):
107
+ """
108
+ Fires all registered callbacks before zero grad
109
+ """
110
+ self.run_callbacks(self._callbacks['on_before_zero_grad'], *args, **kwargs)
111
+
112
+ def on_train_batch_end(self, *args, **kwargs):
113
+ """
114
+ Fires all registered callbacks at the end of each training batch
115
+ """
116
+ self.run_callbacks(self._callbacks['on_train_batch_end'], *args, **kwargs)
117
+
118
+ def on_train_epoch_end(self, *args, **kwargs):
119
+ """
120
+ Fires all registered callbacks at the end of each training epoch
121
+ """
122
+ self.run_callbacks(self._callbacks['on_train_epoch_end'], *args, **kwargs)
123
+
124
+ def on_val_start(self, *args, **kwargs):
125
+ """
126
+ Fires all registered callbacks at the start of the validation
127
+ """
128
+ self.run_callbacks(self._callbacks['on_val_start'], *args, **kwargs)
129
+
130
+ def on_val_batch_start(self, *args, **kwargs):
131
+ """
132
+ Fires all registered callbacks at the start of each validation batch
133
+ """
134
+ self.run_callbacks(self._callbacks['on_val_batch_start'], *args, **kwargs)
135
+
136
+ def on_val_image_end(self, *args, **kwargs):
137
+ """
138
+ Fires all registered callbacks at the end of each val image
139
+ """
140
+ self.run_callbacks(self._callbacks['on_val_image_end'], *args, **kwargs)
141
+
142
+ def on_val_batch_end(self, *args, **kwargs):
143
+ """
144
+ Fires all registered callbacks at the end of each validation batch
145
+ """
146
+ self.run_callbacks(self._callbacks['on_val_batch_end'], *args, **kwargs)
147
+
148
+ def on_val_end(self, *args, **kwargs):
149
+ """
150
+ Fires all registered callbacks at the end of the validation
151
+ """
152
+ self.run_callbacks(self._callbacks['on_val_end'], *args, **kwargs)
153
+
154
+ def on_fit_epoch_end(self, *args, **kwargs):
155
+ """
156
+ Fires all registered callbacks at the end of each fit (train+val) epoch
157
+ """
158
+ self.run_callbacks(self._callbacks['on_fit_epoch_end'], *args, **kwargs)
159
+
160
+ def on_model_save(self, *args, **kwargs):
161
+ """
162
+ Fires all registered callbacks after each model save
163
+ """
164
+ self.run_callbacks(self._callbacks['on_model_save'], *args, **kwargs)
165
+
166
+ def on_train_end(self, *args, **kwargs):
167
+ """
168
+ Fires all registered callbacks at the end of training
169
+ """
170
+ self.run_callbacks(self._callbacks['on_train_end'], *args, **kwargs)
171
+
172
+ def teardown(self, *args, **kwargs):
173
+ """
174
+ Fires all registered callbacks before teardown
175
+ """
176
+ self.run_callbacks(self._callbacks['teardown'], *args, **kwargs)
utils/general.py CHANGED
@@ -67,6 +67,11 @@ def try_except(func):
67
  return handler
68
 
69
 
 
 
 
 
 
70
  def set_logging(rank=-1, verbose=True):
71
  logging.basicConfig(
72
  format="%(message)s",
 
67
  return handler
68
 
69
 
70
+ def methods(instance):
71
+ # Get class/instance methods
72
+ return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
73
+
74
+
75
  def set_logging(rank=-1, verbose=True):
76
  logging.basicConfig(
77
  format="%(message)s",
utils/loggers/__init__.py CHANGED
@@ -29,10 +29,12 @@ class Loggers():
29
  self.hyp = hyp
30
  self.logger = logger # for printing results to console
31
  self.include = include
 
 
 
 
32
  for k in LOGGERS:
33
  setattr(self, k, None) # init empty logger dictionary
34
-
35
- def start(self):
36
  self.csv = True # always log to csv
37
 
38
  # Message
@@ -57,7 +59,11 @@ class Loggers():
57
  else:
58
  self.wandb = None
59
 
60
- return self
 
 
 
 
61
 
62
  def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
63
  # Callback runs on train batch end
@@ -78,8 +84,8 @@ class Loggers():
78
  if self.wandb:
79
  self.wandb.current_epoch = epoch + 1
80
 
81
- def on_val_batch_end(self, pred, predn, path, names, im):
82
- # Callback runs on train batch end
83
  if self.wandb:
84
  self.wandb.val_one_image(pred, predn, path, names, im)
85
 
@@ -89,25 +95,20 @@ class Loggers():
89
  files = sorted(self.save_dir.glob('val*.jpg'))
90
  self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
91
 
92
- def on_train_val_end(self, mloss, results, lr, epoch, best_fitness, fi):
93
- # Callback runs on val end during training
94
  vals = list(mloss) + list(results) + lr
95
- keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
96
- 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
97
- 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
98
- 'x/lr0', 'x/lr1', 'x/lr2'] # params
99
- x = {k: v for k, v in zip(keys, vals)} # dict
100
-
101
  if self.csv:
102
  file = self.save_dir / 'results.csv'
103
  n = len(x) + 1 # number of cols
104
- s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + keys)).rstrip(',') + '\n') # add header
105
  with open(file, 'a') as f:
106
  f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
107
 
108
  if self.tb:
109
  for k, v in x.items():
110
- self.tb.add_scalar(k, v, epoch) # TensorBoard
111
 
112
  if self.wandb:
113
  self.wandb.log(x)
@@ -119,20 +120,22 @@ class Loggers():
119
  if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
120
  self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
121
 
122
- def on_train_end(self, last, best, plots):
123
  # Callback runs on training end
124
  if plots:
125
  plot_results(dir=self.save_dir) # save results.png
126
  files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
127
  files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
 
 
 
 
 
 
 
128
  if self.wandb:
129
  wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
130
  wandb.log_artifact(str(best if best.exists() else last), type='model',
131
  name='run_' + self.wandb.wandb_run.id + '_model',
132
  aliases=['latest', 'best', 'stripped'])
133
  self.wandb.finish_run()
134
-
135
- def log_images(self, paths):
136
- # Log images
137
- if self.wandb:
138
- self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
 
29
  self.hyp = hyp
30
  self.logger = logger # for printing results to console
31
  self.include = include
32
+ self.keys = ['train/box_loss', 'train/obj_loss', 'train/cls_loss', # train loss
33
+ 'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5', 'metrics/mAP_0.5:0.95', # metrics
34
+ 'val/box_loss', 'val/obj_loss', 'val/cls_loss', # val loss
35
+ 'x/lr0', 'x/lr1', 'x/lr2'] # params
36
  for k in LOGGERS:
37
  setattr(self, k, None) # init empty logger dictionary
 
 
38
  self.csv = True # always log to csv
39
 
40
  # Message
 
59
  else:
60
  self.wandb = None
61
 
62
+ def on_pretrain_routine_end(self):
63
+ # Callback runs on pre-train routine end
64
+ paths = self.save_dir.glob('*labels*.jpg') # training labels
65
+ if self.wandb:
66
+ self.wandb.log({"Labels": [wandb.Image(str(x), caption=x.name) for x in paths]})
67
 
68
  def on_train_batch_end(self, ni, model, imgs, targets, paths, plots):
69
  # Callback runs on train batch end
 
84
  if self.wandb:
85
  self.wandb.current_epoch = epoch + 1
86
 
87
+ def on_val_image_end(self, pred, predn, path, names, im):
88
+ # Callback runs on val image end
89
  if self.wandb:
90
  self.wandb.val_one_image(pred, predn, path, names, im)
91
 
 
95
  files = sorted(self.save_dir.glob('val*.jpg'))
96
  self.wandb.log({"Validation": [wandb.Image(str(f), caption=f.name) for f in files]})
97
 
98
+ def on_fit_epoch_end(self, mloss, results, lr, epoch, best_fitness, fi):
99
+ # Callback runs at the end of each fit (train+val) epoch
100
  vals = list(mloss) + list(results) + lr
101
+ x = {k: v for k, v in zip(self.keys, vals)} # dict
 
 
 
 
 
102
  if self.csv:
103
  file = self.save_dir / 'results.csv'
104
  n = len(x) + 1 # number of cols
105
+ s = '' if file.exists() else (('%20s,' * n % tuple(['epoch'] + self.keys)).rstrip(',') + '\n') # add header
106
  with open(file, 'a') as f:
107
  f.write(s + ('%20.5g,' * n % tuple([epoch] + vals)).rstrip(',') + '\n')
108
 
109
  if self.tb:
110
  for k, v in x.items():
111
+ self.tb.add_scalar(k, v, epoch)
112
 
113
  if self.wandb:
114
  self.wandb.log(x)
 
120
  if ((epoch + 1) % self.opt.save_period == 0 and not final_epoch) and self.opt.save_period != -1:
121
  self.wandb.log_model(last.parent, self.opt, epoch, fi, best_model=best_fitness == fi)
122
 
123
+ def on_train_end(self, last, best, plots, epoch):
124
  # Callback runs on training end
125
  if plots:
126
  plot_results(dir=self.save_dir) # save results.png
127
  files = ['results.png', 'confusion_matrix.png', *[f'{x}_curve.png' for x in ('F1', 'PR', 'P', 'R')]]
128
  files = [(self.save_dir / f) for f in files if (self.save_dir / f).exists()] # filter
129
+
130
+ if self.tb:
131
+ from PIL import Image
132
+ import numpy as np
133
+ for f in files:
134
+ self.tb.add_image(f.stem, np.asarray(Image.open(f)), epoch, dataformats='HWC')
135
+
136
  if self.wandb:
137
  wandb.log({"Results": [wandb.Image(str(f), caption=f.name) for f in files]})
138
  wandb.log_artifact(str(best if best.exists() else last), type='model',
139
  name='run_' + self.wandb.wandb_run.id + '_model',
140
  aliases=['latest', 'best', 'stripped'])
141
  self.wandb.finish_run()
 
 
 
 
 
utils/plots.py CHANGED
@@ -281,7 +281,7 @@ def plot_study_txt(path='', x=None): # from utils.plots import *; plot_study_tx
281
  plt.savefig(str(Path(path).name) + '.png', dpi=300)
282
 
283
 
284
- def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
285
  # plot dataset labels
286
  print('Plotting labels... ')
287
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
@@ -324,10 +324,6 @@ def plot_labels(labels, names=(), save_dir=Path(''), loggers=None):
324
  matplotlib.use('Agg')
325
  plt.close()
326
 
327
- # loggers
328
- if loggers:
329
- loggers.log_images(save_dir.glob('*labels*.jpg'))
330
-
331
 
332
  def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
333
  # Plot hyperparameter evolution results in evolve.txt
 
281
  plt.savefig(str(Path(path).name) + '.png', dpi=300)
282
 
283
 
284
+ def plot_labels(labels, names=(), save_dir=Path('')):
285
  # plot dataset labels
286
  print('Plotting labels... ')
287
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
 
324
  matplotlib.use('Agg')
325
  plt.close()
326
 
 
 
 
 
327
 
328
  def plot_evolution(yaml_file='data/hyp.finetune.yaml'): # from utils.plots import *; plot_evolution()
329
  # Plot hyperparameter evolution results in evolve.txt
val.py CHANGED
@@ -25,7 +25,7 @@ from utils.general import coco80_to_coco91_class, check_dataset, check_file, che
25
  from utils.metrics import ap_per_class, ConfusionMatrix
26
  from utils.plots import plot_images, output_to_target, plot_study_txt
27
  from utils.torch_utils import select_device, time_sync
28
- from utils.loggers import Loggers
29
 
30
 
31
  def save_one_txt(predn, save_conf, shape, file):
@@ -97,7 +97,7 @@ def run(data,
97
  dataloader=None,
98
  save_dir=Path(''),
99
  plots=True,
100
- loggers=Loggers(),
101
  compute_loss=None,
102
  ):
103
  # Initialize/load model and set device
@@ -213,7 +213,7 @@ def run(data,
213
  save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
214
  if save_json:
215
  save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
216
- loggers.on_val_batch_end(pred, predn, path, names, img[si])
217
 
218
  # Plot images
219
  if plots and batch_i < 3:
@@ -250,7 +250,7 @@ def run(data,
250
  # Plots
251
  if plots:
252
  confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
253
- loggers.on_val_end()
254
 
255
  # Save JSON
256
  if save_json and len(jdict):
@@ -282,7 +282,7 @@ def run(data,
282
  model.float() # for training
283
  if not training:
284
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
285
- print(f"Results saved to {save_dir}{s}")
286
  maps = np.zeros(nc) + map
287
  for i, c in enumerate(ap_class):
288
  maps[c] = ap[i]
 
25
  from utils.metrics import ap_per_class, ConfusionMatrix
26
  from utils.plots import plot_images, output_to_target, plot_study_txt
27
  from utils.torch_utils import select_device, time_sync
28
+ from utils.callbacks import Callbacks
29
 
30
 
31
  def save_one_txt(predn, save_conf, shape, file):
 
97
  dataloader=None,
98
  save_dir=Path(''),
99
  plots=True,
100
+ callbacks=Callbacks(),
101
  compute_loss=None,
102
  ):
103
  # Initialize/load model and set device
 
213
  save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
214
  if save_json:
215
  save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
216
+ callbacks.on_val_image_end(pred, predn, path, names, img[si])
217
 
218
  # Plot images
219
  if plots and batch_i < 3:
 
250
  # Plots
251
  if plots:
252
  confusion_matrix.plot(save_dir=save_dir, names=list(names.values()))
253
+ callbacks.on_val_end()
254
 
255
  # Save JSON
256
  if save_json and len(jdict):
 
282
  model.float() # for training
283
  if not training:
284
  s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
285
+ print(f"Results saved to {colorstr('bold', save_dir)}{s}")
286
  maps = np.zeros(nc) + map
287
  for i, c in enumerate(ap_class):
288
  maps[c] = ap[i]