glenn-jocher commited on
Commit
8b26e89
1 Parent(s): 8fa3724

AutoAnchor bug fix #72

Browse files
Files changed (2) hide show
  1. train.py +1 -2
  2. utils/utils.py +6 -4
train.py CHANGED
@@ -4,7 +4,6 @@ import torch.distributed as dist
4
  import torch.nn.functional as F
5
  import torch.optim as optim
6
  import torch.optim.lr_scheduler as lr_scheduler
7
- import yaml
8
  from torch.utils.tensorboard import SummaryWriter
9
 
10
  import test # import test.py to get mAP after each epoch
@@ -200,7 +199,7 @@ def train(hyp):
200
  tb_writer.add_histogram('classes', c, 0)
201
 
202
  # Check anchors
203
- check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'])
204
 
205
  # Exponential moving average
206
  ema = torch_utils.ModelEMA(model)
 
4
  import torch.nn.functional as F
5
  import torch.optim as optim
6
  import torch.optim.lr_scheduler as lr_scheduler
 
7
  from torch.utils.tensorboard import SummaryWriter
8
 
9
  import test # import test.py to get mAP after each epoch
 
199
  tb_writer.add_histogram('classes', c, 0)
200
 
201
  # Check anchors
202
+ check_best_possible_recall(dataset, anchors=model.model[-1].anchor_grid, thr=hyp['anchor_t'], imgsz=imgsz)
203
 
204
  # Exponential moving average
205
  ema = torch_utils.ModelEMA(model)
utils/utils.py CHANGED
@@ -52,15 +52,17 @@ def check_img_size(img_size, s=32):
52
  return make_divisible(img_size, s) # nearest gs-multiple
53
 
54
 
55
- def check_best_possible_recall(dataset, anchors, thr):
56
  # Check best possible recall of dataset with current anchors
57
- wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(dataset.shapes, dataset.labels)])).float() # wh
 
58
  ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
59
  m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
60
  bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
61
  mr = (m < thr).float().mean() # match ratio
62
- print(('Label width-height:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
63
- print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
 
64
  assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
65
  'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
66
 
 
52
  return make_divisible(img_size, s) # nearest gs-multiple
53
 
54
 
55
+ def check_best_possible_recall(dataset, anchors, thr=4.0, imgsz=640):
56
  # Check best possible recall of dataset with current anchors
57
+ shapes = imgsz * dataset.shapes / dataset.shapes.max(1, keepdims=True)
58
+ wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
59
  ratio = wh[:, None] / anchors.view(-1, 2).cpu()[None] # ratio
60
  m = torch.max(ratio, 1. / ratio).max(2)[0] # max ratio
61
  bpr = (m.min(1)[0] < thr).float().mean() # best possible recall
62
  mr = (m < thr).float().mean() # match ratio
63
+ print(('AutoAnchor labels:' + '%10s' * 6) % ('n', 'mean', 'min', 'max', 'matching', 'recall'))
64
+ print((' ' + '%10.4g' * 6) % (wh.shape[0], wh.mean(), wh.min(), wh.max(), mr, bpr))
65
+
66
  assert bpr > 0.9, 'Best possible recall %.3g (BPR) below 0.9 threshold. Training cancelled. ' \
67
  'Compute new anchors with utils.utils.kmeans_anchors() and update model before training.' % bpr
68