Naman Gupta glenn-jocher commited on
Commit
6f3db5e
1 Parent(s): c5969f7

Remove autoanchor and class checks on resumed training (#889)

Browse files

* Class frequency not calculated on resuming training

Calculation of class frequency is not needed when resuming training.
Anchors can still be recalculated whether resuming or not.

* Check rank for autoanchor

* Update train.py

no autoanchor checks on resume

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. train.py +4 -4
train.py CHANGED
@@ -185,18 +185,18 @@ def train(hyp, opt, device, tb_writer=None):
185
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
186
  model.names = names
187
 
188
- # Class frequency
189
- if rank in [-1, 0]:
190
  labels = np.concatenate(dataset.labels, 0)
191
  c = torch.tensor(labels[:, 0]) # classes
192
- # cf = torch.bincount(c.long(), minlength=nc) + 1.
193
  # model._initialize_biases(cf.to(device))
194
  plot_labels(labels, save_dir=log_dir)
195
  if tb_writer:
196
  # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
197
  tb_writer.add_histogram('classes', c, 0)
198
 
199
- # Check anchors
200
  if not opt.noautoanchor:
201
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
202
 
 
185
  model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
186
  model.names = names
187
 
188
+ # Classes and Anchors
189
+ if rank in [-1, 0] and not opt.resume:
190
  labels = np.concatenate(dataset.labels, 0)
191
  c = torch.tensor(labels[:, 0]) # classes
192
+ # cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
193
  # model._initialize_biases(cf.to(device))
194
  plot_labels(labels, save_dir=log_dir)
195
  if tb_writer:
196
  # tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
197
  tb_writer.add_histogram('classes', c, 0)
198
 
199
+ # Anchors
200
  if not opt.noautoanchor:
201
  check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
202