Spaces:
Build error
Build error
import os | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import streamlit as st | |
from utils import load_index, load_model | |
def app(model_name): | |
images_directory = "images/val2017" | |
features_directory = f"features/val2017/{model_name}.tsv" | |
files, index = load_index(features_directory) | |
model, processor = load_model(f"koclip/{model_name}") | |
st.title("Text to Image Search Engine") | |
st.markdown( | |
""" | |
This demonstration explores capability of KoCLIP as a Korean-language Image search engine. Embeddings for each of | |
5000 images from [MSCOCO](https://cocodataset.org/#home) 2017 validation set was generated using trained KoCLIP | |
vision model. They are ranked based on cosine similarity distance from input Text query embeddings and top 10 images | |
are displayed below. | |
KoCLIP is a retraining of OpenAI's CLIP model using 82,783 images from [MSCOCO](https://cocodataset.org/#home) dataset and | |
Korean caption annotations. Korean translation of caption annotations were obtained from [AI Hub](https://aihub.or.kr/keti_data_board/visual_intelligence). | |
Base model `koclip` uses `klue/roberta` as text encoder and `openai/clip-vit-base-patch32` as image encoder. | |
Larger model `koclip-large` uses `klue/roberta` as text encoder and bigger `google/vit-large-patch16-224` as image encoder. | |
Example Queries : ์ปดํจํฐํ๋ ๊ณ ์์ด(Cat playing on a computer), ๊ธธ ์์์ ๋ฌ๋ฆฌ๋ ์๋์ฐจ(Car running on the road), | |
""" | |
) | |
query = st.text_input("ํ๊ธ ์ง๋ฌธ์ ์ ์ด์ฃผ์ธ์ (Korean Text Query) :", value="์ํํธ") | |
if st.button("์ง๋ฌธ (Query)"): | |
proc = processor(text=[query], images=None, return_tensors="jax", padding=True) | |
vec = np.asarray(model.get_text_features(**proc)) | |
ids, dists = index.knnQuery(vec, k=10) | |
result_files = map(lambda id: files[id], ids) | |
result_imgs, result_captions = [], [] | |
for file, dist in zip(result_files, dists): | |
result_imgs.append(plt.imread(os.path.join(images_directory, file))) | |
result_captions.append("{:s} (์ ์ฌ๋: {:.3f})".format(file, 1.0 - dist)) | |
st.image(result_imgs[:3], caption=result_captions[:3], width=200) | |
st.image(result_imgs[3:6], caption=result_captions[3:6], width=200) | |
st.image(result_imgs[6:9], caption=result_captions[6:9], width=200) | |
st.image(result_imgs[9:], caption=result_captions[9:], width=200) | |