glenn-jocher
commited on
Commit
•
ec81c7b
1
Parent(s):
e1e3399
check_anchors() bug fix #90
Browse files- train.py +1 -1
- utils/utils.py +6 -6
train.py
CHANGED
@@ -199,7 +199,7 @@ def train(hyp):
|
|
199 |
tb_writer.add_histogram('classes', c, 0)
|
200 |
|
201 |
# Check anchors
|
202 |
-
|
203 |
|
204 |
# Exponential moving average
|
205 |
ema = torch_utils.ModelEMA(model)
|
|
|
199 |
tb_writer.add_histogram('classes', c, 0)
|
200 |
|
201 |
# Check anchors
|
202 |
+
check_anchors(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz)
|
203 |
|
204 |
# Exponential moving average
|
205 |
ema = torch_utils.ModelEMA(model)
|
utils/utils.py
CHANGED
@@ -52,8 +52,9 @@ def check_img_size(img_size, s=32):
|
|
52 |
return make_divisible(img_size, s) # nearest gs-multiple
|
53 |
|
54 |
|
55 |
-
def
|
56 |
# Check best possible recall of dataset with current anchors
|
|
|
57 |
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
58 |
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
|
59 |
ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
|
@@ -62,7 +63,6 @@ def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640):
|
|
62 |
mr = (m < thr).float().mean() # match ratio
|
63 |
print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
|
64 |
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
|
65 |
-
|
66 |
assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
|
67 |
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
|
68 |
|
@@ -512,10 +512,10 @@ def build_targets(p, targets, model):
|
|
512 |
|
513 |
|
514 |
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
|
515 |
-
"""
|
516 |
-
|
517 |
-
Returns
|
518 |
-
|
519 |
"""
|
520 |
if prediction.dtype is torch.float16:
|
521 |
prediction = prediction.float() # to FP32
|
|
|
52 |
return make_divisible(img_size, s) # nearest gs-multiple
|
53 |
|
54 |
|
55 |
+
def check_anchors(dataset, model, thr=4.0, imgsz=640):
|
56 |
# Check best possible recall of dataset with current anchors
|
57 |
+
anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid
|
58 |
shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
|
59 |
wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
|
60 |
ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
|
|
|
63 |
mr = (m < thr).float().mean() # match ratio
|
64 |
print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
|
65 |
print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
|
|
|
66 |
assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
|
67 |
'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
|
68 |
|
|
|
512 |
|
513 |
|
514 |
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
|
515 |
+
"""Performs Non-Maximum Suppression (NMS) on inference results
|
516 |
+
|
517 |
+
Returns:
|
518 |
+
detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
|
519 |
"""
|
520 |
if prediction.dtype is torch.float16:
|
521 |
prediction = prediction.float() # to FP32
|