File size: 4,469 Bytes
d19fddf
 
 
 
 
 
 
 
 
 
 
20dc8e8
 
 
 
 
 
2d8e3da
d19fddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e11923
 
20dc8e8
 
 
 
 
2e11923
20dc8e8
 
 
 
d19fddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84f4414
 
 
 
 
d19fddf
 
 
 
 
 
 
 
 
 
7de86a5
d19fddf
2e11923
 
 
 
 
 
84f4414
d19fddf
2e11923
d19fddf
 
72b6d2f
d19fddf
 
2e11923
d19fddf
 
 
2e11923
d19fddf
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
import torch
import pickle
import numpy as np
import pandas as pd
from transformers import CLIPProcessor, CLIPModel
from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor
from sklearn.metrics.pairwise import cosine_similarity
import csv
from PIL import Image

model_path_rclip = "kaveh/rclip"
embeddings_file_rclip = './image_embeddings_rclip.pkl' 

model_path_pubmedclip = "flaviagiammarino/pubmed-clip-vit-base-patch32"
embeddings_file_pubmedclip = './image_embeddings_pubmedclip.pkl' 

csv_path = "./captions.txt"

def load_image_ids(csv_file):
    ids = []
    captions = []
    with open(csv_file, 'r') as f:
        reader = csv.reader(f, delimiter='\t')
        for row in reader:
            ids.append(row[0])
            captions.append(row[1])
    return ids, captions

def load_embeddings(embeddings_file):
    with open(embeddings_file, 'rb') as f:
        image_embeddings = pickle.load(f)
    return image_embeddings


def find_similar_images(query_embedding, image_embeddings, k=2):
    similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeddings)
    closest_indices = np.argsort(similarities[0])[::-1][:k]
    scores = sorted(similarities[0])[::-1][:k]
    return closest_indices, scores


def main(query, model_id="RCLIP", k=2):
    if model_id=="RCLIP":
        # Load RCLIP model
        model = VisionTextDualEncoderModel.from_pretrained(model_path_rclip)
        processor = VisionTextDualEncoderProcessor.from_pretrained(model_path_rclip)
        # Load image embeddings 
        image_embeddings = load_embeddings(embeddings_file_rclip)
    elif model_id=="PubMedCLIP":
        model = CLIPModel.from_pretrained(model_path_pubmedclip)
        processor = CLIPProcessor.from_pretrained(model_path_pubmedclip)
        # Load image embeddings 
        image_embeddings = load_embeddings(embeddings_file_pubmedclip)   

    # Embed the query
    inputs = processor(text=query, images=None, return_tensors="pt", padding=True)
    with torch.no_grad():
        query_embedding = model.get_text_features(**inputs)[0].numpy()
    
    # Get image names
    ids, captions = load_image_ids(csv_path)
    
    # Find similar images
    similar_image_indices, scores = find_similar_images(query_embedding, image_embeddings, k=int(k))

    # Return the results
    similar_image_names = [f"./images/{ids[index]}.jpg" for index in similar_image_indices]
    similar_image_captions = [captions[index] for index in similar_image_indices]
    similar_images = [Image.open(i) for i in similar_image_names]

    return similar_images, pd.DataFrame([[t+1 for t in range(k)], similar_image_names, similar_image_captions, scores], index=["#", "path", "caption", "score"]).T


# Define the Gradio interface
examples = [
            ["Chest X-ray photos", "RCLIP", 10],
            ["Chest X-ray photos", "PubMedCLIP", 10],
            ["Orthopantogram (OPG)", "RCLIP", 10],
            ["Brain MRI", "RCLIP", 10], 
            ["Ultrasound", "RCLIP", 10],
]

title="RCLIP Image Retrieval"
description = "CLIP model fine-tuned on the ROCO dataset"

with gr.Blocks(title=title) as demo:
    with gr.Row():
        with gr.Column(scale=5):
            gr.Markdown("# "+title)
            gr.Markdown(description)
        
        #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False)
    with gr.Row(variant="compact"):
         query = gr.Textbox(value="Chest X-Ray Photos", label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5)
         btn = gr.Button("Search query", variant="primary", scale=1)
         
    with gr.Row(variant="compact"):   
        model_id = gr.Dropdown(["RCLIP", "PubMedCLIP"], value="RCLIP", label="Model", type="value", scale=1)
        n_s = gr.Slider(2, 10, label='Number of Top Results', value=10, step=1.0, show_label=True, scale=1)
        

    with gr.Column(variant="compact"):
        gr.Markdown("## Results")
        gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="400px", preview=True)
        gr.Markdown("Information of the found images")
        df = gr.DataFrame()
    btn.click(main, [query, model_id, n_s], [gallery, df])
    
    with gr.Column(variant="compact"):
        gr.Markdown("## Examples")
        gr.Examples(examples, [query, model_id, n_s])
    
    
demo.launch(debug='True')