Refactor app.py: Update imports, add get_image_url function, and optimize search functionality
90f3fab
""" | |
streamlit run app.py --server.address 0.0.0.0 | |
""" | |
from __future__ import annotations | |
import os | |
from time import time | |
from typing import Literal | |
import streamlit as st | |
import torch | |
from open_clip import create_model_and_transforms, get_tokenizer | |
from openai import OpenAI | |
from qdrant_client import QdrantClient | |
from qdrant_client.http import models | |
if os.getenv("SPACE_ID"): | |
USE_HF_SPACE = True | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" | |
else: | |
USE_HF_SPACE = False | |
# for tokenizer | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT") | |
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") | |
BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/" | |
TargetImageType = Literal["xsmall", "small", "medium", "large"] | |
if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY: | |
raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.") | |
def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str: | |
return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp" | |
def get_model_preprocess_tokenizer( | |
target_model: str = "xlm-roberta-base-ViT-B-32", | |
pretrained: str = "laion5B-s13B-b90k", | |
): | |
model, _, preprocess = create_model_and_transforms( | |
target_model, pretrained=pretrained | |
) | |
tokenizer = get_tokenizer(target_model) | |
return model, preprocess, tokenizer | |
def get_qdrant_client(): | |
qdrant_client = QdrantClient( | |
url=QDRANT_API_ENDPOINT, | |
api_key=QDRANT_API_KEY, | |
) | |
return qdrant_client | |
def get_text_features(text: str): | |
model, preprocess, tokenizer = get_model_preprocess_tokenizer() | |
text_tokenized = tokenizer([text]) | |
with torch.no_grad(): | |
text_features = model.encode_text(text_tokenized) # type: ignore | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
# tensor to list | |
return text_features[0].tolist() | |
def app(): | |
_, _, _ = get_model_preprocess_tokenizer() # for cache | |
st.title("secon.dev site search") | |
search_text = st.text_input("Search", key="search_text") | |
if search_text: | |
st.write("searching...") | |
start = time() | |
qdrant_client = get_qdrant_client() | |
text_features = get_text_features(search_text) | |
search_results = qdrant_client.search( | |
collection_name="images-clip", | |
query_vector=text_features, | |
limit=20, | |
) | |
elapsed = time() - start | |
st.write(f"elapsed: {elapsed:.2f} sec") | |
st.write(f"total: {len(search_results)}") | |
for r in search_results: | |
score = r.score | |
if payload := r.payload: | |
name = payload["name"] | |
else: | |
name = "unknown" | |
image_url = get_image_url(name, image_type="xsmall") | |
st.write(f"score: {score:.2f}") | |
st.image(image_url, width=200) | |
st.write("---") | |
if __name__ == "__main__": | |
st.set_page_config( | |
layout="wide", page_icon="https://secon.dev/images/profile_usa.png" | |
) | |
app() | |