Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import time
|
|
8 |
import json
|
9 |
import numpy as np
|
10 |
import onnxruntime as ort
|
|
|
11 |
import cv2
|
12 |
import chromadb
|
13 |
|
@@ -22,13 +23,11 @@ def load_clip_model():
|
|
22 |
|
23 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
24 |
|
25 |
-
# ONNX 모델 로드
|
26 |
@st.cache_resource
|
27 |
-
def
|
28 |
-
|
29 |
-
return session
|
30 |
|
31 |
-
|
32 |
|
33 |
# URL에서 이미지 로드
|
34 |
def load_image_from_url(url, max_retries=3):
|
@@ -48,7 +47,6 @@ def load_image_from_url(url, max_retries=3):
|
|
48 |
client = chromadb.PersistentClient(path="./accessaryDB")
|
49 |
collection = client.get_collection(name="accessary_items_ver2")
|
50 |
|
51 |
-
# CLIP 이미지 임베딩 추출
|
52 |
def get_image_embedding(image):
|
53 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
54 |
with torch.no_grad():
|
@@ -56,7 +54,6 @@ def get_image_embedding(image):
|
|
56 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
57 |
return image_features.cpu().numpy()
|
58 |
|
59 |
-
# CLIP 텍스트 임베딩 추출
|
60 |
def get_text_embedding(text):
|
61 |
text_tokens = tokenizer([text]).to(device)
|
62 |
with torch.no_grad():
|
@@ -64,17 +61,14 @@ def get_text_embedding(text):
|
|
64 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
65 |
return text_features.cpu().numpy()
|
66 |
|
67 |
-
# 컬렉션에서 모든 임베딩 가져오기
|
68 |
def get_all_embeddings_from_collection(collection):
|
69 |
all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
70 |
return np.array(all_embeddings)
|
71 |
|
72 |
-
# ID를 통해 메타데이터 가져오기
|
73 |
def get_metadata_from_ids(collection, ids):
|
74 |
results = collection.get(ids=ids)
|
75 |
return results['metadatas']
|
76 |
-
|
77 |
-
# 유사 이미지 찾기
|
78 |
def find_similar_images(query_embedding, collection, top_k=5):
|
79 |
database_embeddings = get_all_embeddings_from_collection(collection)
|
80 |
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
@@ -92,40 +86,19 @@ def find_similar_images(query_embedding, collection, top_k=5):
|
|
92 |
})
|
93 |
return results
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
def preprocess_for_onnx(image, input_size=(640, 640)):
|
99 |
-
resized_image = image.resize(input_size)
|
100 |
-
image_np = np.array(resized_image).astype(np.float32) / 255.0
|
101 |
-
image_np = np.transpose(image_np, (2, 0, 1))
|
102 |
-
input_tensor = np.expand_dims(image_np, axis=0)
|
103 |
-
return input_tensor
|
104 |
-
|
105 |
-
# 의류 탐지
|
106 |
-
def detect_clothing_onnx(image):
|
107 |
-
input_tensor = preprocess_for_onnx(image) # 전처리 함수 호출
|
108 |
-
outputs = onnx_session.run(None, {onnx_session.get_inputs()[0].name: input_tensor})
|
109 |
-
|
110 |
-
detections = outputs[0] # 첫 번째 출력값이 탐지 결과라고 가정
|
111 |
categories = []
|
112 |
-
|
113 |
for detection in detections:
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
# conf가 배열인 경우, 최대 값을 사용
|
118 |
-
if isinstance(conf, np.ndarray):
|
119 |
-
conf = np.max(conf) # 배열에서 최대 신뢰도 값
|
120 |
-
|
121 |
-
if conf > 0.3: # 신뢰도 임계값 설정
|
122 |
-
category = onnx_model_labels[int(cls)]
|
123 |
categories.append({
|
124 |
'category': category,
|
125 |
-
'bbox': [x1, y1, x2, y2],
|
126 |
'confidence': conf
|
127 |
})
|
128 |
-
|
129 |
return categories
|
130 |
|
131 |
# 이미지 자르기
|
@@ -143,21 +116,21 @@ if 'selected_category' not in st.session_state:
|
|
143 |
st.session_state.selected_category = None
|
144 |
|
145 |
# Streamlit app
|
146 |
-
st.title("
|
147 |
|
148 |
# 단계별 처리
|
149 |
if st.session_state.step == 'input':
|
150 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
151 |
-
if st.button("Detect
|
152 |
if st.session_state.query_image_url:
|
153 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
154 |
if query_image is not None:
|
155 |
st.session_state.query_image = query_image
|
156 |
-
st.session_state.detections =
|
157 |
if st.session_state.detections:
|
158 |
st.session_state.step = 'select_category'
|
159 |
else:
|
160 |
-
st.warning("No
|
161 |
else:
|
162 |
st.error("Failed to load the image. Please try another URL.")
|
163 |
else:
|
|
|
8 |
import json
|
9 |
import numpy as np
|
10 |
import onnxruntime as ort
|
11 |
+
from ultralytics import YOLO
|
12 |
import cv2
|
13 |
import chromadb
|
14 |
|
|
|
23 |
|
24 |
clip_model, preprocess_val, tokenizer, device = load_clip_model()
|
25 |
|
|
|
26 |
@st.cache_resource
|
27 |
+
def load_yolo_model():
|
28 |
+
return YOLO("./best.pt")
|
|
|
29 |
|
30 |
+
yolo_model = load_yolo_model()
|
31 |
|
32 |
# URL에서 이미지 로드
|
33 |
def load_image_from_url(url, max_retries=3):
|
|
|
47 |
client = chromadb.PersistentClient(path="./accessaryDB")
|
48 |
collection = client.get_collection(name="accessary_items_ver2")
|
49 |
|
|
|
50 |
def get_image_embedding(image):
|
51 |
image_tensor = preprocess_val(image).unsqueeze(0).to(device)
|
52 |
with torch.no_grad():
|
|
|
54 |
image_features /= image_features.norm(dim=-1, keepdim=True)
|
55 |
return image_features.cpu().numpy()
|
56 |
|
|
|
57 |
def get_text_embedding(text):
|
58 |
text_tokens = tokenizer([text]).to(device)
|
59 |
with torch.no_grad():
|
|
|
61 |
text_features /= text_features.norm(dim=-1, keepdim=True)
|
62 |
return text_features.cpu().numpy()
|
63 |
|
|
|
64 |
def get_all_embeddings_from_collection(collection):
|
65 |
all_embeddings = collection.get(include=['embeddings'])['embeddings']
|
66 |
return np.array(all_embeddings)
|
67 |
|
|
|
68 |
def get_metadata_from_ids(collection, ids):
|
69 |
results = collection.get(ids=ids)
|
70 |
return results['metadatas']
|
71 |
+
|
|
|
72 |
def find_similar_images(query_embedding, collection, top_k=5):
|
73 |
database_embeddings = get_all_embeddings_from_collection(collection)
|
74 |
similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
|
|
|
86 |
})
|
87 |
return results
|
88 |
|
89 |
+
def detect_clothing(image):
|
90 |
+
results = yolo_model(image)
|
91 |
+
detections = results[0].boxes.data.cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
categories = []
|
|
|
93 |
for detection in detections:
|
94 |
+
x1, y1, x2, y2, conf, cls = detection
|
95 |
+
category = yolo_model.names[int(cls)]
|
96 |
+
if category in ['Bracelets', 'Broches', 'bag', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']:
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
categories.append({
|
98 |
'category': category,
|
99 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
100 |
'confidence': conf
|
101 |
})
|
|
|
102 |
return categories
|
103 |
|
104 |
# 이미지 자르기
|
|
|
116 |
st.session_state.selected_category = None
|
117 |
|
118 |
# Streamlit app
|
119 |
+
st.title("Accessary Search App")
|
120 |
|
121 |
# 단계별 처리
|
122 |
if st.session_state.step == 'input':
|
123 |
st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
|
124 |
+
if st.button("Detect acsseary"):
|
125 |
if st.session_state.query_image_url:
|
126 |
query_image = load_image_from_url(st.session_state.query_image_url)
|
127 |
if query_image is not None:
|
128 |
st.session_state.query_image = query_image
|
129 |
+
st.session_state.detections = detect_clothing(query_image)
|
130 |
if st.session_state.detections:
|
131 |
st.session_state.step = 'select_category'
|
132 |
else:
|
133 |
+
st.warning("No items detected in the image.")
|
134 |
else:
|
135 |
st.error("Failed to load the image. Please try another URL.")
|
136 |
else:
|