neonwatty commited on
Commit
db14014
1 Parent(s): bc066dc

Upload 25 files

Browse files
data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
data/dbs/memes.db ADDED
Binary file (28.7 kB). View file
 
data/dbs/memes.faiss ADDED
Binary file (393 kB). View file
 
data/dbs/placeholder ADDED
File without changes
data/input/test_meme_1.jpg ADDED
data/input/test_meme_2.jpg ADDED
data/input/test_meme_3.jpg ADDED
data/input/test_meme_4.jpg ADDED
data/input/test_meme_5.jpg ADDED
data/input/test_meme_6.jpg ADDED
data/input/test_meme_7.jpg ADDED
data/input/test_meme_8.jpg ADDED
data/input/test_meme_9.jpg ADDED
meme_search/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ base_dir = os.path.dirname(os.path.abspath(__file__))
4
+ meme_search_root_dir = os.path.dirname(base_dir)
5
+ abs_dir = "."
6
+
7
+ vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
8
+ sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
meme_search/app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from meme_search import base_dir
3
+ from meme_search.utilities.query import complete_query
4
+ from meme_search.utilities.create import process
5
+ import streamlit as st
6
+
7
+ st.set_page_config(page_title="Meme Search")
8
+
9
+
10
+ # search bar taken from --> https://discuss.streamlit.io/t/creating-a-nicely-formatted-search-field/1804/2
11
+ def local_css(file_name):
12
+ with open(file_name) as f:
13
+ st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
14
+
15
+
16
+ def remote_css(url):
17
+ st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
18
+
19
+
20
+ local_css(base_dir + "/style.css")
21
+ remote_css("https://fonts.googleapis.com/icon?family=Material+Icons")
22
+
23
+ # icon("search")
24
+ with st.container():
25
+ with st.container(border=True):
26
+ input_col, button_col = st.columns([6, 2])
27
+
28
+ with button_col:
29
+ st.empty()
30
+ refresh_index_button = st.button("refresh index", type="primary")
31
+ if refresh_index_button:
32
+ process_start = st.warning("refreshing...")
33
+ val = process()
34
+ if val:
35
+ process_start.empty()
36
+ success = st.success("index updated!")
37
+ time.sleep(2)
38
+ process_start.empty()
39
+ success.empty()
40
+ else:
41
+ process_start.empty()
42
+ warning = st.warning("no refresh needed!")
43
+ time.sleep(2)
44
+ warning.empty()
45
+
46
+ selected = input_col.text_input(label="meme search", placeholder="search for your meme", label_visibility="collapsed")
47
+ if selected:
48
+ results = complete_query(selected)
49
+ img_paths = [v["img_path"] for v in results]
50
+ with st.container(border=True):
51
+ for result in results:
52
+ with st.container(border=True):
53
+ st.image(
54
+ result["img_path"],
55
+ output_format="auto",
56
+ caption=f'{result["full_description"]} (query distance = {result["distance"]})',
57
+ )
meme_search/style.css ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ color: #fff;
3
+ background-color: #4F8BF9;
4
+ }
5
+
6
+ /* .stButton>button {
7
+ color: #4F8BF9;
8
+ border-radius: 50%;
9
+ height: 3em;
10
+ width: 3em;
11
+ } */
12
+
13
+ .stTextInput>div>div>input {
14
+ color: #4F8BF9;
15
+ }
meme_search/utilities/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sentence_transformers import SentenceTransformer
3
+
4
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
5
+ utilities_base_dir = os.path.dirname(os.path.abspath(__file__))
6
+ meme_search_dir = os.path.dirname(utilities_base_dir)
7
+ meme_search_root_dir = os.path.dirname(meme_search_dir)
8
+
9
+ img_dir = meme_search_root_dir + "/data/input/"
10
+ vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
11
+ sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
meme_search/utilities/add.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sqlite3
3
+ import faiss
4
+ from meme_search.utilities import model
5
+ from meme_search.utilities.text_extraction import extract_text_from_imgs
6
+ from meme_search.utilities.chunks import create_all_img_chunks
7
+
8
+
9
+ def add_to_chunk_db(img_chunks: list, sqlite_db_path: str) -> None:
10
+ # Create a lookup table for chunks
11
+ conn = sqlite3.connect(sqlite_db_path)
12
+ cursor = conn.cursor()
13
+
14
+ # Create the table
15
+ cursor.execute("""
16
+ CREATE TABLE IF NOT EXISTS chunks_reverse_lookup (
17
+ img_path TEXT,
18
+ chunk TEXT
19
+ );
20
+ """)
21
+
22
+ # Insert data into the table
23
+ for chunk_index, entry in enumerate(img_chunks):
24
+ img_path = entry["img_path"]
25
+ chunk = entry["chunk"]
26
+ cursor.execute(
27
+ "INSERT INTO chunks_reverse_lookup (img_path, chunk) VALUES (?, ?)",
28
+ (img_path, chunk),
29
+ )
30
+
31
+ conn.commit()
32
+ conn.close()
33
+
34
+
35
+ def add_to_vector_db(chunks: list, vector_db_path: str) -> None:
36
+ # embed inputs
37
+ embeddings = model.encode(chunks)
38
+
39
+ # dump all_embeddings to faiss index
40
+ if os.path.exists(vector_db_path):
41
+ index = faiss.read_index(vector_db_path)
42
+ else:
43
+ index = faiss.IndexFlatL2(embeddings.shape[1])
44
+
45
+ index.add(embeddings)
46
+ faiss.write_index(index, vector_db_path)
47
+
48
+
49
+ def add_to_dbs(img_chunks: list, sqlite_db_path: str, vector_db_path: str) -> None:
50
+ try:
51
+ print("STARTING: add_to_dbs")
52
+
53
+ # add to db for img_chunks
54
+ add_to_chunk_db(img_chunks, sqlite_db_path)
55
+
56
+ # create vector embedding db for chunks
57
+ chunks = [v["chunk"] for v in img_chunks]
58
+ add_to_vector_db(chunks, vector_db_path)
59
+ print("SUCCESS: add_to_dbs succeeded")
60
+ except Exception as e:
61
+ print(f"FAILURE: add_to_dbs failed with exception {e}")
62
+
63
+
64
+ def add(new_imgs_to_be_indexed: list, sqlite_db_path: str, vector_db_path: str) -> None:
65
+ moondream_answers = extract_text_from_imgs(new_imgs_to_be_indexed)
66
+ img_chunks = create_all_img_chunks(new_imgs_to_be_indexed, moondream_answers)
67
+ add_to_dbs(img_chunks, sqlite_db_path, vector_db_path)
meme_search/utilities/chunks.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def clean_word(text: str) -> str:
5
+ # clean input text - keeping only lower case letters, numbers, punctuation, and single quote symbols
6
+ return re.sub(" +", " ", re.compile("[^a-z0-9,.!?']").sub(" ", text.lower().strip()))
7
+
8
+
9
+ def chunk_text(text: str) -> list:
10
+ # split and clean input text
11
+ text_split = clean_word(text).split(" ")
12
+ text_split = [v for v in text_split if len(v) > 0]
13
+
14
+ # use two pointers to create chunks
15
+ chunk_size = 4
16
+ overlap_size = 2
17
+
18
+ # create next chunk by moving right pointer until chunk_size is reached or line_number changes by more than 1 or end of word_sequence is reached
19
+ left_pointer = 0
20
+ right_pointer = chunk_size - 1
21
+ chunks = []
22
+
23
+ if right_pointer >= len(text_split):
24
+ chunks = [" ".join(text_split)]
25
+ else:
26
+ while right_pointer < len(text_split):
27
+ # check if chunk_size has been reached
28
+ # create chunk
29
+ chunk = text_split[left_pointer : right_pointer + 1]
30
+
31
+ # move left pointer
32
+ left_pointer += chunk_size - overlap_size
33
+
34
+ # move right pointer
35
+ right_pointer += chunk_size - overlap_size
36
+
37
+ # store chunk
38
+ chunks.append(" ".join(chunk))
39
+
40
+ # check if there is final chunk
41
+ if len(text_split[left_pointer:]) > 0:
42
+ last_chunk = text_split[left_pointer:]
43
+ chunks.append(" ".join(last_chunk))
44
+
45
+ # insert the full text
46
+ if len(chunks) > 1:
47
+ chunks.insert(0, text.lower())
48
+ return chunks
49
+
50
+
51
+ # loop over each meme's moondream based text descriptor and create a short dict containing its full and chunked text
52
+ def create_all_img_chunks(img_paths: list, answers: list) -> list:
53
+ try:
54
+ print("STARTING: create_all_img_chunks")
55
+ img_chunks = []
56
+ for ind, img_path in enumerate(img_paths):
57
+ moondream_meme_text = answers[ind]
58
+ moondream_chunks = chunk_text(moondream_meme_text)
59
+ for chunk in moondream_chunks:
60
+ entry = {}
61
+ entry["img_path"] = img_path
62
+ entry["chunk"] = chunk
63
+ img_chunks.append(entry)
64
+ print("SUCCESS: create_all_img_chunks ran successfully")
65
+ return img_chunks
66
+ except Exception as e:
67
+ print(f"FAILURE: create_all_img_chunks failed with exception {e}")
68
+ raise e
meme_search/utilities/create.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from meme_search.utilities.status import get_input_directory_status
2
+ from meme_search.utilities.remove import remove
3
+ from meme_search.utilities.add import add
4
+ from meme_search.utilities import img_dir, sqlite_db_path, vector_db_path
5
+
6
+
7
+ def process() -> bool:
8
+ old_imgs_to_be_removed, new_imgs_to_be_indexed = get_input_directory_status(img_dir, sqlite_db_path)
9
+ if len(old_imgs_to_be_removed) == 0 and len(new_imgs_to_be_indexed) == 0:
10
+ return False
11
+ if len(old_imgs_to_be_removed) > 0:
12
+ remove(old_imgs_to_be_removed, sqlite_db_path, vector_db_path)
13
+ if len(new_imgs_to_be_indexed):
14
+ add(new_imgs_to_be_indexed, sqlite_db_path, vector_db_path)
15
+ return True
16
+
17
+
18
+ if __name__ == "__main__":
19
+ process()
meme_search/utilities/imgs.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ allowable_extensions = ["jpg", "jpeg", "png"]
4
+
5
+
6
+ def collect_img_paths(img_dir: str) -> list:
7
+ try:
8
+ print("STARTING: collect_img_paths")
9
+
10
+ all_img_paths = [os.path.join(img_dir, name) for name in os.listdir(img_dir) if name.split(".")[-1] in allowable_extensions]
11
+ all_img_paths = sorted(all_img_paths)
12
+ all_img_paths = ["./data/input/" + v.split("/")[-1] for v in all_img_paths]
13
+
14
+ print(f"SUCCESS: collect_img_paths ran successfully - image paths loaded from '{img_dir}'")
15
+ return all_img_paths
16
+ except Exception as e:
17
+ print(f"FAILURE: collect_img_paths failed with img_dir {img_dir} with exception {e}")
18
+ raise e
meme_search/utilities/query.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import sqlite3
3
+ import numpy as np
4
+ from typing import Tuple, Union
5
+ import argparse
6
+ from meme_search.utilities import model
7
+ from meme_search.utilities import vector_db_path, sqlite_db_path
8
+
9
+
10
+ def query_vector_db(query: str, db_file_path: str, k: int = 10) -> Tuple[list, list]:
11
+ # connect to db
12
+ faiss_index = faiss.read_index(db_file_path)
13
+
14
+ # test
15
+ encoded_query = np.expand_dims(model.encode(query), axis=0)
16
+
17
+ # query db
18
+ distances, indices = faiss_index.search(encoded_query, k)
19
+ distances = distances.tolist()[0]
20
+ indices = indices.tolist()[0]
21
+ return distances, indices
22
+
23
+
24
+ def query_for_indices(indices: list) -> list:
25
+ conn = sqlite3.connect(sqlite_db_path)
26
+ cursor = conn.cursor()
27
+ query = f"SELECT rowid, * FROM chunks_reverse_lookup WHERE rowid IN {tuple(indices)}"
28
+ cursor.execute(query)
29
+ rows = cursor.fetchall()
30
+ rows = [{"index": row[0], "img_path": row[1], "chunk": row[2]} for row in rows]
31
+ rows = sorted(rows, key=lambda x: indices.index(x["index"])) # re-sort rows according to input indices
32
+ for row in rows:
33
+ query = f"SELECT rowid, * FROM chunks_reverse_lookup WHERE rowid=(SELECT MIN(rowid) FROM chunks_reverse_lookup WHERE img_path='{row['img_path']}')"
34
+ cursor.execute(query)
35
+ full_description_row = cursor.fetchall()
36
+ row["full_description"] = full_description_row[0][2]
37
+ conn.close()
38
+ return rows
39
+
40
+
41
+ def query_for_all() -> list:
42
+ conn = sqlite3.connect(sqlite_db_path)
43
+ cursor = conn.cursor()
44
+ query = "SELECT rowid, * FROM chunks_reverse_lookup"
45
+ cursor.execute(query)
46
+ rows = cursor.fetchall()
47
+ rows = [{"index": row[0], "img_path": row[1], "chunk": row[2]} for row in rows]
48
+ return rows
49
+
50
+
51
+ def complete_query(query: str, k: int = 10) -> Union[list, None]:
52
+ try:
53
+ if len(query.strip()) > 1:
54
+ print("STARTING: complete_query")
55
+
56
+ # query vector_db, first converting input query to embedding
57
+ distances, indices = query_vector_db(query, vector_db_path, k=k)
58
+
59
+ # use indices to query sqlite db containing chunk data
60
+ img_chunks = query_for_indices(indices) # bump up indices by 1 since sqlite row index starts at 1 not 0
61
+
62
+ # map indices back to correct image in img_chunks
63
+ imgs_seen = []
64
+ unique_img_entries = []
65
+ for ind, entry in enumerate(img_chunks):
66
+ if entry["img_path"] in imgs_seen:
67
+ continue
68
+ else:
69
+ entry["distance"] = round(distances[ind], 2)
70
+ unique_img_entries.append(entry)
71
+ imgs_seen.append(entry["img_path"])
72
+ print("SUCCESS: complete_query succeeded")
73
+ return unique_img_entries
74
+ except Exception as e:
75
+ print(f"FAILURE: complete_query failed with exception {e}")
76
+ raise e
77
+
78
+
79
+ if __name__ == "__main__":
80
+ parser = argparse.ArgumentParser()
81
+ parser.add_argument("--query", dest="query", type=str, help="Add query")
82
+ args = parser.parse_args()
83
+ query = args.query
84
+ results = complete_query(query, vector_db_path, sqlite_db_path)
meme_search/utilities/remove.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import faiss
3
+ import numpy as np
4
+
5
+
6
+ def collect_removal_rowids(old_imgs_to_be_removed: list, sqlite_db_path: str) -> list:
7
+ try:
8
+ if len(old_imgs_to_be_removed) > 0:
9
+ conn = sqlite3.connect(sqlite_db_path)
10
+ cursor = conn.cursor()
11
+ query = f"""SELECT rowid FROM chunks_reverse_lookup WHERE img_path IN ({','.join(['"'+v+'"' for v in old_imgs_to_be_removed])})"""
12
+ cursor.execute(query)
13
+ rows = cursor.fetchall()
14
+ rowids = [v[0] for v in rows]
15
+ conn.close()
16
+ return rowids
17
+ else:
18
+ return []
19
+ except Exception as e:
20
+ raise ValueError(f"FAILURE: collect_removal_rowids failed with exception {e}")
21
+
22
+
23
+ def delete_removal_rowids_from_reverse_lookup(rowids: list, sqlite_db_path: str) -> None:
24
+ try:
25
+ if len(rowids) > 0:
26
+ conn = sqlite3.connect(sqlite_db_path)
27
+ cursor = conn.cursor()
28
+ if len(rowids) == 1:
29
+ query = f"""DELETE FROM chunks_reverse_lookup WHERE rowid IN ({str(rowids[0])})"""
30
+ else:
31
+ query = f"""DELETE FROM chunks_reverse_lookup WHERE rowid IN ({','.join([str(v) for v in rowids])})"""
32
+ cursor.execute(query)
33
+ conn.commit()
34
+ conn.close()
35
+
36
+ conn = sqlite3.connect(sqlite_db_path)
37
+ cursor = conn.cursor()
38
+ cursor.execute("VACUUM;")
39
+ conn.commit()
40
+ conn.close()
41
+ except Exception as e:
42
+ raise ValueError(f"FAILURE: delete_removal_rowids failed with exception {e}")
43
+
44
+
45
+ def delete_removal_rowids_from_vector_db(rowids: list, vector_db_path: str) -> None:
46
+ try:
47
+ if len(rowids) > 0:
48
+ index = faiss.read_index(vector_db_path)
49
+ remove_set = np.array(rowids, dtype=np.int64)
50
+ index.remove_ids(remove_set)
51
+ faiss.write_index(index, vector_db_path)
52
+ except Exception as e:
53
+ raise ValueError(f"FAILURE: delete_removal_rowids failed with exception {e}")
54
+
55
+
56
+ def remove(old_imgs_to_be_removed: list, sqlite_db_path: str, vector_db_path: str) -> None:
57
+ row_ids = collect_removal_rowids(old_imgs_to_be_removed, sqlite_db_path)
58
+ delete_removal_rowids_from_reverse_lookup(row_ids, sqlite_db_path)
59
+ delete_removal_rowids_from_vector_db(row_ids, vector_db_path)
meme_search/utilities/status.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from meme_search.utilities.imgs import collect_img_paths
3
+
4
+
5
+ def get_current_indexed_img_names(sqlite_db_path: str):
6
+ try:
7
+ print("STARTING: collecting currently indexed names")
8
+ conn = sqlite3.connect(sqlite_db_path)
9
+ cursor = conn.cursor()
10
+ query = f"SELECT DISTINCT(img_path) FROM chunks_reverse_lookup"
11
+ cursor.execute(query)
12
+ rows = cursor.fetchall()
13
+ rows = [v[0] for v in rows]
14
+ conn.close()
15
+ print("SUCCESS: get_current_indexed_img_names ran successfully")
16
+ return rows
17
+ except Exception as e:
18
+ raise ValueError(f"FAILURE: get_current_indexed_img_names failed with exception {e}")
19
+
20
+
21
+ def get_input_directory_status(img_dir: str, sqlite_db_path: str):
22
+ all_img_paths = collect_img_paths(img_dir)
23
+ all_img_paths_stubs = ["./" + "/".join(v.split("/")[-3:]).strip() for v in all_img_paths]
24
+ current_indexed_names = get_current_indexed_img_names(sqlite_db_path)
25
+
26
+ old_imgs_to_be_removed = list(set(current_indexed_names) - set(all_img_paths_stubs))
27
+ new_imgs_to_be_indexed = list(set(all_img_paths_stubs) - set(current_indexed_names))
28
+ return old_imgs_to_be_removed, new_imgs_to_be_indexed
meme_search/utilities/text_extraction.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from PIL import Image
3
+ import transformers
4
+
5
+ transformers.logging.set_verbosity_error()
6
+
7
+
8
+ def prompt_moondream(img_path: str, prompt: str) -> str:
9
+ # copied from moondream demo readme --> https://github.com/vikhyat/moondream/tree/main
10
+ model_id = "vikhyatk/moondream2"
11
+ revision = "2024-05-20"
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_id,
14
+ trust_remote_code=True,
15
+ revision=revision,
16
+ )
17
+ tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
18
+ image = Image.open(img_path)
19
+ enc_image = model.encode_image(image)
20
+ moondream_response = model.answer_question(enc_image, prompt, tokenizer)
21
+ return moondream_response
22
+
23
+
24
+ def extract_text_from_imgs(img_paths: list) -> list:
25
+ try:
26
+ print("STARTING: extract_text_from_imgs")
27
+ prompt = "Describe this image."
28
+ answers = []
29
+ for img_path in img_paths:
30
+ print(f"INFO: prompting moondream for a description of image: '{img_path}'")
31
+ answer = prompt_moondream(img_path, prompt)
32
+ answers.append(answer)
33
+ print("DONE!")
34
+ print("SUCCESS: extract_text_from_imgs succeeded")
35
+ return answers
36
+ except Exception as e:
37
+ print(f"FAILURE: extract_text_from_imgs failed with exception {e}")
38
+ raise e