import gradio as gr import torch import pickle import numpy as np import pandas as pd from transformers import CLIPProcessor, CLIPModel from transformers import VisionTextDualEncoderModel, VisionTextDualEncoderProcessor from sklearn.metrics.pairwise import cosine_similarity import csv from PIL import Image model_path_rclip = "kaveh/rclip" embeddings_file_rclip = './image_embeddings_rclip.pkl' model_path_pubmedclip = "flaviagiammarino/pubmed-clip-vit-base-patch32" embeddings_file_pubmedclip = './image_embeddings_pubmedclip.pkl' csv_path = "./captions.txt" def load_image_ids(csv_file): ids = [] captions = [] with open(csv_file, 'r') as f: reader = csv.reader(f, delimiter='\t') for row in reader: ids.append(row[0]) captions.append(row[1]) return ids, captions def load_embeddings(embeddings_file): with open(embeddings_file, 'rb') as f: image_embeddings = pickle.load(f) return image_embeddings def find_similar_images(query_embedding, image_embeddings, k=2): similarities = cosine_similarity(query_embedding.reshape(1, -1), image_embeddings) closest_indices = np.argsort(similarities[0])[::-1][:k] scores = sorted(similarities[0])[::-1][:k] return closest_indices, scores def main(query, model_id="RCLIP", k=2): if model_id=="RCLIP": # Load RCLIP model model = VisionTextDualEncoderModel.from_pretrained(model_path_rclip) processor = VisionTextDualEncoderProcessor.from_pretrained(model_path_rclip) # Load image embeddings image_embeddings = load_embeddings(embeddings_file_rclip) elif model_id=="PubMedCLIP": model = CLIPModel.from_pretrained(model_path_pubmedclip) processor = CLIPProcessor.from_pretrained(model_path_pubmedclip) # Load image embeddings image_embeddings = load_embeddings(embeddings_file_pubmedclip) # Embed the query inputs = processor(text=query, images=None, return_tensors="pt", padding=True) with torch.no_grad(): query_embedding = model.get_text_features(**inputs)[0].numpy() # Get image names ids, captions = load_image_ids(csv_path) # Find similar images similar_image_indices, scores = find_similar_images(query_embedding, image_embeddings, k=int(k)) # Return the results similar_image_names = [f"./images/{ids[index]}.jpg" for index in similar_image_indices] similar_image_captions = [captions[index] for index in similar_image_indices] similar_images = [Image.open(i) for i in similar_image_names] return similar_images, pd.DataFrame([[t+1 for t in range(k)], similar_image_names, similar_image_captions, scores], index=["#", "path", "caption", "score"]).T # Define the Gradio interface examples = [ ["Chest X-ray photos", "RCLIP", 5], ["Chest X-ray photos", "PubMedCLIP", 5], ["Orthopantogram (OPG)", "RCLIP",5], ["Brain MRI", "RCLIP",5], ["Ultrasound", "RCLIP",5], ] title="RCLIP Image Retrieval" description = "CLIP model fine-tuned on the ROCO dataset" with gr.Blocks(title=title) as demo: with gr.Row(): with gr.Column(scale=5): gr.Markdown("# "+title) gr.Markdown(description) #Image.open("./data/teesside university logo.png"), height=70, show_label=False, container=False) with gr.Row(variant="compact"): query = gr.Textbox(value="Chest X-Ray Photos", label="Enter your query", show_label=False, placeholder= "Enter your query" , scale=5) btn = gr.Button("Search query", variant="primary", scale=1) with gr.Row(variant="compact"): model_id = gr.Dropdown(["RCLIP", "PubMedCLIP"], value="RCLIP", label="Model", type="value", scale=1) n_s = gr.Slider(2, 10, label='Number of Top Results', value=5, step=1.0, show_label=True, scale=1) with gr.Column(variant="compact"): gr.Markdown("## Results") gallery = gr.Gallery(label="found images", show_label=True, elem_id="gallery", columns=[2], rows=[4], object_fit="contain", height="400px", preview=True) gr.Markdown("Information of the found images") df = gr.DataFrame() btn.click(main, [query, model_id, n_s], [gallery, df]) with gr.Column(variant="compact"): gr.Markdown("## Examples") gr.Examples(examples, [query, model_id, n_s]) demo.launch(debug='True')