JoJosmin commited on
Commit
ccc944f
1 Parent(s): eff6d6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -43
app.py CHANGED
@@ -8,6 +8,7 @@ import time
8
  import json
9
  import numpy as np
10
  import onnxruntime as ort
 
11
  import cv2
12
  import chromadb
13
 
@@ -22,13 +23,11 @@ def load_clip_model():
22
 
23
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
24
 
25
- # ONNX 모델 로드
26
  @st.cache_resource
27
- def load_onnx_model():
28
- session = ort.InferenceSession("./accessary_weights.onnx")
29
- return session
30
 
31
- onnx_session = load_onnx_model()
32
 
33
  # URL에서 이미지 로드
34
  def load_image_from_url(url, max_retries=3):
@@ -48,7 +47,6 @@ def load_image_from_url(url, max_retries=3):
48
  client = chromadb.PersistentClient(path="./accessaryDB")
49
  collection = client.get_collection(name="accessary_items_ver2")
50
 
51
- # CLIP 이미지 임베딩 추출
52
  def get_image_embedding(image):
53
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
54
  with torch.no_grad():
@@ -56,7 +54,6 @@ def get_image_embedding(image):
56
  image_features /= image_features.norm(dim=-1, keepdim=True)
57
  return image_features.cpu().numpy()
58
 
59
- # CLIP 텍스트 임베딩 추출
60
  def get_text_embedding(text):
61
  text_tokens = tokenizer([text]).to(device)
62
  with torch.no_grad():
@@ -64,17 +61,14 @@ def get_text_embedding(text):
64
  text_features /= text_features.norm(dim=-1, keepdim=True)
65
  return text_features.cpu().numpy()
66
 
67
- # 컬렉션에서 모든 임베딩 가져오기
68
  def get_all_embeddings_from_collection(collection):
69
  all_embeddings = collection.get(include=['embeddings'])['embeddings']
70
  return np.array(all_embeddings)
71
 
72
- # ID를 통해 메타데이터 가져오기
73
  def get_metadata_from_ids(collection, ids):
74
  results = collection.get(ids=ids)
75
  return results['metadatas']
76
-
77
- # 유사 이미지 찾기
78
  def find_similar_images(query_embedding, collection, top_k=5):
79
  database_embeddings = get_all_embeddings_from_collection(collection)
80
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
@@ -92,40 +86,19 @@ def find_similar_images(query_embedding, collection, top_k=5):
92
  })
93
  return results
94
 
