Gregor commited on
Commit
6f21e67
1 Parent(s): 7eacff5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -21,6 +21,7 @@ main_language_values = sorted([[name, code] for code, name in language_names.ite
21
 
22
  babel_imagenet = json.load(open("data/babel_imagenet-298.json", encoding="utf-8"))
23
  babelnet_images = json.load(open("data/images.json", encoding="utf-8"))
 
24
  no_image_idxs = [i for i, imgs in enumerate(babelnet_images) if len(imgs) == 0]
25
  IMG_HEIGHT, IMG_WIDTH = 512, 512
26
 
@@ -95,7 +96,7 @@ def prepare(raw_idx, lang, text_embeddings, class_order, randomize_images):
95
 
96
  img_idx = 0
97
  if randomize_images:
98
- img_idx = np.random.choice(len(babelnet_images[class_idx]))
99
  img_url = babelnet_images[class_idx][img_idx]["url"]
100
  class_labels = babel_imagenet[lang][1] if lang != "EN" else openai_en_classes
101
 
@@ -144,7 +145,7 @@ def reroll(raw_idx, lang, text_embeddings, class_order, randomize_images):
144
 
145
  img_idx = 0
146
  if randomize_images:
147
- img_idx = np.random.choice(len(babelnet_images[class_idx]))
148
  img_url = babelnet_images[class_idx][img_idx]["url"]
149
  class_labels = babel_imagenet[lang][1] if lang != "EN" else openai_en_classes
150
 
 
21
 
22
  babel_imagenet = json.load(open("data/babel_imagenet-298.json", encoding="utf-8"))
23
  babelnet_images = json.load(open("data/images.json", encoding="utf-8"))
24
+ max_image_choices = 10 # Currently up to 30 images but relevance degrades quickly in my experience. Limiting to 10
25
  no_image_idxs = [i for i, imgs in enumerate(babelnet_images) if len(imgs) == 0]
26
  IMG_HEIGHT, IMG_WIDTH = 512, 512
27
 
 
96
 
97
  img_idx = 0
98
  if randomize_images:
99
+ img_idx = np.random.choice(min(len(babelnet_images[class_idx]), max_image_choices))
100
  img_url = babelnet_images[class_idx][img_idx]["url"]
101
  class_labels = babel_imagenet[lang][1] if lang != "EN" else openai_en_classes
102
 
 
145
 
146
  img_idx = 0
147
  if randomize_images:
148
+ img_idx = np.random.choice(min(len(babelnet_images[class_idx]), max_image_choices))
149
  img_url = babelnet_images[class_idx][img_idx]["url"]
150
  class_labels = babel_imagenet[lang][1] if lang != "EN" else openai_en_classes
151