clip-rsicd-demo / dashboard_text2image.py
Sujit Pal
fix: removed commented code
f9d31ee
raw
history blame
2.65 kB
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import streamlit as st
from transformers import CLIPProcessor, FlaxCLIPModel
import utils
BASELINE_MODEL = "openai/clip-vit-base-patch32"
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
MODEL_PATH = "flax-community/clip-rsicd-v2"
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
# IMAGES_DIR = "/home/shared/data/rsicd_images"
IMAGES_DIR = "./images"
def app():
filenames, index = utils.load_index(IMAGE_VECTOR_FILE)
model, processor = utils.load_model(MODEL_PATH, BASELINE_MODEL)
st.title("Text to Image Retrieval")
st.markdown("""
The CLIP model from OpenAI is trained in a self-supervised manner using
contrastive learning to project images and caption text onto a common
embedding space. We have fine-tuned the model using the RSICD dataset
(10k images and ~50k captions from the remote sensing domain).
This demo shows the image to text retrieval capabilities of this model, i.e.,
given a text query, we use our fine-tuned CLIP model to project the text query
to the image/caption embedding space and search for nearby images (by
cosine similarity) in this space.
Our fine-tuned CLIP model was previously used to generate image vectors for
our demo, and NMSLib was used for fast vector access.
Some suggested queries to start you off with -- `ships`, `school house`,
`military installations`, `mountains`, `beaches`, `airports`, `lakes`, etc.
""")
query = st.text_input("Text Query:")
if st.button("Query"):
inputs = processor(text=[query], images=None, return_tensors="jax", padding=True)
query_vec = model.get_text_features(**inputs)
query_vec = np.asarray(query_vec)
ids, distances = index.knnQuery(query_vec, k=10)
result_filenames = [filenames[id] for id in ids]
images, captions = [], []
for result_filename, score in zip(result_filenames, distances):
images.append(
plt.imread(os.path.join(IMAGES_DIR, result_filename)))
captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
st.image(images[0:3], caption=captions[0:3])
st.image(images[3:6], caption=captions[3:6])
st.image(images[6:9], caption=captions[6:9])
st.image(images[9:], caption=captions[9:])