95
- onnx_model_labels = ['Bracelets', 'Broches', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']
96
-
97
- # ONNX 모델에 맞춘 전처리 함수
98
- def preprocess_for_onnx(image, input_size=(640, 640)):
99
- resized_image = image.resize(input_size)
100
- image_np = np.array(resized_image).astype(np.float32) / 255.0
101
- image_np = np.transpose(image_np, (2, 0, 1))
102
- input_tensor = np.expand_dims(image_np, axis=0)
103
- return input_tensor
104
-
105
- # 의류 탐지
106
- def detect_clothing_onnx(image):
107
- input_tensor = preprocess_for_onnx(image) # 전처리 함수 호출
108
- outputs = onnx_session.run(None, {onnx_session.get_inputs()[0].name: input_tensor})
109
-
110
- detections = outputs[0] # 첫 번째 출력값이 탐지 결과라고 가정
111
  categories = []
112
-
113
  for detection in detections:
114
- # detection에서 필요한 추출
115
- x1, y1, x2, y2, conf, cls = detection[:6]
116
-
117
- # conf가 배열인 경우, 최대 값을 사용
118
- if isinstance(conf, np.ndarray):
119
- conf = np.max(conf) # 배열에서 최대 신뢰도 값
120
-
121
- if conf > 0.3: # 신뢰도 임계값 설정
122
- category = onnx_model_labels[int(cls)]
123
  categories.append({
124
  'category': category,
125
- 'bbox': [x1, y1, x2, y2],
126
  'confidence': conf
127
  })
128
-
129
  return categories
130
 
131
  # 이미지 자르기
@@ -143,21 +116,21 @@ if 'selected_category' not in st.session_state:
143
  st.session_state.selected_category = None
144
 
145
  # Streamlit app
146
- st.title("Advanced Fashion Search App")
147
 
148
  # 단계별 처리
149
  if st.session_state.step == 'input':
150
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
151
- if st.button("Detect Clothing"):
152
  if st.session_state.query_image_url:
153
  query_image = load_image_from_url(st.session_state.query_image_url)
154
  if query_image is not None:
155
  st.session_state.query_image = query_image
156
- st.session_state.detections = detect_clothing_onnx(query_image)
157
  if st.session_state.detections:
158
  st.session_state.step = 'select_category'
159
  else:
160
- st.warning("No clothing items detected in the image.")
161
  else:
162
  st.error("Failed to load the image. Please try another URL.")
163
  else:
 
8
  import json
9
  import numpy as np
10
  import onnxruntime as ort
11
+ from ultralytics import YOLO
12
  import cv2
13
  import chromadb
14
 
 
23
 
24
  clip_model, preprocess_val, tokenizer, device = load_clip_model()
25
 
 
26
  @st.cache_resource
27
+ def load_yolo_model():
28
+ return YOLO("./best.pt")
 
29
 
30
+ yolo_model = load_yolo_model()
31
 
32
  # URL에서 이미지 로드
33
  def load_image_from_url(url, max_retries=3):
 
47
  client = chromadb.PersistentClient(path="./accessaryDB")
48
  collection = client.get_collection(name="accessary_items_ver2")
49
 
 
50
  def get_image_embedding(image):
51
  image_tensor = preprocess_val(image).unsqueeze(0).to(device)
52
  with torch.no_grad():
 
54
  image_features /= image_features.norm(dim=-1, keepdim=True)
55
  return image_features.cpu().numpy()
56
 
 
57
  def get_text_embedding(text):
58
  text_tokens = tokenizer([text]).to(device)
59
  with torch.no_grad():
 
61
  text_features /= text_features.norm(dim=-1, keepdim=True)
62
  return text_features.cpu().numpy()
63
 
 
64
  def get_all_embeddings_from_collection(collection):
65
  all_embeddings = collection.get(include=['embeddings'])['embeddings']
66
  return np.array(all_embeddings)
67
 
 
68
  def get_metadata_from_ids(collection, ids):
69
  results = collection.get(ids=ids)
70
  return results['metadatas']
71
+
 
72
  def find_similar_images(query_embedding, collection, top_k=5):
73
  database_embeddings = get_all_embeddings_from_collection(collection)
74
  similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
 
86
  })
87
  return results
88
 
89
+ def detect_clothing(image):
90
+ results = yolo_model(image)
91
+ detections = results[0].boxes.data.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  categories = []
 
93
  for detection in detections:
94
+ x1, y1, x2, y2, conf, cls = detection
95
+ category = yolo_model.names[int(cls)]
96
+ if category in ['Bracelets', 'Broches', 'bag', 'belt', 'earring', 'maangtika', 'necklace', 'nose ring', 'ring', 'tiara']:
 
 
 
 
 
 
97
  categories.append({
98
  'category': category,
99
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
100
  'confidence': conf
101
  })
 
102
  return categories
103
 
104
  # 이미지 자르기
 
116
  st.session_state.selected_category = None
117
 
118
  # Streamlit app
119
+ st.title("Accessary Search App")
120
 
121
  # 단계별 처리
122
  if st.session_state.step == 'input':
123
  st.session_state.query_image_url = st.text_input("Enter image URL:", st.session_state.query_image_url)
124
+ if st.button("Detect acsseary"):
125
  if st.session_state.query_image_url:
126
  query_image = load_image_from_url(st.session_state.query_image_url)
127
  if query_image is not None:
128
  st.session_state.query_image = query_image
129
+ st.session_state.detections = detect_clothing(query_image)
130
  if st.session_state.detections:
131
  st.session_state.step = 'select_category'
132
  else:
133
+ st.warning("No items detected in the image.")
134
  else:
135
  st.error("Failed to load the image. Please try another URL.")
136
  else: