""" streamlit run app.py --server.address 0.0.0.0 """ from __future__ import annotations import os from time import time from typing import Literal import streamlit as st import torch from open_clip import create_model_and_transforms, get_tokenizer from openai import OpenAI from qdrant_client import QdrantClient from qdrant_client.http import models if os.getenv("SPACE_ID"): USE_HF_SPACE = True os.environ["HF_HOME"] = "/data/.huggingface" os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface" else: USE_HF_SPACE = False # for tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT") QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY") BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/" TargetImageType = Literal["xsmall", "small", "medium", "large"] if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY: raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.") def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str: return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp" @st.cache_resource def get_model_preprocess_tokenizer( target_model: str = "xlm-roberta-base-ViT-B-32", pretrained: str = "laion5B-s13B-b90k", ): model, _, preprocess = create_model_and_transforms( target_model, pretrained=pretrained ) tokenizer = get_tokenizer(target_model) return model, preprocess, tokenizer @st.cache_resource def get_qdrant_client(): qdrant_client = QdrantClient( url=QDRANT_API_ENDPOINT, api_key=QDRANT_API_KEY, ) return qdrant_client @st.cache_data def get_text_features(text: str): model, preprocess, tokenizer = get_model_preprocess_tokenizer() text_tokenized = tokenizer([text]) with torch.no_grad(): text_features = model.encode_text(text_tokenized) # type: ignore text_features /= text_features.norm(dim=-1, keepdim=True) # tensor to list return text_features[0].tolist() def app(): _, _, _ = get_model_preprocess_tokenizer() # for cache st.title("secon.dev site search") search_text = st.text_input("Search", key="search_text") if search_text: st.write("searching...") start = time() qdrant_client = get_qdrant_client() text_features = get_text_features(search_text) search_results = qdrant_client.search( collection_name="images-clip", query_vector=text_features, limit=20, ) elapsed = time() - start st.write(f"elapsed: {elapsed:.2f} sec") st.write(f"total: {len(search_results)}") for r in search_results: score = r.score if payload := r.payload: name = payload["name"] else: name = "unknown" image_url = get_image_url(name, image_type="xsmall") st.write(f"score: {score:.2f}") st.image(image_url, width=200) st.write("---") if __name__ == "__main__": st.set_page_config( layout="wide", page_icon="https://secon.dev/images/profile_usa.png" ) app()