hotchpotch commited on
Commit
90f3fab
1 Parent(s): cdd2f2d

Refactor app.py: Update imports, add get_image_url function, and optimize search functionality

Browse files
Files changed (1) hide show
  1. app.py +56 -6
app.py CHANGED
@@ -6,11 +6,11 @@ from __future__ import annotations
6
 
7
  import os
8
  from time import time
 
9
 
10
- import faiss
11
- import pandas as pd
12
  import streamlit as st
13
- from open_clip import create_model_and_transforms
 
14
  from openai import OpenAI
15
  from qdrant_client import QdrantClient
16
  from qdrant_client.http import models
@@ -29,16 +29,27 @@ OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
29
  QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT")
30
  QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
31
 
 
 
 
32
  if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY:
33
  raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.")
34
 
35
 
 
 
 
 
36
  @st.cache_resource
37
- def get_model_preprocess():
 
 
 
38
  model, _, preprocess = create_model_and_transforms(
39
- "xlm-roberta-base-ViT-B-32", pretrained="laion5B-s13B-b90k"
40
  )
41
- return model, preprocess
 
42
 
43
 
44
  @st.cache_resource
@@ -50,9 +61,48 @@ def get_qdrant_client():
50
  return qdrant_client
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  def app():
 
54
  st.title("secon.dev site search")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
 
57
  if __name__ == "__main__":
 
 
 
58
  app()
 
6
 
7
  import os
8
  from time import time
9
+ from typing import Literal
10
 
 
 
11
  import streamlit as st
12
+ import torch
13
+ from open_clip import create_model_and_transforms, get_tokenizer
14
  from openai import OpenAI
15
  from qdrant_client import QdrantClient
16
  from qdrant_client.http import models
 
29
  QDRANT_API_ENDPOINT = os.environ.get("QDRANT_API_ENDPOINT")
30
  QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY")
31
 
32
+ BASE_IMAGE_URL = "https://storage.googleapis.com/secons-site-images/photo/"
33
+ TargetImageType = Literal["xsmall", "small", "medium", "large"]
34
+
35
  if not QDRANT_API_ENDPOINT or not QDRANT_API_KEY:
36
  raise ValueError("env: QDRANT_API_ENDPOINT or QDRANT_API_KEY is not set.")
37
 
38
 
39
+ def get_image_url(image_name: str, image_type: TargetImageType = "xsmall") -> str:
40
+ return f"{BASE_IMAGE_URL}{image_type}/{image_name}.webp"
41
+
42
+
43
  @st.cache_resource
44
+ def get_model_preprocess_tokenizer(
45
+ target_model: str = "xlm-roberta-base-ViT-B-32",
46
+ pretrained: str = "laion5B-s13B-b90k",
47
+ ):
48
  model, _, preprocess = create_model_and_transforms(
49
+ target_model, pretrained=pretrained
50
  )
51
+ tokenizer = get_tokenizer(target_model)
52
+ return model, preprocess, tokenizer
53
 
54
 
55
  @st.cache_resource
 
61
  return qdrant_client
62
 
63
 
64
+ @st.cache_data
65
+ def get_text_features(text: str):
66
+ model, preprocess, tokenizer = get_model_preprocess_tokenizer()
67
+ text_tokenized = tokenizer([text])
68
+ with torch.no_grad():
69
+ text_features = model.encode_text(text_tokenized) # type: ignore
70
+ text_features /= text_features.norm(dim=-1, keepdim=True)
71
+ # tensor to list
72
+ return text_features[0].tolist()
73
+
74
+
75
  def app():
76
+ _, _, _ = get_model_preprocess_tokenizer() # for cache
77
  st.title("secon.dev site search")
78
+ search_text = st.text_input("Search", key="search_text")
79
+ if search_text:
80
+ st.write("searching...")
81
+ start = time()
82
+ qdrant_client = get_qdrant_client()
83
+ text_features = get_text_features(search_text)
84
+ search_results = qdrant_client.search(
85
+ collection_name="images-clip",
86
+ query_vector=text_features,
87
+ limit=20,
88
+ )
89
+ elapsed = time() - start
90
+ st.write(f"elapsed: {elapsed:.2f} sec")
91
+ st.write(f"total: {len(search_results)}")
92
+ for r in search_results:
93
+ score = r.score
94
+ if payload := r.payload:
95
+ name = payload["name"]
96
+ else:
97
+ name = "unknown"
98
+ image_url = get_image_url(name, image_type="xsmall")
99
+ st.write(f"score: {score:.2f}")
100
+ st.image(image_url, width=200)
101
+ st.write("---")
102
 
103
 
104
  if __name__ == "__main__":
105
+ st.set_page_config(
106
+ layout="wide", page_icon="https://secon.dev/images/profile_usa.png"
107
+ )
108
  app()