Spaces:
Runtime error
Runtime error
Initial code
Browse files- .gitattributes +33 -0
- .github/workflows/hf_sync.yml +20 -0
- .gitignore +2 -0
- README.md +10 -0
- app.py +80 -0
- images.db +3 -0
- requirements.txt +2 -0
- text.db +3 -0
- utils/config.py +1 -0
- utils/frontend.py +10 -0
- utils/haystack.py +79 -0
.gitattributes
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bin.* filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
20 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
26 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.db filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.faiss filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
.github/workflows/hf_sync.yml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v2
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
lfs: true
|
17 |
+
- name: Push to hub
|
18 |
+
env:
|
19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
20 |
+
run: git push --force https://Tuana:[email protected]/spaces/Tuana/find-the-animal main
|
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.DS_Store
|
README.md
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: MultiModalRetrival for Image Search
|
3 |
+
emoji: 😽
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.2.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
app.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
from json import JSONDecodeError
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
|
8 |
+
from utils.haystack import query
|
9 |
+
from utils.frontend import reset_results, set_state_if_absent
|
10 |
+
|
11 |
+
def main():
|
12 |
+
|
13 |
+
set_state_if_absent("statement", "What is the fastest animal?")
|
14 |
+
set_state_if_absent("results", None)
|
15 |
+
|
16 |
+
st.write("# Look for images with MultiModalRetrieval 🐅")
|
17 |
+
st.write()
|
18 |
+
st.markdown(
|
19 |
+
"""
|
20 |
+
##### Enter a question about animals
|
21 |
+
"""
|
22 |
+
)
|
23 |
+
# Search bar
|
24 |
+
statement = st.text_input(
|
25 |
+
"", value=st.session_state.statement, max_chars=100, on_change=reset_results
|
26 |
+
)
|
27 |
+
|
28 |
+
col1, col2 = st.columns(2)
|
29 |
+
col1.markdown(
|
30 |
+
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
|
31 |
+
)
|
32 |
+
|
33 |
+
run_pressed = col1.button("Run")
|
34 |
+
|
35 |
+
run_query = (
|
36 |
+
run_pressed or statement != st.session_state.statement
|
37 |
+
)
|
38 |
+
|
39 |
+
# Get results for query
|
40 |
+
if run_query and statement:
|
41 |
+
time_start = time.time()
|
42 |
+
reset_results()
|
43 |
+
st.session_state.statement = statement
|
44 |
+
with st.spinner("🧠 Performing neural search on documents..."):
|
45 |
+
try:
|
46 |
+
docs = query(statement)
|
47 |
+
st.write(docs["documents"])
|
48 |
+
for doc in docs["documents"]:
|
49 |
+
image = Image.open(doc.content)
|
50 |
+
st.image(image)
|
51 |
+
for answer in docs["answers"]:
|
52 |
+
st.write(answer)
|
53 |
+
print(f"S: {statement}")
|
54 |
+
time_end = time.time()
|
55 |
+
print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
|
56 |
+
print(f"elapsed time: {time_end - time_start}")
|
57 |
+
except JSONDecodeError as je:
|
58 |
+
st.error(
|
59 |
+
"👓 An error occurred reading the results. Is the document store working?"
|
60 |
+
)
|
61 |
+
return
|
62 |
+
except Exception as e:
|
63 |
+
logging.exception(e)
|
64 |
+
st.error("🐞 An error occurred during the request.")
|
65 |
+
return
|
66 |
+
|
67 |
+
# if st.session_state.results:
|
68 |
+
# st.write("Got some results")
|
69 |
+
# print("GOT RESTULTS")
|
70 |
+
# st.write("Received Results")
|
71 |
+
# results = st.session_state.results
|
72 |
+
# print(results)
|
73 |
+
# docs = results["documents"]
|
74 |
+
# st.write(results)
|
75 |
+
# # show different messages depending on entailment results
|
76 |
+
# for doc in docs:
|
77 |
+
# image = Image(filename=doc.content)
|
78 |
+
# st.image(image)
|
79 |
+
|
80 |
+
main()
|
images.db
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1df1fbb3cd45b562b6561acc3b159ea637ee538bc2f8ce2c59fa959dbc7b2538
|
3 |
+
size 200704
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
farm-haystack[faiss]==1.11.1
|
2 |
+
streamlit==1.12.0
|
text.db
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0fa05449dc61ca9bb83ad4053fb8797eb0b4deba7a31e2b2b15d0d73f97c3095
|
3 |
+
size 4464640
|
utils/config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
INDEX_DIR = "data/index"
|
utils/frontend.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
def set_state_if_absent(key, value):
|
4 |
+
if key not in st.session_state:
|
5 |
+
st.session_state[key] = value
|
6 |
+
|
7 |
+
def reset_results(*args):
|
8 |
+
st.write("Called reset")
|
9 |
+
st.session_state.answer = None
|
10 |
+
st.session_state.results = None
|
utils/haystack.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from haystack.document_stores import FAISSDocumentStore
|
3 |
+
from haystack.nodes.retriever import EmbeddingRetriever, MultiModalRetriever
|
4 |
+
from haystack.nodes.reader import FARMReader
|
5 |
+
from haystack import Pipeline
|
6 |
+
from utils.config import (INDEX_DIR)
|
7 |
+
from typing import List
|
8 |
+
from haystack import BaseComponent, Answer
|
9 |
+
import streamlit as st
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class AnswerToQuery(BaseComponent):
|
14 |
+
|
15 |
+
outgoing_edges = 1
|
16 |
+
|
17 |
+
def run(self, query: str, answers: List[Answer]):
|
18 |
+
return {"query": answers[0].answer}, "output_1"
|
19 |
+
|
20 |
+
def run_batch(self):
|
21 |
+
raise NotImplementedError()
|
22 |
+
|
23 |
+
# cached to make index and models load only at start
|
24 |
+
@st.cache(
|
25 |
+
hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
|
26 |
+
)
|
27 |
+
def start_haystack():
|
28 |
+
"""
|
29 |
+
load document store, retriever, entailment checker and create pipeline
|
30 |
+
"""
|
31 |
+
shutil.copy(f"{INDEX_DIR}/text.db", ".")
|
32 |
+
shutil.copy(f"{INDEX_DIR}/images.db", ".")
|
33 |
+
|
34 |
+
document_store_text = FAISSDocumentStore(
|
35 |
+
faiss_index_path=f"{INDEX_DIR}/text.faiss",
|
36 |
+
faiss_config_path=f"{INDEX_DIR}/text.json",
|
37 |
+
)
|
38 |
+
|
39 |
+
document_store_images = FAISSDocumentStore(
|
40 |
+
faiss_index_path=f"{INDEX_DIR}/images.faiss",
|
41 |
+
faiss_config_path=f"{INDEX_DIR}/images.json",
|
42 |
+
)
|
43 |
+
retriever_text = EmbeddingRetriever(
|
44 |
+
document_store=document_store_text,
|
45 |
+
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
|
46 |
+
model_format="sentence_transformers",
|
47 |
+
)
|
48 |
+
|
49 |
+
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
|
50 |
+
|
51 |
+
|
52 |
+
retriever_images = MultiModalRetriever(
|
53 |
+
document_store=document_store_images,
|
54 |
+
query_embedding_model = "sentence-transformers/clip-ViT-B-32",
|
55 |
+
query_type="text",
|
56 |
+
document_embedding_models = {
|
57 |
+
"image": "sentence-transformers/clip-ViT-B-32"
|
58 |
+
}
|
59 |
+
)
|
60 |
+
|
61 |
+
answer_to_query = AnswerToQuery()
|
62 |
+
|
63 |
+
pipe = Pipeline()
|
64 |
+
|
65 |
+
pipe.add_node(retriever_text, name="text_retriever", inputs=["Query"])
|
66 |
+
pipe.add_node(reader, name="text_reader", inputs=["text_retriever"])
|
67 |
+
pipe.add_node(answer_to_query, name="answer2query", inputs=["text_reader"])
|
68 |
+
pipe.add_node(retriever_images, name="image_retriever", inputs=["answer2query"])
|
69 |
+
|
70 |
+
return pipe
|
71 |
+
|
72 |
+
pipe = start_haystack()
|
73 |
+
|
74 |
+
@st.cache(allow_output_mutation=True)
|
75 |
+
def query(statement: str, text_retriever_top_k: int = 5, image_retriever_top_k = 1):
|
76 |
+
"""Run query and verify statement"""
|
77 |
+
params = {"image_retriever": {"top_k": image_retriever_top_k},"text_retriever": {"top_k": text_retriever_top_k} }
|
78 |
+
results = pipe.run(statement, params=params)
|
79 |
+
return results
|