Commit
•
6f3db5e
1
Parent(s):
c5969f7
Remove autoanchor and class checks on resumed training (#889)
Browse files* Class frequency not calculated on resuming training
Calculation of class frequency is not needed when resuming training.
Anchors can still be recalculated whether resuming or not.
* Check rank for autoanchor
* Update train.py
no autoanchor checks on resume
Co-authored-by: Glenn Jocher <[email protected]>
train.py
CHANGED
@@ -185,18 +185,18 @@ def train(hyp, opt, device, tb_writer=None):
|
|
185 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
186 |
model.names = names
|
187 |
|
188 |
-
#
|
189 |
-
if rank in [-1, 0]:
|
190 |
labels = np.concatenate(dataset.labels, 0)
|
191 |
c = torch.tensor(labels[:, 0]) # classes
|
192 |
-
# cf = torch.bincount(c.long(), minlength=nc) + 1.
|
193 |
# model._initialize_biases(cf.to(device))
|
194 |
plot_labels(labels, save_dir=log_dir)
|
195 |
if tb_writer:
|
196 |
# tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
|
197 |
tb_writer.add_histogram('classes', c, 0)
|
198 |
|
199 |
-
#
|
200 |
if not opt.noautoanchor:
|
201 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
202 |
|
|
|
185 |
model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) # attach class weights
|
186 |
model.names = names
|
187 |
|
188 |
+
# Classes and Anchors
|
189 |
+
if rank in [-1, 0] and not opt.resume:
|
190 |
labels = np.concatenate(dataset.labels, 0)
|
191 |
c = torch.tensor(labels[:, 0]) # classes
|
192 |
+
# cf = torch.bincount(c.long(), minlength=nc) + 1. # frequency
|
193 |
# model._initialize_biases(cf.to(device))
|
194 |
plot_labels(labels, save_dir=log_dir)
|
195 |
if tb_writer:
|
196 |
# tb_writer.add_hparams(hyp, {}) # causes duplicate https://github.com/ultralytics/yolov5/pull/384
|
197 |
tb_writer.add_histogram('classes', c, 0)
|
198 |
|
199 |
+
# Anchors
|
200 |
if not opt.noautoanchor:
|
201 |
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
202 |
|