File size: 3,534 Bytes
357b0b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import matplotlib.pyplot as plt
import nmslib
import numpy as np
import os
import streamlit as st

from PIL import Image
from transformers import CLIPProcessor, FlaxCLIPModel


BASELINE_MODEL = "openai/clip-vit-base-patch32"
# MODEL_PATH = "/home/shared/models/clip-rsicd/bs128x8-lr5e-6-adam/ckpt-1"
MODEL_PATH = "flax-community/clip-rsicd"

# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-baseline.tsv"
# IMAGE_VECTOR_FILE = "/home/shared/data/vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"
IMAGE_VECTOR_FILE = "./vectors/test-bs128x8-lr5e-6-adam-ckpt-1.tsv"

# IMAGES_DIR = "/home/shared/data/rsicd_images"
IMAGES_DIR = "./images"


@st.cache(allow_output_mutation=True)
def load_index():
    filenames, image_vecs = [], []
    fvec = open(IMAGE_VECTOR_FILE, "r")
    for line in fvec:
        cols = line.strip().split('\t')
        filename = cols[0]
        image_vec = np.array([float(x) for x in cols[1].split(',')])
        filenames.append(filename)
        image_vecs.append(image_vec)
    V = np.array(image_vecs)
    index = nmslib.init(method='hnsw', space='cosinesimil')
    index.addDataPointBatch(V)
    index.createIndex({'post': 2}, print_progress=True)
    return filenames, index


@st.cache(allow_output_mutation=True)
def load_model():
    model = FlaxCLIPModel.from_pretrained(MODEL_PATH)
    processor = CLIPProcessor.from_pretrained(BASELINE_MODEL)
    return model, processor


def app():
    filenames, index = load_index()
    model, processor = load_model()

    st.title("Image to Image Retrieval")
    st.markdown("""
        The CLIP model from OpenAI is trained in a self-supervised manner using 
        contrastive learning to project images and caption text onto a common 
        embedding space. We have fine-tuned the model using the RSICD dataset
        (10k images and ~50k captions from the remote sensing domain).
        
        This demo shows the image to image retrieval capabilities of this model, i.e., 
        given an image file name as a query (we suggest copy pasting the file name
        from the result of a text to image query), we use our fine-tuned CLIP model 
        to project the query image to the image/caption embedding space and search 
        for nearby images (by cosine similarity) in this space.
        
        Our fine-tuned CLIP model was previously used to generate image vectors for 
        our demo, and NMSLib was used for fast vector access.
    """)

    image_file = st.text_input("Image Query (filename):")
    if st.button("Query"):
        image = Image.fromarray(plt.imread(os.path.join(IMAGES_DIR, image_file)))
        inputs = processor(images=image, return_tensors="jax", padding=True)
        query_vec = model.get_image_features(**inputs)
        query_vec = np.asarray(query_vec)
        ids, distances = index.knnQuery(query_vec, k=11)
        result_filenames = [filenames[id] for id in ids]
        images, captions = [], []
        for result_filename, score in zip(result_filenames, distances):
            if result_filename == image_file:
                continue
            images.append(
                plt.imread(os.path.join(IMAGES_DIR, result_filename)))
            captions.append("{:s} (score: {:.3f})".format(result_filename, 1.0 - score))
        images = images[0:10]
        captions = captions[0:10]
        st.image(images[0:3], caption=captions[0:3])
        st.image(images[3:6], caption=captions[3:6])
        st.image(images[6:9], caption=captions[6:9])
        st.image(images[9:], caption=captions[9:])