glenn-jocher commited on
Commit
bafbc65
1 Parent(s): 57a0ae3

AutoAnchor bug fix

Browse files
Files changed (1) hide show
  1. utils/utils.py +5 -5
utils/utils.py CHANGED
@@ -719,7 +719,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
719
  return x, x.max(1)[0] # x, best_x
720
 
721
  def fitness(k): # mutation fitness
722
- _, best = metric(k)
723
  return (best * (best > thr).float()).mean() # fitness
724
 
725
  def print_results(k):
@@ -743,8 +743,8 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
743
 
744
  # Get label wh
745
  shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
746
- wh = torch.tensor(np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)])).float() # wh
747
- wh = wh[(wh > 2.0).all(1)].numpy() # filter > 2 pixels
748
 
749
  # Kmeans calculation
750
  from scipy.cluster.vq import kmeans
@@ -752,7 +752,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
752
  s = wh.std(0) # sigmas for whitening
753
  k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
754
  k *= s
755
- wh = torch.tensor(wh)
756
  k = print_results(k)
757
 
758
  # Plot
@@ -771,7 +771,7 @@ def kmean_anchors(path='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen=10
771
  # Evolve
772
  npr = np.random
773
  f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
774
- for _ in tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm:'):
775
  v = np.ones(sh)
776
  while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
777
  v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)
 
719
  return x, x.max(1)[0] # x, best_x
720
 
721
  def fitness(k): # mutation fitness
722
+ _, best = metric(torch.tensor(k, dtype=torch.float32))
723
  return (best * (best > thr).float()).mean() # fitness
724
 
725
  def print_results(k):
 
743
 
744
  # Get label wh
745
  shapes = img_size * dataset.shapes / dataset.shapes.max(1, keepdims=True)
746
+ wh = np.concatenate([l[:, 3:5] * s for s, l in zip(shapes, dataset.labels)]) # wh
747
+ wh = wh[(wh > 2.0).all(1)] # filter > 2 pixels
748
 
749
  # Kmeans calculation
750
  from scipy.cluster.vq import kmeans
 
752
  s = wh.std(0) # sigmas for whitening
753
  k, dist = kmeans(wh / s, n, iter=30) # points, mean distance
754
  k *= s
755
+ wh = torch.tensor(wh, dtype=torch.float32)
756
  k = print_results(k)
757
 
758
  # Plot
 
771
  # Evolve
772
  npr = np.random
773
  f, sh, mp, s = fitness(k), k.shape, 0.9, 0.1 # fitness, generations, mutation prob, sigma
774
+ for _ in tqdm(range(gen), desc='Evolving anchors with Genetic Algorithm'):
775
  v = np.ones(sh)
776
  while (v == 1).all(): # mutate until a change occurs (prevent duplicates)
777
  v = ((npr.random(sh) < mp) * npr.random() * npr.randn(*sh) * s + 1).clip(0.3, 3.0)