patrickvonplaten commited on
Commit
273867b
1 Parent(s): 3d59359
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from datasets import load_dataset
2
- from collections import Counter
3
  from random import sample, shuffle
4
  import datasets
5
  from pandas import DataFrame
@@ -65,10 +65,19 @@ def start():
65
 
66
  # sort by count
67
  ids = sorted(ids.items(), key=lambda x: x[1])
68
- ids = [i[0] for i in ids]
 
 
 
 
 
 
 
 
 
69
 
70
  # get lowest count ids
71
- id_candidates = ids[: (10 * NUM_QUESTIONS)]
72
 
73
  # get random `NUM_QUESTIONS` ids to check
74
  image_ids = sample(id_candidates, k=NUM_QUESTIONS)
 
1
  from datasets import load_dataset
2
+ from collections import Counter, defaultdict
3
  from random import sample, shuffle
4
  import datasets
5
  from pandas import DataFrame
 
65
 
66
  # sort by count
67
  ids = sorted(ids.items(), key=lambda x: x[1])
68
+ freq_ids = defaultdict(list)
69
+ for k, v in ids:
70
+ freq_ids[v].append(k)
71
+
72
+ # shuffle in-between categories
73
+ for k, v_list in freq_ids.items():
74
+ shuffle(v_list)
75
+ freq_ids[v] = v_list
76
+
77
+ shuffled_ids = sum(list(freq_ids.values()), [])
78
 
79
  # get lowest count ids
80
+ id_candidates = shuffled_ids[: (10 * NUM_QUESTIONS)]
81
 
82
  # get random `NUM_QUESTIONS` ids to check
83
  image_ids = sample(id_candidates, k=NUM_QUESTIONS)