File size: 3,597 Bytes
cdd2f2d
 
 
 
 
 
 
 
90f3fab
cdd2f2d
 
90f3fab
 
cdd2f2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90f3fab
 
 
cdd2f2d
 
 
 
90f3fab
 
 
 
cdd2f2d
90f3fab
 
 
 
cdd2f2d
90f3fab
cdd2f2d
90f3fab
 
cdd2f2d
 
 
 
 
 
 
 
 
 
 
90f3fab
 
 
 
 
 
 
 
 
 
 
cdd2f2d
90f3fab
cdd2f2d
90f3fab
 
 
 
 
 
 
 
 
2a47dd7
90f3fab
 
 
 
2a47dd7
 
90f3fab
 
 
 
 
 
 
2a47dd7
 
 
 
 
 
 
 
 
 
 
cdd2f2d
 
 
90f3fab
 
 
cdd2f2d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
streamlit run app.py --server.address 0.0.0.0
"""

from __future__ import annotations

import os
from time import time
from typing import Literal

import streamlit as st
import torch
from open_clip import create_model_and_transforms, get_tokenizer
from openai import OpenAI
from qdrant_client import QdrantClient
from qdrant_client.http import models

if os.getenv("SPACE_ID"):
    USE_HF_SPACE = True
    os.environ["HF_HOME"] = "/data/.huggingface"
    os.environ["HF_DATASETS_CACHE"] = "/data/.huggingface"
else:
    USE_HF_SPACE = False

# for tokenizer
os.environ["TOKENIZERS_PARALLELISM"] = "false"

OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")

BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/"
TargetImageType = Literal["xsmall", "small", "medium", "large"]

if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY:
    raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.")


def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str:
    return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp"


@st.cache_resource
def get_model_preprocess_tokenizer(
    target_model: str = "xlm-roberta-base-ViT-B-32",
    pretrained: str = "laion5B-s13B-b90k",
):
    model, _, preprocess = create_model_and_transforms(
        target_model, pretrained=pretrained
    )
    tokenizer = get_tokenizer(target_model)
    return model, preprocess, tokenizer


@st.cache_resource
def get_qdrant_client():
    qdrant_client = QdrantClient(
        url=QDRANT_API_ENDPOINT,
        api_key=QDRANT_API_KEY,
    )
    return qdrant_client


@st.cache_data
def get_text_features(text: str):
    model, preprocess, tokenizer = get_model_preprocess_tokenizer()
    text_tokenized = tokenizer([text])
    with torch.no_grad():
        text_features = model.encode_text(text_tokenized)  # type: ignore
        text_features /= text_features.norm(dim=-1, keepdim=True)
    # tensor to list
    return text_features[0].tolist()


def app():
    _, _, _ = get_model_preprocess_tokenizer()  # for cache
    st.title("secon.dev site search")
    search_text = st.text_input("Search", key="search_text")
    if search_text:
        st.write("searching...")
        start = time()
        qdrant_client = get_qdrant_client()
        text_features = get_text_features(search_text)
        search_results = qdrant_client.search(
            collection_name="images-clip",
            query_vector=text_features,
            limit=50,
        )
        elapsed = time() - start
        st.write(f"elapsed: {elapsed:.2f} sec")
        st.write(f"total: {len(search_results)}")
        images = []
        captions = []
        for r in search_results:
            score = r.score
            if payload := r.payload:
                name = payload["name"]
            else:
                name = "unknown"
            image_url = get_image_url(name, image_type="xsmall")
            images.append(image_url)
            captions.append(f"{name} ({score:.4f})")
        image_group_n = 6
        for i in range(0, len(images), image_group_n):
            target_images = images[i : i + image_group_n]
            target_captions = captions[i : i + image_group_n]
            st.image(
                target_images,
                caption=target_captions,
                width=160,
            )


if __name__ == "__main__":
    st.set_page_config(
        layout="wide", page_icon="https://secon.dev/images/profile_usa.png"
    )
    app()