vojtam's picture
Update app.py
48a1a00 verified
import gradio as gr
import pickle
from datasets import load_dataset
from torch import nn
import numpy as np
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
model_checkpoint = "openai/clip-vit-base-patch32"
def get_clip_embeddings(input_data, input_type='text'):
# Load the CLIP model and processor
model = CLIPModel.from_pretrained(model_checkpoint)
processor = CLIPProcessor.from_pretrained(model_checkpoint)
# Prepare the input based on the type
if input_type == 'text':
inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True)
elif input_type == 'image':
if isinstance(input_data, str):
image = Image.open(input_data)
elif isinstance(input_data, Image.Image):
image = input_data
else:
raise ValueError("For image input, provide either a file path or a PIL Image object")
inputs = processor(images=image, return_tensors="pt")
else:
raise ValueError("Invalid input_type. Choose 'text' or 'image'")
# Get the embeddings
with torch.no_grad():
if input_type == 'text':
embeddings = model.get_text_features(**inputs)
else:
embeddings = model.get_image_features(**inputs)
return embeddings.numpy()
veggies = load_dataset('vojtam/vegetables')
with open('img_embeddings.pkl', 'rb') as file:
img_embeddings = pickle.load(file)
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def get_similar_images(text, n=4):
if text:
text_embedding = get_clip_embeddings(text, input_type='text')
sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings))
top_n = np.argsort(np.array(sims))[::-1][:n]
imgs = []
for index in top_n:
imgs.append(veggies['train'][index.item()]['image'])
return imgs
return []
css = """
.full-height-gallery {
height: calc(100vh - 250px);
overflow-y: auto;
}
#submit-btn {
background-color: #ff5b00;
color: #ffffff;
}
"""
with gr.Blocks(css=css) as intf:
with gr.Row():
text_input = gr.Textbox(label="Enter the description of the images you want to search for", placeholder='Your text goes here')
with gr.Row():
submit_btn = gr.Button("Submit", elem_id="submit-btn")
clear_btn = gr.Button("Clear")
with gr.Row():
gallery = gr.Gallery(label="Similar Images", show_label=False, elem_classes = ["full-height-gallery"])
submit_btn.click(fn=get_similar_images, inputs=text_input, outputs=gallery)
clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, gallery])
intf.launch(share=True)