Tuana commited on
Commit
75128dd
1 Parent(s): 71a2d59

Initial code

Browse files
Files changed (11) hide show
  1. .gitattributes +33 -0
  2. .github/workflows/hf_sync.yml +20 -0
  3. .gitignore +2 -0
  4. README.md +10 -0
  5. app.py +80 -0
  6. images.db +3 -0
  7. requirements.txt +2 -0
  8. text.db +3 -0
  9. utils/config.py +1 -0
  10. utils/frontend.py +10 -0
  11. utils/haystack.py +79 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.wasm filter=lfs diff=lfs merge=lfs -text
25
+ *.xz filter=lfs diff=lfs merge=lfs -text
26
+ *.zip filter=lfs diff=lfs merge=lfs -text
27
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
28
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
29
+ *.db filter=lfs diff=lfs merge=lfs -text
30
+ *.faiss filter=lfs diff=lfs merge=lfs -text
31
+ *.pdf filter=lfs diff=lfs merge=lfs -text
32
+ *.jpg filter=lfs diff=lfs merge=lfs -text
33
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.github/workflows/hf_sync.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v2
14
+ with:
15
+ fetch-depth: 0
16
+ lfs: true
17
+ - name: Push to hub
18
+ env:
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push --force https://Tuana:[email protected]/spaces/Tuana/find-the-animal main
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ .DS_Store
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MultiModalRetrival for Image Search
3
+ emoji: 😽
4
+ colorFrom: green
5
+ colorTo: yellow
6
+ sdk: streamlit
7
+ sdk_version: 1.2.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import time
3
+ import logging
4
+ from json import JSONDecodeError
5
+ from PIL import Image
6
+
7
+
8
+ from utils.haystack import query
9
+ from utils.frontend import reset_results, set_state_if_absent
10
+
11
+ def main():
12
+
13
+ set_state_if_absent("statement", "What is the fastest animal?")
14
+ set_state_if_absent("results", None)
15
+
16
+ st.write("# Look for images with MultiModalRetrieval 🐅")
17
+ st.write()
18
+ st.markdown(
19
+ """
20
+ ##### Enter a question about animals
21
+ """
22
+ )
23
+ # Search bar
24
+ statement = st.text_input(
25
+ "", value=st.session_state.statement, max_chars=100, on_change=reset_results
26
+ )
27
+
28
+ col1, col2 = st.columns(2)
29
+ col1.markdown(
30
+ "<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True
31
+ )
32
+
33
+ run_pressed = col1.button("Run")
34
+
35
+ run_query = (
36
+ run_pressed or statement != st.session_state.statement
37
+ )
38
+
39
+ # Get results for query
40
+ if run_query and statement:
41
+ time_start = time.time()
42
+ reset_results()
43
+ st.session_state.statement = statement
44
+ with st.spinner("🧠 &nbsp;&nbsp; Performing neural search on documents..."):
45
+ try:
46
+ docs = query(statement)
47
+ st.write(docs["documents"])
48
+ for doc in docs["documents"]:
49
+ image = Image.open(doc.content)
50
+ st.image(image)
51
+ for answer in docs["answers"]:
52
+ st.write(answer)
53
+ print(f"S: {statement}")
54
+ time_end = time.time()
55
+ print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()))
56
+ print(f"elapsed time: {time_end - time_start}")
57
+ except JSONDecodeError as je:
58
+ st.error(
59
+ "👓 &nbsp;&nbsp; An error occurred reading the results. Is the document store working?"
60
+ )
61
+ return
62
+ except Exception as e:
63
+ logging.exception(e)
64
+ st.error("🐞 &nbsp;&nbsp; An error occurred during the request.")
65
+ return
66
+
67
+ # if st.session_state.results:
68
+ # st.write("Got some results")
69
+ # print("GOT RESTULTS")
70
+ # st.write("Received Results")
71
+ # results = st.session_state.results
72
+ # print(results)
73
+ # docs = results["documents"]
74
+ # st.write(results)
75
+ # # show different messages depending on entailment results
76
+ # for doc in docs:
77
+ # image = Image(filename=doc.content)
78
+ # st.image(image)
79
+
80
+ main()
images.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1df1fbb3cd45b562b6561acc3b159ea637ee538bc2f8ce2c59fa959dbc7b2538
3
+ size 200704
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ farm-haystack[faiss]==1.11.1
2
+ streamlit==1.12.0
text.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0fa05449dc61ca9bb83ad4053fb8797eb0b4deba7a31e2b2b15d0d73f97c3095
3
+ size 4464640
utils/config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ INDEX_DIR = "data/index"
utils/frontend.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ def set_state_if_absent(key, value):
4
+ if key not in st.session_state:
5
+ st.session_state[key] = value
6
+
7
+ def reset_results(*args):
8
+ st.write("Called reset")
9
+ st.session_state.answer = None
10
+ st.session_state.results = None
utils/haystack.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from haystack.document_stores import FAISSDocumentStore
3
+ from haystack.nodes.retriever import EmbeddingRetriever, MultiModalRetriever
4
+ from haystack.nodes.reader import FARMReader
5
+ from haystack import Pipeline
6
+ from utils.config import (INDEX_DIR)
7
+ from typing import List
8
+ from haystack import BaseComponent, Answer
9
+ import streamlit as st
10
+
11
+
12
+
13
+ class AnswerToQuery(BaseComponent):
14
+
15
+ outgoing_edges = 1
16
+
17
+ def run(self, query: str, answers: List[Answer]):
18
+ return {"query": answers[0].answer}, "output_1"
19
+
20
+ def run_batch(self):
21
+ raise NotImplementedError()
22
+
23
+ # cached to make index and models load only at start
24
+ @st.cache(
25
+ hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
26
+ )
27
+ def start_haystack():
28
+ """
29
+ load document store, retriever, entailment checker and create pipeline
30
+ """
31
+ shutil.copy(f"{INDEX_DIR}/text.db", ".")
32
+ shutil.copy(f"{INDEX_DIR}/images.db", ".")
33
+
34
+ document_store_text = FAISSDocumentStore(
35
+ faiss_index_path=f"{INDEX_DIR}/text.faiss",
36
+ faiss_config_path=f"{INDEX_DIR}/text.json",
37
+ )
38
+
39
+ document_store_images = FAISSDocumentStore(
40
+ faiss_index_path=f"{INDEX_DIR}/images.faiss",
41
+ faiss_config_path=f"{INDEX_DIR}/images.json",
42
+ )
43
+ retriever_text = EmbeddingRetriever(
44
+ document_store=document_store_text,
45
+ embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
46
+ model_format="sentence_transformers",
47
+ )
48
+
49
+ reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
50
+
51
+
52
+ retriever_images = MultiModalRetriever(
53
+ document_store=document_store_images,
54
+ query_embedding_model = "sentence-transformers/clip-ViT-B-32",
55
+ query_type="text",
56
+ document_embedding_models = {
57
+ "image": "sentence-transformers/clip-ViT-B-32"
58
+ }
59
+ )
60
+
61
+ answer_to_query = AnswerToQuery()
62
+
63
+ pipe = Pipeline()
64
+
65
+ pipe.add_node(retriever_text, name="text_retriever", inputs=["Query"])
66
+ pipe.add_node(reader, name="text_reader", inputs=["text_retriever"])
67
+ pipe.add_node(answer_to_query, name="answer2query", inputs=["text_reader"])
68
+ pipe.add_node(retriever_images, name="image_retriever", inputs=["answer2query"])
69
+
70
+ return pipe
71
+
72
+ pipe = start_haystack()
73
+
74
+ @st.cache(allow_output_mutation=True)
75
+ def query(statement: str, text_retriever_top_k: int = 5, image_retriever_top_k = 1):
76
+ """Run query and verify statement"""
77
+ params = {"image_retriever": {"top_k": image_retriever_top_k},"text_retriever": {"top_k": text_retriever_top_k} }
78
+ results = pipe.run(statement, params=params)
79
+ return results