glenn-jocher commited on
Commit
ec81c7b
1 Parent(s): e1e3399

check_anchors() bug fix #90

Browse files
Files changed (2) hide show
  1. train.py +1 -1
  2. utils/utils.py +6 -6
train.py CHANGED
@@ -199,7 +199,7 @@ def train(hyp):
199
  tb_writer.add_histogram('classes', c, 0)
200
 
201
  # Check anchors
202
- check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz)
203
 
204
  # Exponential moving average
205
  ema = torch_utils.ModelEMA(model)
 
199
  tb_writer.add_histogram('classes', c, 0)
200
 
201
  # Check anchors
202
+ check_anchors(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz)
203
 
204
  # Exponential moving average
205
  ema = torch_utils.ModelEMA(model)
utils/utils.py CHANGED
@@ -52,8 +52,9 @@ def check_img_size(img_size, s=32):
52
  return make_divisible(img_size, s) # nearest gs-multiple
53
 
54
 
55
- def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640):
56
  # Check best possible recall of dataset with current anchors
 
57
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
58
  wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
59
  ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
@@ -62,7 +63,6 @@ def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640):
62
  mr = (m < thr).float().mean() # match ratio
63
  print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
64
  print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
65
-
66
  assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
67
  'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
68
 
@@ -512,10 +512,10 @@ def build_targets(p, targets, model):
512
 
513
 
514
  def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
515
- """
516
- Performs Non-Maximum Suppression on inference results
517
- Returns detections with shape:
518
- nx6 (x1, y1, x2, y2, conf, cls)
519
  """
520
  if prediction.dtype is torch.float16:
521
  prediction = prediction.float() # to FP32
 
52
  return make_divisible(img_size, s) # nearest gs-multiple
53
 
54
 
55
+ def check_anchors(dataset, model, thr=4.0, imgsz=640):
56
  # Check best possible recall of dataset with current anchors
57
+ anchors = model.module.model[-1].anchor_grid if hasattr(model, 'module') else model.model[-1].anchor_grid
58
  shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
59
  wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
60
  ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
 
63
  mr = (m < thr).float().mean() # match ratio
64
  print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
65
  print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
 
66
  assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
67
  'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
68
 
 
512
 
513
 
514
  def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, fast=False, classes=None, agnostic=False):
515
+ """Performs Non-Maximum Suppression (NMS) on inference results
516
+
517
+ Returns:
518
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
519
  """
520
  if prediction.dtype is torch.float16:
521
  prediction = prediction.float() # to FP32