glenn-jocher commited on
Commit
7a565f1
1 Parent(s): 4984cf5

Update `dataset_stats()` (#3593)

Browse files

@KalenMike this is a PR to add image filenames and labels to our stats dictionary and to save the dictionary to JSON. Save location is next to the train labels.cache file. The single JSON contains all stats for entire dataset.

Usage example:
```python
from utils.datasets import *

dataset_stats('coco128.yaml', verbose=True)
```

Files changed (1) hide show
  1. utils/datasets.py +12 -3
utils/datasets.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  import glob
4
  import hashlib
 
5
  import logging
6
  import math
7
  import os
@@ -1105,12 +1106,20 @@ def dataset_stats(path='coco128.yaml', autodownload=False, verbose=False):
1105
  continue
1106
  x = []
1107
  dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
 
 
1108
  for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
1109
  x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
1110
  x = np.array(x) # shape(128x80)
1111
- stats[split] = {'instances': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
1112
- 'images': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
1113
- 'per_class': (x > 0).sum(0).tolist()}}
 
 
 
 
 
1114
  if verbose:
1115
  print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
 
1116
  return stats
 
2
 
3
  import glob
4
  import hashlib
5
+ import json
6
  import logging
7
  import math
8
  import os
 
1106
  continue
1107
  x = []
1108
  dataset = LoadImagesAndLabels(data[split], augment=False, rect=True) # load dataset
1109
+ if split == 'train':
1110
+ cache_path = Path(dataset.label_files[0]).parent.with_suffix('.cache') # *.cache path
1111
  for label in tqdm(dataset.labels, total=dataset.n, desc='Statistics'):
1112
  x.append(np.bincount(label[:, 0].astype(int), minlength=nc))
1113
  x = np.array(x) # shape(128x80)
1114
+ stats[split] = {'instance_stats': {'total': int(x.sum()), 'per_class': x.sum(0).tolist()},
1115
+ 'image_stats': {'total': dataset.n, 'unlabelled': int(np.all(x == 0, 1).sum()),
1116
+ 'per_class': (x > 0).sum(0).tolist()},
1117
+ 'labels': {str(Path(k).name): v.tolist() for k, v in zip(dataset.img_files, dataset.labels)}}
1118
+
1119
+ # Save, print and return
1120
+ with open(cache_path.with_suffix('.json'), 'w') as f:
1121
+ json.dump(stats, f) # save stats *.json
1122
  if verbose:
1123
  print(yaml.dump([stats], sort_keys=False, default_flow_style=False))
1124
+ # print(json.dumps(stats, indent=2, sort_keys=False))
1125
  return stats