Spaces:
Runtime error
Runtime error
File size: 1,739 Bytes
0d59440 cf9d9e0 d6c88ae 0d59440 d6c88ae 42f52e5 b54da20 0d59440 d6c88ae e3bc95e d6c88ae e3bc95e dec5315 b047033 dec5315 7318e38 dec5315 0d59440 dec5315 d6c88ae 489b7f2 d6c88ae 489b7f2 d6c88ae e3bc95e 0a6c9ba 482cb2b 82496ef d6c88ae f352d7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import streamlit as st
st.set_page_config(page_title='T2I', page_icon="🧊", layout='centered')
st.title("Text To Image Retrieval for KaggleX BPIOC Mentorship Program")
import torch
from transformers import AutoTokenizer, AutoModel
import faiss
import numpy as np
from PIL import Image
from sentence_transformers import SentenceTransformer
import json
import zipfile
# Map the image ids to the corresponding image URLs
image_map_name = 'captions.json'
with open(image_map_name, 'r') as f:
caption_dict = json.load(f)
image_list = list(caption_dict.keys())
caption_list = list(caption_dict.values())
zip_path = "Images.zip"
zip_file = zipfile.ZipFile(zip_path)
model_name = "sentence-transformers/all-distilroberta-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = SentenceTransformer(model_name)
# vectors = model.encode(caption_list)
vectors = np.load("./sbert_text_features.npy")
vector_dimension = vectors.shape[1]
index = faiss.IndexFlatIP(vector_dimension)
faiss.normalize_L2(vectors)
index.add(vectors)
def search(query, k=4):
# Encode the query
query_embedding = model.encode(query)
query_vector = np.array([query_embedding])
faiss.normalize_L2(query_vector)
index.nprobe = index.ntotal
# Search for the nearest neighbors in the FAISS index
D, I = index.search(query_vector, k)
# Map the image ids to the corresponding image URLs
image_urls = []
for i in I[0]:
text_id = i
image_id = str(image_list[i])
image_data = zip_file.open("Images/" +image_id)
image = Image.open(image_data)
st.image(image, width=600)
query = st.text_input("Enter your search query here:")
if st.button("Search"):
if query:
search(query) |