Sujit Pal commited on
Commit
f58917e
1 Parent(s): 5de821f

fix: changing output format to include caption

Browse files
Files changed (3) hide show
  1. dashboard_image2image.py +14 -13
  2. dashboard_text2image.py +14 -16
  3. utils.py +15 -0
dashboard_image2image.py CHANGED
@@ -12,11 +12,9 @@ import utils
12
 
13
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
14
  MODEL_PATH = "flax-community/clip-rsicd-v2"
15
-
16
  IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
17
-
18
  IMAGES_DIR = "./images"
19
-
20
 
21
  @st.cache(allow_output_mutation=True)
22
  def load_example_images():
@@ -62,6 +60,7 @@ def download_and_prepare_image(image_url):
62
  def app():
63
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
64
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
 
65
 
66
  example_image_list = load_example_images()
67
 
@@ -150,17 +149,19 @@ def app():
150
  query_vec = np.asarray(query_vec)
151
  ids, distances = index.knnQuery(query_vec, k=11)
152
  result_filenames = [filenames[id] for id in ids]
153
- images, captions = [], []
154
  for result_filename, score in zip(result_filenames, distances):
155
  if image_name is not None and result_filename == image_name:
156
  continue
157
- images.append(
158
- plt.imread(os.path.join(IMAGES_DIR, result_filename)))
159
- captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
160
- images = images[0:10]
161
- captions = captions[0:10]
162
- st.image(images[0:3], caption=captions[0:3])
163
- st.image(images[3:6], caption=captions[3:6])
164
- st.image(images[6:9], caption=captions[6:9])
165
- st.image(images[9:], caption=captions[9:])
 
 
166
  suggest_idx = -1
 
12
 
13
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
14
  MODEL_PATH = "flax-community/clip-rsicd-v2"
 
15
  IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
 
16
  IMAGES_DIR = "./images"
17
+ CAPTIONS_FILE = os.path.join(IMAGES_DIR, "dataset_rsicd.json")
18
 
19
  @st.cache(allow_output_mutation=True)
20
  def load_example_images():
 
60
  def app():
61
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
62
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
63
+ image2caption = utils.load_captions(CAPTIONS_FILE)
64
 
65
  example_image_list = load_example_images()
66
 
 
149
  query_vec = np.asarray(query_vec)
150
  ids, distances = index.knnQuery(query_vec, k=11)
151
  result_filenames = [filenames[id] for id in ids]
152
+ rank = 0
153
  for result_filename, score in zip(result_filenames, distances):
154
  if image_name is not None and result_filename == image_name:
155
  continue
156
+ caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
157
+ col1, col2, col3 = st.beta_columns([2, 10, 10])
158
+ col1.markdown("{:d}.".format(rank + 1))
159
+ col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
160
+ caption=caption)
161
+ caption_text = []
162
+ for caption in image2caption[result_filename]:
163
+ caption_text.append("* {:s}\n".format(caption))
164
+ col3.markdown("".join(caption_text))
165
+ rank += 1
166
+ st.markdown("---")
167
  suggest_idx = -1
dashboard_text2image.py CHANGED
@@ -4,25 +4,21 @@ import numpy as np
4
  import os
5
  import streamlit as st
6
 
 
7
  from transformers import CLIPProcessor, FlaxCLIPModel
8
 
9
  import utils
10
 
11
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
12
- # MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
13
  MODEL_PATH = "flax-community/clip-rsicd-v2"
14
-
15
- # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
16
- # IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
17
  IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
18
-
19
- # IMAGES_DIR = "/home/shared/data/rsicd_images"
20
  IMAGES_DIR = "./images"
21
-
22
 
23
  def app():
24
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
25
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
 
26
 
27
  st.title("Retrieve Images given Text")
28
  st.markdown("""
@@ -78,13 +74,15 @@ def app():
78
  query_vec = np.asarray(query_vec)
79
  ids, distances = index.knnQuery(query_vec, k=10)
80
  result_filenames = [filenames[id] for id in ids]
81
- images, captions = [], []
82
- for result_filename, score in zip(result_filenames, distances):
83
- images.append(
84
- plt.imread(os.path.join(IMAGES_DIR, result_filename)))
85
- captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
86
- st.image(images[0:3], caption=captions[0:3])
87
- st.image(images[3:6], caption=captions[3:6])
88
- st.image(images[6:9], caption=captions[6:9])
89
- st.image(images[9:], caption=captions[9:])
 
 
90
  suggest_idx = -1
 
4
  import os
5
  import streamlit as st
6
 
7
+ from PIL import Image
8
  from transformers import CLIPProcessor, FlaxCLIPModel
9
 
10
  import utils
11
 
12
  BASELINE_MODEL = "openai/clip-vit-base-patch32"
 
13
  MODEL_PATH = "flax-community/clip-rsicd-v2"
 
 
 
14
  IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
 
 
15
  IMAGES_DIR = "./images"
16
+ CAPTIONS_FILE = os.path.join(IMAGES_DIR, "dataset_rsicd.json")
17
 
18
  def app():
19
  filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
20
  model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
21
+ image2caption = utils.load_captions(CAPTIONS_FILE)
22
 
23
  st.title("Retrieve Images given Text")
24
  st.markdown("""
 
74
  query_vec = np.asarray(query_vec)
75
  ids, distances = index.knnQuery(query_vec, k=10)
76
  result_filenames = [filenames[id] for id in ids]
77
+ for rank, (result_filename, score) in enumerate(zip(result_filenames, distances)):
78
+ caption = "{:s} (score: {:.3f})".format(result_filename, 1.0 - score)
79
+ col1, col2, col3 = st.beta_columns([2, 10, 10])
80
+ col1.markdown("{:d}.".format(rank + 1))
81
+ col2.image(Image.open(os.path.join(IMAGES_DIR, result_filename)),
82
+ caption=caption)
83
+ caption_text = []
84
+ for caption in image2caption[result_filename]:
85
+ caption_text.append("* {:s}\n".format(caption))
86
+ col3.markdown("".join(caption_text))
87
+ st.markdown("---")
88
  suggest_idx = -1
utils.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import matplotlib.pyplot as plt
2
  import nmslib
3
  import numpy as np
@@ -31,3 +32,17 @@ def load_model(model_path, baseline_model):
31
  # processor = CLIPProcessor.from_pretrained(baseline_model)
32
  processor = CLIPProcessor.from_pretrained(model_path)
33
  return model, processor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import matplotlib.pyplot as plt
3
  import nmslib
4
  import numpy as np
 
32
  # processor = CLIPProcessor.from_pretrained(baseline_model)
33
  processor = CLIPProcessor.from_pretrained(model_path)
34
  return model, processor
35
+
36
+
37
+ @st.cache(allow_output_mutation=True)
38
+ def load_captions(caption_file):
39
+ image2caption = {}
40
+ with open(caption_file, "r") as fcap:
41
+ data = json.loads(fcap.read())
42
+ for image in data["images"]:
43
+ filename = image["filename"]
44
+ captions = []
45
+ for sentence in image["sentences"]:
46
+ captions.append(sentence["raw"])
47
+ image2caption[filename] = captions
48
+ return image2caption