Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pickle | |
from datasets import load_dataset | |
from torch import nn | |
import numpy as np | |
import torch | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel | |
from datasets import load_dataset | |
model_checkpoint = "openai/clip-vit-base-patch32" | |
def get_clip_embeddings(input_data, input_type='text'): | |
# Load the CLIP model and processor | |
model = CLIPModel.from_pretrained(model_checkpoint) | |
processor = CLIPProcessor.from_pretrained(model_checkpoint) | |
# Prepare the input based on the type | |
if input_type == 'text': | |
inputs = processor(text=input_data, return_tensors="pt", padding=True, truncation=True) | |
elif input_type == 'image': | |
if isinstance(input_data, str): | |
image = Image.open(input_data) | |
elif isinstance(input_data, Image.Image): | |
image = input_data | |
else: | |
raise ValueError("For image input, provide either a file path or a PIL Image object") | |
inputs = processor(images=image, return_tensors="pt") | |
else: | |
raise ValueError("Invalid input_type. Choose 'text' or 'image'") | |
# Get the embeddings | |
with torch.no_grad(): | |
if input_type == 'text': | |
embeddings = model.get_text_features(**inputs) | |
else: | |
embeddings = model.get_image_features(**inputs) | |
return embeddings.numpy() | |
veggies = load_dataset('vojtam/vegetables') | |
with open('img_embeddings.pkl', 'rb') as file: | |
img_embeddings = pickle.load(file) | |
cos = nn.CosineSimilarity(dim=1, eps=1e-6) | |
def get_similar_images(text, n=4): | |
if text: | |
text_embedding = get_clip_embeddings(text, input_type='text') | |
sims = cos(torch.tensor(text_embedding), torch.tensor(img_embeddings)) | |
top_n = np.argsort(np.array(sims))[::-1][:n] | |
imgs = [] | |
for index in top_n: | |
imgs.append(veggies['train'][index.item()]['image']) | |
return imgs | |
return [] | |
css = """ | |
.full-height-gallery { | |
height: calc(100vh - 250px); | |
overflow-y: auto; | |
} | |
#submit-btn { | |
background-color: #ff5b00; | |
color: #ffffff; | |
} | |
""" | |
with gr.Blocks(css=css) as intf: | |
with gr.Row(): | |
text_input = gr.Textbox(label="Enter the description of the images you want to search for", placeholder='Your text goes here') | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", elem_id="submit-btn") | |
clear_btn = gr.Button("Clear") | |
with gr.Row(): | |
gallery = gr.Gallery(label="Similar Images", show_label=False, elem_classes = ["full-height-gallery"]) | |
submit_btn.click(fn=get_similar_images, inputs=text_input, outputs=gallery) | |
clear_btn.click(fn=lambda: [None, []], inputs=None, outputs=[text_input, gallery]) | |
intf.launch(share=True) |