Spaces:
Running
Running
Upload 25 files
Browse files- data/.DS_Store +0 -0
- data/dbs/memes.db +0 -0
- data/dbs/memes.faiss +0 -0
- data/dbs/placeholder +0 -0
- data/input/test_meme_1.jpg +0 -0
- data/input/test_meme_2.jpg +0 -0
- data/input/test_meme_3.jpg +0 -0
- data/input/test_meme_4.jpg +0 -0
- data/input/test_meme_5.jpg +0 -0
- data/input/test_meme_6.jpg +0 -0
- data/input/test_meme_7.jpg +0 -0
- data/input/test_meme_8.jpg +0 -0
- data/input/test_meme_9.jpg +0 -0
- meme_search/__init__.py +8 -0
- meme_search/app.py +57 -0
- meme_search/style.css +15 -0
- meme_search/utilities/__init__.py +11 -0
- meme_search/utilities/add.py +67 -0
- meme_search/utilities/chunks.py +68 -0
- meme_search/utilities/create.py +19 -0
- meme_search/utilities/imgs.py +18 -0
- meme_search/utilities/query.py +84 -0
- meme_search/utilities/remove.py +59 -0
- meme_search/utilities/status.py +28 -0
- meme_search/utilities/text_extraction.py +38 -0
data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
data/dbs/memes.db
ADDED
Binary file (28.7 kB). View file
|
|
data/dbs/memes.faiss
ADDED
Binary file (393 kB). View file
|
|
data/dbs/placeholder
ADDED
File without changes
|
data/input/test_meme_1.jpg
ADDED
data/input/test_meme_2.jpg
ADDED
data/input/test_meme_3.jpg
ADDED
data/input/test_meme_4.jpg
ADDED
data/input/test_meme_5.jpg
ADDED
data/input/test_meme_6.jpg
ADDED
data/input/test_meme_7.jpg
ADDED
data/input/test_meme_8.jpg
ADDED
data/input/test_meme_9.jpg
ADDED
meme_search/__init__.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
base_dir = os.path.dirname(os.path.abspath(__file__))
|
4 |
+
meme_search_root_dir = os.path.dirname(base_dir)
|
5 |
+
abs_dir = "."
|
6 |
+
|
7 |
+
vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
|
8 |
+
sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
|
meme_search/app.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from meme_search import base_dir
|
3 |
+
from meme_search.utilities.query import complete_query
|
4 |
+
from meme_search.utilities.create import process
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
st.set_page_config(page_title="Meme Search")
|
8 |
+
|
9 |
+
|
10 |
+
# search bar taken from --> https://discuss.streamlit.io/t/creating-a-nicely-formatted-search-field/1804/2
|
11 |
+
def local_css(file_name):
|
12 |
+
with open(file_name) as f:
|
13 |
+
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
|
14 |
+
|
15 |
+
|
16 |
+
def remote_css(url):
|
17 |
+
st.markdown(f'<link href="{url}" rel="stylesheet">', unsafe_allow_html=True)
|
18 |
+
|
19 |
+
|
20 |
+
local_css(base_dir + "/style.css")
|
21 |
+
remote_css("https://fonts.googleapis.com/icon?family=Material+Icons")
|
22 |
+
|
23 |
+
# icon("search")
|
24 |
+
with st.container():
|
25 |
+
with st.container(border=True):
|
26 |
+
input_col, button_col = st.columns([6, 2])
|
27 |
+
|
28 |
+
with button_col:
|
29 |
+
st.empty()
|
30 |
+
refresh_index_button = st.button("refresh index", type="primary")
|
31 |
+
if refresh_index_button:
|
32 |
+
process_start = st.warning("refreshing...")
|
33 |
+
val = process()
|
34 |
+
if val:
|
35 |
+
process_start.empty()
|
36 |
+
success = st.success("index updated!")
|
37 |
+
time.sleep(2)
|
38 |
+
process_start.empty()
|
39 |
+
success.empty()
|
40 |
+
else:
|
41 |
+
process_start.empty()
|
42 |
+
warning = st.warning("no refresh needed!")
|
43 |
+
time.sleep(2)
|
44 |
+
warning.empty()
|
45 |
+
|
46 |
+
selected = input_col.text_input(label="meme search", placeholder="search for your meme", label_visibility="collapsed")
|
47 |
+
if selected:
|
48 |
+
results = complete_query(selected)
|
49 |
+
img_paths = [v["img_path"] for v in results]
|
50 |
+
with st.container(border=True):
|
51 |
+
for result in results:
|
52 |
+
with st.container(border=True):
|
53 |
+
st.image(
|
54 |
+
result["img_path"],
|
55 |
+
output_format="auto",
|
56 |
+
caption=f'{result["full_description"]} (query distance = {result["distance"]})',
|
57 |
+
)
|
meme_search/style.css
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
body {
|
2 |
+
color: #fff;
|
3 |
+
background-color: #4F8BF9;
|
4 |
+
}
|
5 |
+
|
6 |
+
/* .stButton>button {
|
7 |
+
color: #4F8BF9;
|
8 |
+
border-radius: 50%;
|
9 |
+
height: 3em;
|
10 |
+
width: 3em;
|
11 |
+
} */
|
12 |
+
|
13 |
+
.stTextInput>div>div>input {
|
14 |
+
color: #4F8BF9;
|
15 |
+
}
|
meme_search/utilities/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from sentence_transformers import SentenceTransformer
|
3 |
+
|
4 |
+
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
5 |
+
utilities_base_dir = os.path.dirname(os.path.abspath(__file__))
|
6 |
+
meme_search_dir = os.path.dirname(utilities_base_dir)
|
7 |
+
meme_search_root_dir = os.path.dirname(meme_search_dir)
|
8 |
+
|
9 |
+
img_dir = meme_search_root_dir + "/data/input/"
|
10 |
+
vector_db_path = meme_search_root_dir + "/data/dbs/memes.faiss"
|
11 |
+
sqlite_db_path = meme_search_root_dir + "/data/dbs/memes.db"
|
meme_search/utilities/add.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sqlite3
|
3 |
+
import faiss
|
4 |
+
from meme_search.utilities import model
|
5 |
+
from meme_search.utilities.text_extraction import extract_text_from_imgs
|
6 |
+
from meme_search.utilities.chunks import create_all_img_chunks
|
7 |
+
|
8 |
+
|
9 |
+
def add_to_chunk_db(img_chunks: list, sqlite_db_path: str) -> None:
|
10 |
+
# Create a lookup table for chunks
|
11 |
+
conn = sqlite3.connect(sqlite_db_path)
|
12 |
+
cursor = conn.cursor()
|
13 |
+
|
14 |
+
# Create the table
|
15 |
+
cursor.execute("""
|
16 |
+
CREATE TABLE IF NOT EXISTS chunks_reverse_lookup (
|
17 |
+
img_path TEXT,
|
18 |
+
chunk TEXT
|
19 |
+
);
|
20 |
+
""")
|
21 |
+
|
22 |
+
# Insert data into the table
|
23 |
+
for chunk_index, entry in enumerate(img_chunks):
|
24 |
+
img_path = entry["img_path"]
|
25 |
+
chunk = entry["chunk"]
|
26 |
+
cursor.execute(
|
27 |
+
"INSERT INTO chunks_reverse_lookup (img_path, chunk) VALUES (?, ?)",
|
28 |
+
(img_path, chunk),
|
29 |
+
)
|
30 |
+
|
31 |
+
conn.commit()
|
32 |
+
conn.close()
|
33 |
+
|
34 |
+
|
35 |
+
def add_to_vector_db(chunks: list, vector_db_path: str) -> None:
|
36 |
+
# embed inputs
|
37 |
+
embeddings = model.encode(chunks)
|
38 |
+
|
39 |
+
# dump all_embeddings to faiss index
|
40 |
+
if os.path.exists(vector_db_path):
|
41 |
+
index = faiss.read_index(vector_db_path)
|
42 |
+
else:
|
43 |
+
index = faiss.IndexFlatL2(embeddings.shape[1])
|
44 |
+
|
45 |
+
index.add(embeddings)
|
46 |
+
faiss.write_index(index, vector_db_path)
|
47 |
+
|
48 |
+
|
49 |
+
def add_to_dbs(img_chunks: list, sqlite_db_path: str, vector_db_path: str) -> None:
|
50 |
+
try:
|
51 |
+
print("STARTING: add_to_dbs")
|
52 |
+
|
53 |
+
# add to db for img_chunks
|
54 |
+
add_to_chunk_db(img_chunks, sqlite_db_path)
|
55 |
+
|
56 |
+
# create vector embedding db for chunks
|
57 |
+
chunks = [v["chunk"] for v in img_chunks]
|
58 |
+
add_to_vector_db(chunks, vector_db_path)
|
59 |
+
print("SUCCESS: add_to_dbs succeeded")
|
60 |
+
except Exception as e:
|
61 |
+
print(f"FAILURE: add_to_dbs failed with exception {e}")
|
62 |
+
|
63 |
+
|
64 |
+
def add(new_imgs_to_be_indexed: list, sqlite_db_path: str, vector_db_path: str) -> None:
|
65 |
+
moondream_answers = extract_text_from_imgs(new_imgs_to_be_indexed)
|
66 |
+
img_chunks = create_all_img_chunks(new_imgs_to_be_indexed, moondream_answers)
|
67 |
+
add_to_dbs(img_chunks, sqlite_db_path, vector_db_path)
|
meme_search/utilities/chunks.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def clean_word(text: str) -> str:
|
5 |
+
# clean input text - keeping only lower case letters, numbers, punctuation, and single quote symbols
|
6 |
+
return re.sub(" +", " ", re.compile("[^a-z0-9,.!?']").sub(" ", text.lower().strip()))
|
7 |
+
|
8 |
+
|
9 |
+
def chunk_text(text: str) -> list:
|
10 |
+
# split and clean input text
|
11 |
+
text_split = clean_word(text).split(" ")
|
12 |
+
text_split = [v for v in text_split if len(v) > 0]
|
13 |
+
|
14 |
+
# use two pointers to create chunks
|
15 |
+
chunk_size = 4
|
16 |
+
overlap_size = 2
|
17 |
+
|
18 |
+
# create next chunk by moving right pointer until chunk_size is reached or line_number changes by more than 1 or end of word_sequence is reached
|
19 |
+
left_pointer = 0
|
20 |
+
right_pointer = chunk_size - 1
|
21 |
+
chunks = []
|
22 |
+
|
23 |
+
if right_pointer >= len(text_split):
|
24 |
+
chunks = [" ".join(text_split)]
|
25 |
+
else:
|
26 |
+
while right_pointer < len(text_split):
|
27 |
+
# check if chunk_size has been reached
|
28 |
+
# create chunk
|
29 |
+
chunk = text_split[left_pointer : right_pointer + 1]
|
30 |
+
|
31 |
+
# move left pointer
|
32 |
+
left_pointer += chunk_size - overlap_size
|
33 |
+
|
34 |
+
# move right pointer
|
35 |
+
right_pointer += chunk_size - overlap_size
|
36 |
+
|
37 |
+
# store chunk
|
38 |
+
chunks.append(" ".join(chunk))
|
39 |
+
|
40 |
+
# check if there is final chunk
|
41 |
+
if len(text_split[left_pointer:]) > 0:
|
42 |
+
last_chunk = text_split[left_pointer:]
|
43 |
+
chunks.append(" ".join(last_chunk))
|
44 |
+
|
45 |
+
# insert the full text
|
46 |
+
if len(chunks) > 1:
|
47 |
+
chunks.insert(0, text.lower())
|
48 |
+
return chunks
|
49 |
+
|
50 |
+
|
51 |
+
# loop over each meme's moondream based text descriptor and create a short dict containing its full and chunked text
|
52 |
+
def create_all_img_chunks(img_paths: list, answers: list) -> list:
|
53 |
+
try:
|
54 |
+
print("STARTING: create_all_img_chunks")
|
55 |
+
img_chunks = []
|
56 |
+
for ind, img_path in enumerate(img_paths):
|
57 |
+
moondream_meme_text = answers[ind]
|
58 |
+
moondream_chunks = chunk_text(moondream_meme_text)
|
59 |
+
for chunk in moondream_chunks:
|
60 |
+
entry = {}
|
61 |
+
entry["img_path"] = img_path
|
62 |
+
entry["chunk"] = chunk
|
63 |
+
img_chunks.append(entry)
|
64 |
+
print("SUCCESS: create_all_img_chunks ran successfully")
|
65 |
+
return img_chunks
|
66 |
+
except Exception as e:
|
67 |
+
print(f"FAILURE: create_all_img_chunks failed with exception {e}")
|
68 |
+
raise e
|
meme_search/utilities/create.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from meme_search.utilities.status import get_input_directory_status
|
2 |
+
from meme_search.utilities.remove import remove
|
3 |
+
from meme_search.utilities.add import add
|
4 |
+
from meme_search.utilities import img_dir, sqlite_db_path, vector_db_path
|
5 |
+
|
6 |
+
|
7 |
+
def process() -> bool:
|
8 |
+
old_imgs_to_be_removed, new_imgs_to_be_indexed = get_input_directory_status(img_dir, sqlite_db_path)
|
9 |
+
if len(old_imgs_to_be_removed) == 0 and len(new_imgs_to_be_indexed) == 0:
|
10 |
+
return False
|
11 |
+
if len(old_imgs_to_be_removed) > 0:
|
12 |
+
remove(old_imgs_to_be_removed, sqlite_db_path, vector_db_path)
|
13 |
+
if len(new_imgs_to_be_indexed):
|
14 |
+
add(new_imgs_to_be_indexed, sqlite_db_path, vector_db_path)
|
15 |
+
return True
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
process()
|
meme_search/utilities/imgs.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
allowable_extensions = ["jpg", "jpeg", "png"]
|
4 |
+
|
5 |
+
|
6 |
+
def collect_img_paths(img_dir: str) -> list:
|
7 |
+
try:
|
8 |
+
print("STARTING: collect_img_paths")
|
9 |
+
|
10 |
+
all_img_paths = [os.path.join(img_dir, name) for name in os.listdir(img_dir) if name.split(".")[-1] in allowable_extensions]
|
11 |
+
all_img_paths = sorted(all_img_paths)
|
12 |
+
all_img_paths = ["./data/input/" + v.split("/")[-1] for v in all_img_paths]
|
13 |
+
|
14 |
+
print(f"SUCCESS: collect_img_paths ran successfully - image paths loaded from '{img_dir}'")
|
15 |
+
return all_img_paths
|
16 |
+
except Exception as e:
|
17 |
+
print(f"FAILURE: collect_img_paths failed with img_dir {img_dir} with exception {e}")
|
18 |
+
raise e
|
meme_search/utilities/query.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import sqlite3
|
3 |
+
import numpy as np
|
4 |
+
from typing import Tuple, Union
|
5 |
+
import argparse
|
6 |
+
from meme_search.utilities import model
|
7 |
+
from meme_search.utilities import vector_db_path, sqlite_db_path
|
8 |
+
|
9 |
+
|
10 |
+
def query_vector_db(query: str, db_file_path: str, k: int = 10) -> Tuple[list, list]:
|
11 |
+
# connect to db
|
12 |
+
faiss_index = faiss.read_index(db_file_path)
|
13 |
+
|
14 |
+
# test
|
15 |
+
encoded_query = np.expand_dims(model.encode(query), axis=0)
|
16 |
+
|
17 |
+
# query db
|
18 |
+
distances, indices = faiss_index.search(encoded_query, k)
|
19 |
+
distances = distances.tolist()[0]
|
20 |
+
indices = indices.tolist()[0]
|
21 |
+
return distances, indices
|
22 |
+
|
23 |
+
|
24 |
+
def query_for_indices(indices: list) -> list:
|
25 |
+
conn = sqlite3.connect(sqlite_db_path)
|
26 |
+
cursor = conn.cursor()
|
27 |
+
query = f"SELECT rowid, * FROM chunks_reverse_lookup WHERE rowid IN {tuple(indices)}"
|
28 |
+
cursor.execute(query)
|
29 |
+
rows = cursor.fetchall()
|
30 |
+
rows = [{"index": row[0], "img_path": row[1], "chunk": row[2]} for row in rows]
|
31 |
+
rows = sorted(rows, key=lambda x: indices.index(x["index"])) # re-sort rows according to input indices
|
32 |
+
for row in rows:
|
33 |
+
query = f"SELECT rowid, * FROM chunks_reverse_lookup WHERE rowid=(SELECT MIN(rowid) FROM chunks_reverse_lookup WHERE img_path='{row['img_path']}')"
|
34 |
+
cursor.execute(query)
|
35 |
+
full_description_row = cursor.fetchall()
|
36 |
+
row["full_description"] = full_description_row[0][2]
|
37 |
+
conn.close()
|
38 |
+
return rows
|
39 |
+
|
40 |
+
|
41 |
+
def query_for_all() -> list:
|
42 |
+
conn = sqlite3.connect(sqlite_db_path)
|
43 |
+
cursor = conn.cursor()
|
44 |
+
query = "SELECT rowid, * FROM chunks_reverse_lookup"
|
45 |
+
cursor.execute(query)
|
46 |
+
rows = cursor.fetchall()
|
47 |
+
rows = [{"index": row[0], "img_path": row[1], "chunk": row[2]} for row in rows]
|
48 |
+
return rows
|
49 |
+
|
50 |
+
|
51 |
+
def complete_query(query: str, k: int = 10) -> Union[list, None]:
|
52 |
+
try:
|
53 |
+
if len(query.strip()) > 1:
|
54 |
+
print("STARTING: complete_query")
|
55 |
+
|
56 |
+
# query vector_db, first converting input query to embedding
|
57 |
+
distances, indices = query_vector_db(query, vector_db_path, k=k)
|
58 |
+
|
59 |
+
# use indices to query sqlite db containing chunk data
|
60 |
+
img_chunks = query_for_indices(indices) # bump up indices by 1 since sqlite row index starts at 1 not 0
|
61 |
+
|
62 |
+
# map indices back to correct image in img_chunks
|
63 |
+
imgs_seen = []
|
64 |
+
unique_img_entries = []
|
65 |
+
for ind, entry in enumerate(img_chunks):
|
66 |
+
if entry["img_path"] in imgs_seen:
|
67 |
+
continue
|
68 |
+
else:
|
69 |
+
entry["distance"] = round(distances[ind], 2)
|
70 |
+
unique_img_entries.append(entry)
|
71 |
+
imgs_seen.append(entry["img_path"])
|
72 |
+
print("SUCCESS: complete_query succeeded")
|
73 |
+
return unique_img_entries
|
74 |
+
except Exception as e:
|
75 |
+
print(f"FAILURE: complete_query failed with exception {e}")
|
76 |
+
raise e
|
77 |
+
|
78 |
+
|
79 |
+
if __name__ == "__main__":
|
80 |
+
parser = argparse.ArgumentParser()
|
81 |
+
parser.add_argument("--query", dest="query", type=str, help="Add query")
|
82 |
+
args = parser.parse_args()
|
83 |
+
query = args.query
|
84 |
+
results = complete_query(query, vector_db_path, sqlite_db_path)
|
meme_search/utilities/remove.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
import faiss
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def collect_removal_rowids(old_imgs_to_be_removed: list, sqlite_db_path: str) -> list:
|
7 |
+
try:
|
8 |
+
if len(old_imgs_to_be_removed) > 0:
|
9 |
+
conn = sqlite3.connect(sqlite_db_path)
|
10 |
+
cursor = conn.cursor()
|
11 |
+
query = f"""SELECT rowid FROM chunks_reverse_lookup WHERE img_path IN ({','.join(['"'+v+'"' for v in old_imgs_to_be_removed])})"""
|
12 |
+
cursor.execute(query)
|
13 |
+
rows = cursor.fetchall()
|
14 |
+
rowids = [v[0] for v in rows]
|
15 |
+
conn.close()
|
16 |
+
return rowids
|
17 |
+
else:
|
18 |
+
return []
|
19 |
+
except Exception as e:
|
20 |
+
raise ValueError(f"FAILURE: collect_removal_rowids failed with exception {e}")
|
21 |
+
|
22 |
+
|
23 |
+
def delete_removal_rowids_from_reverse_lookup(rowids: list, sqlite_db_path: str) -> None:
|
24 |
+
try:
|
25 |
+
if len(rowids) > 0:
|
26 |
+
conn = sqlite3.connect(sqlite_db_path)
|
27 |
+
cursor = conn.cursor()
|
28 |
+
if len(rowids) == 1:
|
29 |
+
query = f"""DELETE FROM chunks_reverse_lookup WHERE rowid IN ({str(rowids[0])})"""
|
30 |
+
else:
|
31 |
+
query = f"""DELETE FROM chunks_reverse_lookup WHERE rowid IN ({','.join([str(v) for v in rowids])})"""
|
32 |
+
cursor.execute(query)
|
33 |
+
conn.commit()
|
34 |
+
conn.close()
|
35 |
+
|
36 |
+
conn = sqlite3.connect(sqlite_db_path)
|
37 |
+
cursor = conn.cursor()
|
38 |
+
cursor.execute("VACUUM;")
|
39 |
+
conn.commit()
|
40 |
+
conn.close()
|
41 |
+
except Exception as e:
|
42 |
+
raise ValueError(f"FAILURE: delete_removal_rowids failed with exception {e}")
|
43 |
+
|
44 |
+
|
45 |
+
def delete_removal_rowids_from_vector_db(rowids: list, vector_db_path: str) -> None:
|
46 |
+
try:
|
47 |
+
if len(rowids) > 0:
|
48 |
+
index = faiss.read_index(vector_db_path)
|
49 |
+
remove_set = np.array(rowids, dtype=np.int64)
|
50 |
+
index.remove_ids(remove_set)
|
51 |
+
faiss.write_index(index, vector_db_path)
|
52 |
+
except Exception as e:
|
53 |
+
raise ValueError(f"FAILURE: delete_removal_rowids failed with exception {e}")
|
54 |
+
|
55 |
+
|
56 |
+
def remove(old_imgs_to_be_removed: list, sqlite_db_path: str, vector_db_path: str) -> None:
|
57 |
+
row_ids = collect_removal_rowids(old_imgs_to_be_removed, sqlite_db_path)
|
58 |
+
delete_removal_rowids_from_reverse_lookup(row_ids, sqlite_db_path)
|
59 |
+
delete_removal_rowids_from_vector_db(row_ids, vector_db_path)
|
meme_search/utilities/status.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
from meme_search.utilities.imgs import collect_img_paths
|
3 |
+
|
4 |
+
|
5 |
+
def get_current_indexed_img_names(sqlite_db_path: str):
|
6 |
+
try:
|
7 |
+
print("STARTING: collecting currently indexed names")
|
8 |
+
conn = sqlite3.connect(sqlite_db_path)
|
9 |
+
cursor = conn.cursor()
|
10 |
+
query = f"SELECT DISTINCT(img_path) FROM chunks_reverse_lookup"
|
11 |
+
cursor.execute(query)
|
12 |
+
rows = cursor.fetchall()
|
13 |
+
rows = [v[0] for v in rows]
|
14 |
+
conn.close()
|
15 |
+
print("SUCCESS: get_current_indexed_img_names ran successfully")
|
16 |
+
return rows
|
17 |
+
except Exception as e:
|
18 |
+
raise ValueError(f"FAILURE: get_current_indexed_img_names failed with exception {e}")
|
19 |
+
|
20 |
+
|
21 |
+
def get_input_directory_status(img_dir: str, sqlite_db_path: str):
|
22 |
+
all_img_paths = collect_img_paths(img_dir)
|
23 |
+
all_img_paths_stubs = ["./" + "/".join(v.split("/")[-3:]).strip() for v in all_img_paths]
|
24 |
+
current_indexed_names = get_current_indexed_img_names(sqlite_db_path)
|
25 |
+
|
26 |
+
old_imgs_to_be_removed = list(set(current_indexed_names) - set(all_img_paths_stubs))
|
27 |
+
new_imgs_to_be_indexed = list(set(all_img_paths_stubs) - set(current_indexed_names))
|
28 |
+
return old_imgs_to_be_removed, new_imgs_to_be_indexed
|
meme_search/utilities/text_extraction.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
2 |
+
from PIL import Image
|
3 |
+
import transformers
|
4 |
+
|
5 |
+
transformers.logging.set_verbosity_error()
|
6 |
+
|
7 |
+
|
8 |
+
def prompt_moondream(img_path: str, prompt: str) -> str:
|
9 |
+
# copied from moondream demo readme --> https://github.com/vikhyat/moondream/tree/main
|
10 |
+
model_id = "vikhyatk/moondream2"
|
11 |
+
revision = "2024-05-20"
|
12 |
+
model = AutoModelForCausalLM.from_pretrained(
|
13 |
+
model_id,
|
14 |
+
trust_remote_code=True,
|
15 |
+
revision=revision,
|
16 |
+
)
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
|
18 |
+
image = Image.open(img_path)
|
19 |
+
enc_image = model.encode_image(image)
|
20 |
+
moondream_response = model.answer_question(enc_image, prompt, tokenizer)
|
21 |
+
return moondream_response
|
22 |
+
|
23 |
+
|
24 |
+
def extract_text_from_imgs(img_paths: list) -> list:
|
25 |
+
try:
|
26 |
+
print("STARTING: extract_text_from_imgs")
|
27 |
+
prompt = "Describe this image."
|
28 |
+
answers = []
|
29 |
+
for img_path in img_paths:
|
30 |
+
print(f"INFO: prompting moondream for a description of image: '{img_path}'")
|
31 |
+
answer = prompt_moondream(img_path, prompt)
|
32 |
+
answers.append(answer)
|
33 |
+
print("DONE!")
|
34 |
+
print("SUCCESS: extract_text_from_imgs succeeded")
|
35 |
+
return answers
|
36 |
+
except Exception as e:
|
37 |
+
print(f"FAILURE: extract_text_from_imgs failed with exception {e}")
|
38 |
+
raise e
|