glenn-jocher
commited on
Commit
•
9fdb0fb
1
Parent(s):
8b6f582
AutoAnchor bug fix # 117
Browse files- train.py +3 -1
- utils/utils.py +8 -5
train.py
CHANGED
@@ -200,7 +200,8 @@ def train(hyp):
|
|
200 |
tb_writer.add_histogram('classes', c, 0)
|
201 |
|
202 |
# Check anchors
|
203 |
-
|
|
|
204 |
|
205 |
# Exponential moving average
|
206 |
ema = torch_utils.ModelEMA(model)
|
@@ -374,6 +375,7 @@ if __name__ == '__main__':
|
|
374 |
parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
|
375 |
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
376 |
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
|
|
377 |
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
378 |
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
|
379 |
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
|
|
200 |
tb_writer.add_histogram('classes', c, 0)
|
201 |
|
202 |
# Check anchors
|
203 |
+
if not opt.noautoanchor:
|
204 |
+
check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)
|
205 |
|
206 |
# Exponential moving average
|
207 |
ema = torch_utils.ModelEMA(model)
|
|
|
375 |
parser.add_argument('--resume', action='store_true', help='resume training from last.pt')
|
376 |
parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
|
377 |
parser.add_argument('--notest', action='store_true', help='only test final epoch')
|
378 |
+
parser.add_argument('--noautoanchor', action='store_true', help='disable autoanchor check')
|
379 |
parser.add_argument('--evolve', action='store_true', help='evolve hyperparameters')
|
380 |
parser.add_argument('--bucket', type=str, default='', help='gsutil bucket')
|
381 |
parser.add_argument('--cache-images', action='store_true', help='cache images for faster training')
|
utils/utils.py
CHANGED
@@ -56,7 +56,7 @@ def check_img_size(img_size, s=32):
|
|
56 |
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
57 |
# Check anchor fit to data, recompute if necessary
|
58 |
print('\nAnalyzing anchors... ', end='')
|
59 |
-
|
60 |
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
61 |
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
|
62 |
|
@@ -66,14 +66,17 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
|
66 |
best = x.max(1)[0] # best_x
|
67 |
return (best > 1. / thr).float().mean() # best possible recall
|
68 |
|
69 |
-
bpr = metric(
|
70 |
print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
|
71 |
if bpr < 0.99: # threshold to recompute
|
72 |
print('. Attempting to generate improved anchors, please wait...' % bpr)
|
73 |
-
|
|
|
74 |
new_bpr = metric(new_anchors.reshape(-1, 2))
|
75 |
-
if new_bpr > bpr:
|
76 |
-
|
|
|
|
|
77 |
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
78 |
else:
|
79 |
print('Original anchors better than new anchors. Proceeding with original anchors.')
|
|
|
56 |
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
57 |
# Check anchor fit to data, recompute if necessary
|
58 |
print('\nAnalyzing anchors... ', end='')
|
59 |
+
m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
|
60 |
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
61 |
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
|
62 |
|
|
|
66 |
best = x.max(1)[0] # best_x
|
67 |
return (best > 1. / thr).float().mean() # best possible recall
|
68 |
|
69 |
+
bpr = metric(m.anchor_grid.clone().cpu().view(-1, 2))
|
70 |
print('Best Possible Recall (BPR) = %.4f' % bpr, end='')
|
71 |
if bpr < 0.99: # threshold to recompute
|
72 |
print('. Attempting to generate improved anchors, please wait...' % bpr)
|
73 |
+
na = m.anchor_grid.numel() // 2 # number of anchors
|
74 |
+
new_anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
|
75 |
new_bpr = metric(new_anchors.reshape(-1, 2))
|
76 |
+
if new_bpr > bpr: # replace anchors
|
77 |
+
new_anchors = torch.tensor(new_anchors, device=m.anchors.device).type_as(m.anchors)
|
78 |
+
m.anchor_grid[:] = new_anchors.clone().view_as(m.anchor_grid) # for inference
|
79 |
+
m.anchors[:] = new_anchors.clone().view_as(m.anchors) / m.stride.to(m.anchors.device).view(-1, 1, 1) # loss
|
80 |
print('New anchors saved to model. Update model *.yaml to use these anchors in the future.')
|
81 |
else:
|
82 |
print('Original anchors better than new anchors. Proceeding with original anchors.')
|