hotchpotch's picture
Refactor app.py: Update imports, add get_image_url function, and optimize search functionality
90f3fab
raw
history blame
3.24 kB
"""
streamlit run app.py --server.address 0.0.0.0
"""
from __future__ import annotations
import os
from time import time
from typing import Literal
import streamlit as st
import torch
from open_clip import create_model_and_transforms, get_tokenizer
from openai import OpenAI
from qdrant_client import QdrantClient
from qdrant_client.http import models
if os.getenv("SPACE_ID"):
USE_HF_SPACE = True
os.environ["HF_HOME"] = "/data/.huggingface"
os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
USE_HF_SPACE = False
# for tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/"
TargetImageType = Literal["xsmall", "small", "medium", "large"]
if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY:
raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.")
def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str:
return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp"
@st.cache_resource
def get_model_preprocess_tokenizer(
target_model: str = "xlm-roberta-base-ViT-B-32",
pretrained: str = "laion5B-s13B-b90k",
):
model, _, preprocess = create_model_and_transforms(
target_model, pretrained=pretrained
)
tokenizer = get_tokenizer(target_model)
return model, preprocess, tokenizer
@st.cache_resource
def get_qdrant_client():
qdrant_client = QdrantClient(
url=QDRANT_API_ENDPOINT,
api_key=QDRANT_API_KEY,
)
return qdrant_client
@st.cache_data
def get_text_features(text: str):
model, preprocess, tokenizer = get_model_preprocess_tokenizer()
text_tokenized = tokenizer([text])
with torch.no_grad():
text_features = model.encode_text(text_tokenized) # type: ignore
text_features /= text_features.norm(dim=-1, keepdim=True)
# tensor to list
return text_features[0].tolist()
def app():
_, _, _ = get_model_preprocess_tokenizer() # for cache
st.title("secon.dev site search")
search_text = st.text_input("Search", key="search_text")
if search_text:
st.write("searching...")
start = time()
qdrant_client = get_qdrant_client()
text_features = get_text_features(search_text)
search_results = qdrant_client.search(
collection_name="images-clip",
query_vector=text_features,
limit=20,
)
elapsed = time() - start
st.write(f"elapsed: {elapsed:.2f} sec")
st.write(f"total: {len(search_results)}")
for r in search_results:
score = r.score
if payload := r.payload:
name = payload["name"]
else:
name = "unknown"
image_url = get_image_url(name, image_type="xsmall")
st.write(f"score: {score:.2f}")
st.image(image_url, width=200)
st.write("---")
if __name__ == "__main__":
st.set_page_config(
layout="wide", page_icon="https://secon.dev/images/profile_usa.png"
)
app()