Spaces:
Runtime error
Runtime error
Vivien
commited on
Commit
•
563e3ef
1
Parent(s):
8d4b675
Initial commit
Browse files- README.md +2 -2
- app.py +294 -0
- data.csv +0 -0
- data2.csv +0 -0
- embeddings-vit-base-patch32.npy +3 -0
- embeddings-vit-large-patch14-336.npy +3 -0
- embeddings2-vit-base-patch32.npy +3 -0
- embeddings2-vit-large-patch14-336.npy +3 -0
- requirements.txt +7 -0
README.md
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
6 |
sdk: streamlit
|
|
|
1 |
---
|
2 |
+
title: Search and Detect (CLIP/Owl-ViT)
|
3 |
+
emoji: 🦉
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
6 |
sdk: streamlit
|
app.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from html import escape
|
2 |
+
import requests
|
3 |
+
from io import BytesIO
|
4 |
+
import base64
|
5 |
+
from multiprocessing.dummy import Pool
|
6 |
+
from PIL import Image, ImageDraw
|
7 |
+
import streamlit as st
|
8 |
+
import pandas as pd, numpy as np
|
9 |
+
import torch
|
10 |
+
from transformers import CLIPProcessor, CLIPModel
|
11 |
+
from transformers import OwlViTProcessor, OwlViTForObjectDetection
|
12 |
+
from transformers.image_utils import ImageFeatureExtractionMixin
|
13 |
+
import tokenizers
|
14 |
+
|
15 |
+
DEBUG = True
|
16 |
+
if DEBUG:
|
17 |
+
MODEL = "vit-base-patch32"
|
18 |
+
OWL_MODEL = f"google/owlvit-base-patch32"
|
19 |
+
else:
|
20 |
+
MODEL = "vit-large-patch14-336"
|
21 |
+
OWL_MODEL = f"google/owlvit-large-path14"
|
22 |
+
CLIP_MODEL = f"openai/clip-{MODEL}"
|
23 |
+
|
24 |
+
if not DEBUG and torch.cuda.is_available():
|
25 |
+
device = torch.device("cuda")
|
26 |
+
else:
|
27 |
+
device = torch.device("cpu")
|
28 |
+
|
29 |
+
HEIGHT = 200
|
30 |
+
N_RESULTS = 6
|
31 |
+
|
32 |
+
color = st.get_option("theme.primaryColor")
|
33 |
+
if color is None:
|
34 |
+
color = (255, 75, 75)
|
35 |
+
else:
|
36 |
+
color = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4))
|
37 |
+
|
38 |
+
|
39 |
+
@st.cache(allow_output_mutation=True)
|
40 |
+
def load():
|
41 |
+
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
|
42 |
+
clip_model = CLIPModel.from_pretrained(CLIP_MODEL)
|
43 |
+
clip_model.to(device)
|
44 |
+
clip_model.eval()
|
45 |
+
clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
|
46 |
+
owl_model = OwlViTForObjectDetection.from_pretrained(OWL_MODEL)
|
47 |
+
owl_model.to(device)
|
48 |
+
owl_model.eval()
|
49 |
+
owl_processor = OwlViTProcessor.from_pretrained(OWL_MODEL)
|
50 |
+
embeddings = {
|
51 |
+
0: np.load(f"embeddings-{MODEL}.npy"),
|
52 |
+
1: np.load(f"embeddings2-{MODEL}.npy"),
|
53 |
+
}
|
54 |
+
for k in [0, 1]:
|
55 |
+
embeddings[k] = embeddings[k] / np.linalg.norm(
|
56 |
+
embeddings[k], axis=1, keepdims=True
|
57 |
+
)
|
58 |
+
return clip_model, clip_processor, owl_model, owl_processor, df, embeddings
|
59 |
+
|
60 |
+
|
61 |
+
clip_model, clip_processor, owl_model, owl_processor, df, embeddings = load()
|
62 |
+
mixin = ImageFeatureExtractionMixin()
|
63 |
+
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
|
64 |
+
|
65 |
+
|
66 |
+
def compute_text_embeddings(list_of_strings):
|
67 |
+
inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to(
|
68 |
+
device
|
69 |
+
)
|
70 |
+
with torch.no_grad():
|
71 |
+
result = clip_model.get_text_features(**inputs).detach().cpu().numpy()
|
72 |
+
return result / np.linalg.norm(result, axis=1, keepdims=True)
|
73 |
+
|
74 |
+
|
75 |
+
def image_search(query, corpus, n_results=N_RESULTS):
|
76 |
+
query_embedding = compute_text_embeddings([query])
|
77 |
+
corpus_id = 0 if corpus == "Unsplash" else 1
|
78 |
+
dot_product = (embeddings[corpus_id] @ query_embedding.T)[:, 0]
|
79 |
+
results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
|
80 |
+
return [
|
81 |
+
(
|
82 |
+
df[corpus_id].iloc[i].path,
|
83 |
+
df[corpus_id].iloc[i].tooltip + source[corpus_id],
|
84 |
+
df[corpus_id].iloc[i].link,
|
85 |
+
)
|
86 |
+
for i in results
|
87 |
+
]
|
88 |
+
|
89 |
+
|
90 |
+
def make_square(img, fill_color=(255, 255, 255)):
|
91 |
+
x, y = img.size
|
92 |
+
size = max(x, y)
|
93 |
+
new_img = Image.new("RGB", (size, size), fill_color)
|
94 |
+
new_img.paste(img, (int((size - x) / 2), int((size - y) / 2)))
|
95 |
+
return new_img, x, y
|
96 |
+
|
97 |
+
|
98 |
+
@st.cache(allow_output_mutation=True, show_spinner=False)
|
99 |
+
def get_images(paths):
|
100 |
+
def process_image(path):
|
101 |
+
return make_square(Image.open(BytesIO(requests.get(path).content)))
|
102 |
+
|
103 |
+
processed = Pool(N_RESULTS).map(process_image, paths)
|
104 |
+
imgs, xs, ys = [], [], []
|
105 |
+
for img, x, y in processed:
|
106 |
+
imgs.append(img)
|
107 |
+
xs.append(x)
|
108 |
+
ys.append(y)
|
109 |
+
return imgs, xs, ys
|
110 |
+
|
111 |
+
|
112 |
+
@st.cache(
|
113 |
+
hash_funcs={
|
114 |
+
tokenizers.Tokenizer: lambda x: None,
|
115 |
+
tokenizers.AddedToken: lambda x: None,
|
116 |
+
torch.nn.parameter.Parameter: lambda x: None,
|
117 |
+
},
|
118 |
+
allow_output_mutation=True,
|
119 |
+
show_spinner=False,
|
120 |
+
)
|
121 |
+
def apply_owl_model(owl_queries, images):
|
122 |
+
inputs = owl_processor(text=owl_queries, images=images, return_tensors="pt").to(
|
123 |
+
device
|
124 |
+
)
|
125 |
+
with torch.no_grad():
|
126 |
+
results = owl_model(**inputs)
|
127 |
+
target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
|
128 |
+
return owl_processor.post_process(outputs=results, target_sizes=target_sizes)
|
129 |
+
|
130 |
+
|
131 |
+
def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8):
|
132 |
+
candidates = []
|
133 |
+
for box, score in zip(boxes, scores):
|
134 |
+
box = [round(i, 0) for i in box.tolist()]
|
135 |
+
if score >= score_threshold:
|
136 |
+
candidates.append((box, float(score)))
|
137 |
+
|
138 |
+
to_ignore = set()
|
139 |
+
for i in range(len(candidates) - 1):
|
140 |
+
if i in to_ignore:
|
141 |
+
continue
|
142 |
+
for j in range(i + 1, len(candidates)):
|
143 |
+
if j in to_ignore:
|
144 |
+
continue
|
145 |
+
xmin1, ymin1, xmax1, ymax1 = candidates[i][0]
|
146 |
+
xmin2, ymin2, xmax2, ymax2 = candidates[j][0]
|
147 |
+
if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1:
|
148 |
+
continue
|
149 |
+
else:
|
150 |
+
xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3]
|
151 |
+
ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3]
|
152 |
+
area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter)
|
153 |
+
area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
|
154 |
+
area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
|
155 |
+
iou = area_inter / (area1 + area2 - area_inter)
|
156 |
+
if iou > max_iou:
|
157 |
+
if candidates[i][1] > candidates[j][1]:
|
158 |
+
to_ignore.add(j)
|
159 |
+
else:
|
160 |
+
to_ignore.add(i)
|
161 |
+
break
|
162 |
+
else:
|
163 |
+
if area_inter / area1 > 0.9:
|
164 |
+
if candidates[i][1] < 1.1 * candidates[j][1]:
|
165 |
+
to_ignore.add(i)
|
166 |
+
if area_inter / area2 > 0.9:
|
167 |
+
if 1.1 * candidates[i][1] > candidates[j][1]:
|
168 |
+
to_ignore.add(j)
|
169 |
+
return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore]
|
170 |
+
|
171 |
+
|
172 |
+
def convert_pil_to_base64(image):
|
173 |
+
img_buffer = BytesIO()
|
174 |
+
image.save(img_buffer, format="JPEG")
|
175 |
+
byte_data = img_buffer.getvalue()
|
176 |
+
base64_str = base64.b64encode(byte_data)
|
177 |
+
return base64_str
|
178 |
+
|
179 |
+
|
180 |
+
def draw_reshape_encode(img, boxes, x, y):
|
181 |
+
image = img.copy()
|
182 |
+
draw = ImageDraw.Draw(image)
|
183 |
+
new_x, new_y = int(x * HEIGHT / y), HEIGHT
|
184 |
+
for box in boxes:
|
185 |
+
draw.rectangle(
|
186 |
+
(tuple(box[:2]), tuple(box[2:])), outline=color, width=2 * int(y / HEIGHT)
|
187 |
+
)
|
188 |
+
if x > y:
|
189 |
+
image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2))
|
190 |
+
else:
|
191 |
+
image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y))
|
192 |
+
return convert_pil_to_base64(image.resize((new_x, new_y)))
|
193 |
+
|
194 |
+
|
195 |
+
def get_html(url_list, encoded_images):
|
196 |
+
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
|
197 |
+
for i in range(len(url_list)):
|
198 |
+
title, link, encoded = url_list[i][1], url_list[i][2], encoded_images[i]
|
199 |
+
html2 = f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>"
|
200 |
+
if len(link) > 0:
|
201 |
+
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
|
202 |
+
html = html + html2
|
203 |
+
html += "</div>"
|
204 |
+
return html
|
205 |
+
|
206 |
+
|
207 |
+
description = """
|
208 |
+
# Search and Detect
|
209 |
+
|
210 |
+
This demo illustrates how you can both retrieve images containing certain objects and locate these objects with a simple natural language query.
|
211 |
+
|
212 |
+
**Enter your query and hit enter**
|
213 |
+
|
214 |
+
**Tip 1**: if your query includes "/", the part left (resp. right) of "/" will be used to retrieve images (resp. locate objects). For example, if you want to retrieve pictures with several cats but locate individual cats, you can type "cats / cat".
|
215 |
+
|
216 |
+
**Tip 2**: change the score threshold below to adjust the sensitivity of the object detection.
|
217 |
+
|
218 |
+
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model and Google's [Owl-ViT](https://arxiv.org/abs/2205.06230) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
|
219 |
+
|
220 |
+
"""
|
221 |
+
|
222 |
+
div_style = {
|
223 |
+
"display": "flex",
|
224 |
+
"justify-content": "center",
|
225 |
+
"flex-wrap": "wrap",
|
226 |
+
}
|
227 |
+
|
228 |
+
|
229 |
+
def main():
|
230 |
+
st.markdown(
|
231 |
+
"""
|
232 |
+
<style>
|
233 |
+
.block-container{
|
234 |
+
max-width: 1200px;
|
235 |
+
}
|
236 |
+
div.row-widget.stRadio > div{
|
237 |
+
flex-direction:row;
|
238 |
+
display: flex;
|
239 |
+
justify-content: center;
|
240 |
+
}
|
241 |
+
div.row-widget.stRadio > div > label{
|
242 |
+
margin-left: 5px;
|
243 |
+
margin-right: 5px;
|
244 |
+
}
|
245 |
+
.row-widget {
|
246 |
+
margin-top: -25px;
|
247 |
+
}
|
248 |
+
section>div:first-child {
|
249 |
+
padding-top: 30px;
|
250 |
+
}
|
251 |
+
div.reportview-container > section:first-child{
|
252 |
+
max-width: 320px;
|
253 |
+
}
|
254 |
+
#MainMenu {
|
255 |
+
visibility: hidden;
|
256 |
+
}
|
257 |
+
footer {
|
258 |
+
visibility: hidden;
|
259 |
+
}
|
260 |
+
</style>""",
|
261 |
+
unsafe_allow_html=True,
|
262 |
+
)
|
263 |
+
st.sidebar.markdown(description)
|
264 |
+
score_threshold = st.sidebar.slider(
|
265 |
+
"Score threshold", min_value=0.01, max_value=0.3, value=0.1, step=0.01
|
266 |
+
)
|
267 |
+
|
268 |
+
_, c, _ = st.columns((1, 3, 1))
|
269 |
+
query = c.text_input("", value="clouds at sunset")
|
270 |
+
corpus = st.radio("", ["Unsplash", "Movies"])
|
271 |
+
|
272 |
+
if len(query) > 0:
|
273 |
+
if "/" in query:
|
274 |
+
queries = query.split("/")
|
275 |
+
clip_query, owl_query = ("/").join(queries[:-1]), queries[-1]
|
276 |
+
else:
|
277 |
+
clip_query, owl_query = query, query
|
278 |
+
retrieved = image_search(clip_query, corpus)
|
279 |
+
imgs, xs, ys = get_images([x[0] for x in retrieved])
|
280 |
+
results = apply_owl_model([[owl_query]] * len(imgs), imgs)
|
281 |
+
encoded_images = []
|
282 |
+
for image_idx in range(len(imgs)):
|
283 |
+
img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
|
284 |
+
boxes = keep_best_boxes(
|
285 |
+
results[image_idx]["boxes"],
|
286 |
+
results[image_idx]["scores"],
|
287 |
+
score_threshold=score_threshold,
|
288 |
+
)
|
289 |
+
encoded_images.append(draw_reshape_encode(img0, boxes, x, y))
|
290 |
+
st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True)
|
291 |
+
|
292 |
+
|
293 |
+
if __name__ == "__main__":
|
294 |
+
main()
|
data.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data2.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
embeddings-vit-base-patch32.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3f7ebdff24079665faf58d07045056a63b5499753e3ffbda479691d53de3ab38
|
3 |
+
size 51200128
|
embeddings-vit-large-patch14-336.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f79f10ebe267b4ee7acd553dfe0ee31df846123630058a6d58c04bf22e0ad068
|
3 |
+
size 76800128
|
embeddings2-vit-base-patch32.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e7d545bed86121dac1cedcc1de61ea5295f5840c1eb751637e6628ac54faef81
|
3 |
+
size 16732288
|
embeddings2-vit-large-patch14-336.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e66eb377465fbfaa56cec079aa3e214533ceac43646f2ca78028ae4d8ad6d03
|
3 |
+
size 25098368
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|
3 |
+
tokenizers
|
4 |
+
Pillow
|
5 |
+
ftfy
|
6 |
+
numpy
|
7 |
+
pandas
|