Spaces:
Running
Running
Ashmi Banerjee
commited on
Commit
•
89cd5d5
1
Parent(s):
ac20456
refactored the vectordb
Browse files- app.py +1 -1
- src/__init__.py +2 -2
- src/information_retrieval/info_retrieval.py +11 -7
- src/vectordb/create_db.py +3 -1
- src/vectordb/helpers.py +15 -1
- src/vectordb/ingest.py +61 -0
- src/vectordb/{lancedb_init.py → schema.py} +0 -0
- src/vectordb/search.py +97 -0
- src/vectordb/vectordb.py +0 -190
app.py
CHANGED
@@ -56,7 +56,7 @@ def create_ui():
|
|
56 |
" ")
|
57 |
|
58 |
with gr.Group():
|
59 |
-
countries = gr.Dropdown(choices=list(df.country), multiselect=False, label="
|
60 |
starting_point = gr.Dropdown(choices=[], multiselect=False,
|
61 |
label="Select your starting point for the trip!")
|
62 |
|
|
|
56 |
" ")
|
57 |
|
58 |
with gr.Group():
|
59 |
+
countries = gr.Dropdown(choices=list(df.country.unique()), multiselect=False, label="Country")
|
60 |
starting_point = gr.Dropdown(choices=[], multiselect=False,
|
61 |
label="Select your starting point for the trip!")
|
62 |
|
src/__init__.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
from src.vectordb.
|
2 |
from src.vectordb.helpers import *
|
3 |
-
from src.vectordb.
|
4 |
|
5 |
from src.sustainability.s_fairness import *
|
6 |
from src.information_retrieval.info_retrieval import *
|
|
|
1 |
+
from src.vectordb.search import *
|
2 |
from src.vectordb.helpers import *
|
3 |
+
from src.vectordb.schema import *
|
4 |
|
5 |
from src.sustainability.s_fairness import *
|
6 |
from src.information_retrieval.info_retrieval import *
|
src/information_retrieval/info_retrieval.py
CHANGED
@@ -2,8 +2,11 @@ import sys
|
|
2 |
import re
|
3 |
import os
|
4 |
import json
|
|
|
|
|
|
|
5 |
sys.path.append("../")
|
6 |
-
from src.vectordb import
|
7 |
from src.sustainability import s_fairness
|
8 |
import logging
|
9 |
|
@@ -12,6 +15,7 @@ logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
|
12 |
|
13 |
from src.helpers.data_loaders import load_scores
|
14 |
|
|
|
15 |
def get_travel_months(query):
|
16 |
"""
|
17 |
|
@@ -66,7 +70,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
|
|
66 |
# limit = params['limit']
|
67 |
# reranking = params['reranking']
|
68 |
|
69 |
-
docs =
|
70 |
logger.info("Finished getting chunked wikivoyage docs.")
|
71 |
|
72 |
results = {}
|
@@ -76,7 +80,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
|
|
76 |
|
77 |
cities = [result['city'] for result in docs]
|
78 |
|
79 |
-
listings =
|
80 |
logger.info("Finished getting wikivoyage listings.")
|
81 |
# logger.info(type(docs), type(listings))
|
82 |
|
@@ -92,7 +96,7 @@ def get_wikivoyage_context(query, limit=10, reranking=0):
|
|
92 |
return results
|
93 |
|
94 |
|
95 |
-
def get_sustainability_scores(starting_point: str
|
96 |
"""
|
97 |
|
98 |
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
|
@@ -164,7 +168,7 @@ def get_cities(context: dict):
|
|
164 |
"""
|
165 |
|
166 |
recommended_cities = []
|
167 |
-
|
168 |
for city, info in context.items():
|
169 |
city_info = {
|
170 |
'city': city,
|
@@ -242,8 +246,8 @@ def test():
|
|
242 |
# print(cities)
|
243 |
except FileNotFoundError as e:
|
244 |
try:
|
245 |
-
|
246 |
-
|
247 |
|
248 |
try:
|
249 |
context = get_context(query, sustainability=1)
|
|
|
2 |
import re
|
3 |
import os
|
4 |
import json
|
5 |
+
|
6 |
+
from src.vectordb.ingest import create_wikivoyage_docs_db_and_add_data, create_wikivoyage_listings_db_and_add_data
|
7 |
+
|
8 |
sys.path.append("../")
|
9 |
+
from src.vectordb.search import search_wikivoyage_listings, search_wikivoyage_docs
|
10 |
from src.sustainability import s_fairness
|
11 |
import logging
|
12 |
|
|
|
15 |
|
16 |
from src.helpers.data_loaders import load_scores
|
17 |
|
18 |
+
|
19 |
def get_travel_months(query):
|
20 |
"""
|
21 |
|
|
|
70 |
# limit = params['limit']
|
71 |
# reranking = params['reranking']
|
72 |
|
73 |
+
docs = search_wikivoyage_docs(query, limit, reranking)
|
74 |
logger.info("Finished getting chunked wikivoyage docs.")
|
75 |
|
76 |
results = {}
|
|
|
80 |
|
81 |
cities = [result['city'] for result in docs]
|
82 |
|
83 |
+
listings = search_wikivoyage_listings(query, cities, limit, reranking)
|
84 |
logger.info("Finished getting wikivoyage listings.")
|
85 |
# logger.info(type(docs), type(listings))
|
86 |
|
|
|
96 |
return results
|
97 |
|
98 |
|
99 |
+
def get_sustainability_scores(starting_point: str, query: str, destinations: list):
|
100 |
"""
|
101 |
|
102 |
Function to get the s-fairness scores for each destination for the given month (or the ideal month of travel if the user hasn't provided a month).
|
|
|
168 |
"""
|
169 |
|
170 |
recommended_cities = []
|
171 |
+
info = context[list(context.keys())[0]]
|
172 |
for city, info in context.items():
|
173 |
city_info = {
|
174 |
'city': city,
|
|
|
246 |
# print(cities)
|
247 |
except FileNotFoundError as e:
|
248 |
try:
|
249 |
+
create_wikivoyage_docs_db_and_add_data()
|
250 |
+
create_wikivoyage_listings_db_and_add_data()
|
251 |
|
252 |
try:
|
253 |
context = get_context(query, sustainability=1)
|
src/vectordb/create_db.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
-
from src.vectordb.
|
2 |
import logging
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
6 |
|
|
|
|
|
7 |
|
8 |
def run():
|
9 |
logging.info("Creating database for Wikivoyage Documents")
|
|
|
1 |
+
from src.vectordb.search import *
|
2 |
import logging
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
6 |
|
7 |
+
from src.vectordb.ingest import create_wikivoyage_docs_db_and_add_data, create_wikivoyage_listings_db_and_add_data
|
8 |
+
|
9 |
|
10 |
def run():
|
11 |
logging.info("Creating database for Wikivoyage Documents")
|
src/vectordb/helpers.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import pandas as pd
|
2 |
import os
|
3 |
import re
|
@@ -7,7 +9,7 @@ import sys
|
|
7 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
8 |
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
9 |
|
10 |
-
from data_directories import *
|
11 |
|
12 |
|
13 |
def create_chunks(city, country, text):
|
@@ -148,3 +150,15 @@ def embed_query(query):
|
|
148 |
# vector_dimension = model.get_sentence_embedding_dimension()
|
149 |
embedding = model.encode(query).tolist()
|
150 |
return embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
import pandas as pd
|
4 |
import os
|
5 |
import re
|
|
|
9 |
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
|
10 |
sys.path.append(os.path.dirname(SCRIPT_DIR))
|
11 |
|
12 |
+
from src.data_directories import *
|
13 |
|
14 |
|
15 |
def create_chunks(city, country, text):
|
|
|
150 |
# vector_dimension = model.get_sentence_embedding_dimension()
|
151 |
embedding = model.encode(query).tolist()
|
152 |
return embedding
|
153 |
+
|
154 |
+
|
155 |
+
def set_uri(run_local: Optional[bool] = False):
|
156 |
+
if run_local:
|
157 |
+
uri = database_dir
|
158 |
+
current_dir = os.path.split(os.getcwd())[1]
|
159 |
+
|
160 |
+
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
161 |
+
uri = uri.replace("../../", "../")
|
162 |
+
else:
|
163 |
+
uri = os.environ["BUCKET_NAME"]
|
164 |
+
return uri
|
src/vectordb/ingest.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Callable
|
2 |
+
import logging
|
3 |
+
logger = logging.getLogger(__name__)
|
4 |
+
logging.basicConfig(encoding='utf-8', level=logging.DEBUG)
|
5 |
+
from src.vectordb.helpers import read_docs, read_listings, preprocess_df
|
6 |
+
from src.vectordb.schema import WikivoyageDocuments, WikivoyageListings
|
7 |
+
from src.vectordb.helpers import set_uri
|
8 |
+
import lancedb
|
9 |
+
|
10 |
+
|
11 |
+
def _create_table_and_ingest_data(table_name: str, schema: object, data_fetcher: Callable,
|
12 |
+
preprocessor: Optional[Callable] = None):
|
13 |
+
"""
|
14 |
+
Generalized function to create a table and ingest data into the database.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
- table_name: str, name of the table to create.
|
18 |
+
- schema: object, schema of the table.
|
19 |
+
- data_fetcher: Callable, function to fetch the data.
|
20 |
+
- preprocessor: Optional[Callable], function to preprocess the data (default is None).
|
21 |
+
"""
|
22 |
+
uri = set_uri()
|
23 |
+
|
24 |
+
db = lancedb.connect(uri)
|
25 |
+
logger.info(f"Connected to DB. Reading data for table {table_name} now...")
|
26 |
+
|
27 |
+
df = data_fetcher()
|
28 |
+
|
29 |
+
if preprocessor:
|
30 |
+
df = preprocessor(df)
|
31 |
+
|
32 |
+
logger.info(f"Finished reading data for {table_name}, attempting to create table and ingest the data...")
|
33 |
+
|
34 |
+
db.drop_table(table_name, ignore_missing=True)
|
35 |
+
table = db.create_table(table_name, schema=schema)
|
36 |
+
|
37 |
+
table.add(df.to_dict('records'))
|
38 |
+
logger.info(f"Completed ingestion for {table_name}.")
|
39 |
+
|
40 |
+
|
41 |
+
def create_wikivoyage_docs_db_and_add_data():
|
42 |
+
"""
|
43 |
+
Creates the Wikivoyage documents table and ingests data.
|
44 |
+
"""
|
45 |
+
_create_table_and_ingest_data(
|
46 |
+
table_name="wikivoyage_documents",
|
47 |
+
schema=WikivoyageDocuments,
|
48 |
+
data_fetcher=read_docs,
|
49 |
+
preprocessor=preprocess_df
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def create_wikivoyage_listings_db_and_add_data():
|
54 |
+
"""
|
55 |
+
Creates the Wikivoyage listings table and ingests data.
|
56 |
+
"""
|
57 |
+
_create_table_and_ingest_data(
|
58 |
+
table_name="wikivoyage_listings",
|
59 |
+
schema=WikivoyageListings,
|
60 |
+
data_fetcher=read_listings
|
61 |
+
)
|
src/vectordb/{lancedb_init.py → schema.py}
RENAMED
File without changes
|
src/vectordb/search.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# from src import *
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import lancedb
|
6 |
+
from lancedb.rerankers import ColbertReranker
|
7 |
+
|
8 |
+
import sys
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
from typing import Optional
|
12 |
+
from src.vectordb.helpers import set_uri
|
13 |
+
|
14 |
+
|
15 |
+
# db = lancedb.connect("/tmp/db")
|
16 |
+
|
17 |
+
|
18 |
+
def search(query: str, table_name: str, filter_condition: Optional[str] = None,
|
19 |
+
category: str = "docs", limit: int = 10, reranking: int = 0,
|
20 |
+
run_local: Optional[bool] = False) -> list | None:
|
21 |
+
"""
|
22 |
+
Generalized function to search a database table, with optional filters and reranking.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
- query: str, search query.
|
26 |
+
- table_name: str, name of the table to search.
|
27 |
+
- filter_condition: Optional[str], optional SQL-like condition for filtering results.
|
28 |
+
- category: str, type of category (default is 'docs').
|
29 |
+
- limit: int, number of results (default is 10).
|
30 |
+
- reranking: int (0 or 1), if activated, ColbertReranker is used.
|
31 |
+
- run_local: Optional[bool], whether to run in a local environment.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
A list of the most relevant documents or listings based on the category.
|
35 |
+
"""
|
36 |
+
uri = set_uri(run_local)
|
37 |
+
|
38 |
+
try:
|
39 |
+
db = lancedb.connect(uri)
|
40 |
+
except Exception as e:
|
41 |
+
logger.error(f"Error while connecting to DB: {e}")
|
42 |
+
return None
|
43 |
+
|
44 |
+
logger.info(f"Connected to {table_name} DB.")
|
45 |
+
table = db.open_table(table_name)
|
46 |
+
|
47 |
+
search_query = table.search(query).metric('cosine')
|
48 |
+
|
49 |
+
if filter_condition:
|
50 |
+
search_query = search_query.where(filter_condition)
|
51 |
+
|
52 |
+
if reranking:
|
53 |
+
try:
|
54 |
+
column = 'description' if category == 'listings' else 'text'
|
55 |
+
reranker = ColbertReranker(column=column)
|
56 |
+
results = search_query.rerank(reranker=reranker).limit(limit).to_list()
|
57 |
+
except Exception as e:
|
58 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
59 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
60 |
+
logger.error(f"Error while reranking results: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
61 |
+
return None
|
62 |
+
else:
|
63 |
+
try:
|
64 |
+
results = search_query.limit(limit).to_list()
|
65 |
+
except Exception as e:
|
66 |
+
exc_type, exc_obj, exc_tb = sys.exc_info()
|
67 |
+
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
68 |
+
logger.error(f"Error while searching: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
69 |
+
return None
|
70 |
+
|
71 |
+
logger.info("Found the most relevant documents.")
|
72 |
+
|
73 |
+
if category == "docs":
|
74 |
+
return [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
|
75 |
+
results]
|
76 |
+
else:
|
77 |
+
return [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
|
78 |
+
"description": r['description']} for r in results]
|
79 |
+
|
80 |
+
|
81 |
+
def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0,
|
82 |
+
run_local: Optional[bool] = False) -> list | None:
|
83 |
+
"""
|
84 |
+
Function to search documents in the Wikivoyage database.
|
85 |
+
"""
|
86 |
+
return search(query=query, table_name="wikivoyage_documents", category="docs",
|
87 |
+
limit=limit, reranking=reranking, run_local=run_local)
|
88 |
+
|
89 |
+
|
90 |
+
def search_wikivoyage_listings(query: str, cities: list, limit: int = 10, reranking: int = 0,
|
91 |
+
run_local: Optional[bool] = False) -> list | None:
|
92 |
+
"""
|
93 |
+
Function to search listings in the Wikivoyage database, post-filtered by cities.
|
94 |
+
"""
|
95 |
+
cities_filter = f"city IN {tuple(cities)}"
|
96 |
+
return search(query=query, table_name="wikivoyage_listings", filter_condition=cities_filter,
|
97 |
+
category="listings", limit=limit, reranking=reranking, run_local=run_local)
|
src/vectordb/vectordb.py
DELETED
@@ -1,190 +0,0 @@
|
|
1 |
-
# from src import *
|
2 |
-
from src.vectordb.helpers import *
|
3 |
-
from src.vectordb.lancedb_init import *
|
4 |
-
import logging
|
5 |
-
import os
|
6 |
-
import lancedb
|
7 |
-
from lancedb.rerankers import ColbertReranker
|
8 |
-
|
9 |
-
import sys
|
10 |
-
logger = logging.getLogger(__name__)
|
11 |
-
from typing import Optional
|
12 |
-
|
13 |
-
# db = lancedb.connect("/tmp/db")
|
14 |
-
|
15 |
-
def create_wikivoyage_docs_db_and_add_data():
|
16 |
-
"""
|
17 |
-
|
18 |
-
Creates wikivoyage documents table and ingests data
|
19 |
-
|
20 |
-
"""
|
21 |
-
uri = database_dir
|
22 |
-
current_dir = os.path.split(os.getcwd())[1]
|
23 |
-
|
24 |
-
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
25 |
-
uri = uri.replace("../../", "../")
|
26 |
-
|
27 |
-
db = lancedb.connect(uri)
|
28 |
-
logger.info("Connected to DB. Reading data now...")
|
29 |
-
df = read_docs()
|
30 |
-
filtered_df = preprocess_df(df)
|
31 |
-
logger.info("Finished reading data, attempting to create table and ingest the data...")
|
32 |
-
|
33 |
-
db.drop_table("wikivoyage_documents", ignore_missing=True)
|
34 |
-
table = db.create_table("wikivoyage_documents", schema=WikivoyageDocuments)
|
35 |
-
|
36 |
-
table.add(filtered_df.to_dict('records'))
|
37 |
-
logger.info("Completed ingestion.")
|
38 |
-
|
39 |
-
|
40 |
-
def create_wikivoyage_listings_db_and_add_data():
|
41 |
-
"""
|
42 |
-
|
43 |
-
Creates wikivoyage listings table and ingests data
|
44 |
-
|
45 |
-
"""
|
46 |
-
uri = database_dir
|
47 |
-
current_dir = os.path.split(os.getcwd())[1]
|
48 |
-
|
49 |
-
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
50 |
-
uri = uri.replace("../../", "../")
|
51 |
-
|
52 |
-
db = lancedb.connect(uri)
|
53 |
-
logger.info("Connected to DB. Reading data now...")
|
54 |
-
df = read_listings()
|
55 |
-
logger.info("Finished reading data, attempting to create table and ingest the data...")
|
56 |
-
# filtered_df = preprocess_df(df)
|
57 |
-
|
58 |
-
db.drop_table("wikivoyage_listings", ignore_missing=True)
|
59 |
-
table = db.create_table("wikivoyage_listings", schema=WikivoyageListings)
|
60 |
-
|
61 |
-
table.add(df.astype('str').to_dict('records'))
|
62 |
-
logger.info("Completed ingestion.")
|
63 |
-
|
64 |
-
|
65 |
-
def search_wikivoyage_docs(query: str, limit: int = 10, reranking: int = 0, run_local: Optional[bool] = False):
|
66 |
-
"""
|
67 |
-
|
68 |
-
Function to search the wikivoyage database an return most relevant chunked docs.
|
69 |
-
|
70 |
-
Args:
|
71 |
-
- query: str
|
72 |
-
- limit: number of results (default is 10)
|
73 |
-
- reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
|
74 |
-
|
75 |
-
"""
|
76 |
-
if run_local:
|
77 |
-
uri = database_dir
|
78 |
-
current_dir = os.path.split(os.getcwd())[1]
|
79 |
-
|
80 |
-
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
81 |
-
uri = uri.replace("../../", "../")
|
82 |
-
else:
|
83 |
-
uri = os.environ["BUCKET_NAME"]
|
84 |
-
# print(uri)
|
85 |
-
try:
|
86 |
-
db = lancedb.connect(uri)
|
87 |
-
except Exception as e:
|
88 |
-
logger.error(f"Error while connecting to DB: {e}")
|
89 |
-
|
90 |
-
logger.info("Connected to Wikivoyage DB.")
|
91 |
-
print("Tablenames: ", db.table_names())
|
92 |
-
|
93 |
-
# query_embedding = embed_query(query)
|
94 |
-
table = db.open_table("wikivoyage_documents")
|
95 |
-
|
96 |
-
if reranking:
|
97 |
-
try:
|
98 |
-
reranker = ColbertReranker(column='text')
|
99 |
-
results = table.search(query) \
|
100 |
-
.metric('cosine') \
|
101 |
-
.rerank(reranker=reranker) \
|
102 |
-
.limit(limit) \
|
103 |
-
.to_list()
|
104 |
-
except Exception as e:
|
105 |
-
exc_type, exc_obj, exc_tb = sys.exc_info()
|
106 |
-
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
107 |
-
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
108 |
-
|
109 |
-
else:
|
110 |
-
try:
|
111 |
-
results = table.search(query) \
|
112 |
-
.limit(limit) \
|
113 |
-
.metric('cosine') \
|
114 |
-
.to_list()
|
115 |
-
except Exception as e:
|
116 |
-
exc_type, exc_obj, exc_tb = sys.exc_info()
|
117 |
-
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
118 |
-
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
119 |
-
|
120 |
-
logger.info("Found the most relevant documents.")
|
121 |
-
city_lists = [{"city": r['city'], "country": r['country'], "section": r['section'], "text": r['text']} for r in
|
122 |
-
results]
|
123 |
-
|
124 |
-
# context = [f"city: {r['city']}, country: {r['country']}, name: {r['title']}, description: {r['description']}"
|
125 |
-
# for r in results]
|
126 |
-
|
127 |
-
return city_lists
|
128 |
-
|
129 |
-
|
130 |
-
def search_wikivoyage_listings(query:str, cities: list, limit: int=10, reranking: int = 0, run_local: Optional[bool] = False):
|
131 |
-
"""
|
132 |
-
|
133 |
-
Function to search the wikivoyage database an return most relevant listings, post-filtered by the recommended
|
134 |
-
cities provided by wikivoyage_documents table.
|
135 |
-
|
136 |
-
Args:
|
137 |
-
- query: str
|
138 |
-
- cities: list
|
139 |
-
- limit: number of results (default is 10)
|
140 |
-
- reranking: bool (0 or 1), if activated, CrossEncoderReranker is used.
|
141 |
-
|
142 |
-
"""
|
143 |
-
if run_local:
|
144 |
-
uri = database_dir
|
145 |
-
current_dir = os.path.split(os.getcwd())[1]
|
146 |
-
|
147 |
-
if "src" or "tests" in current_dir: # hacky way to get the correct path
|
148 |
-
uri = uri.replace("../../", "../")
|
149 |
-
else:
|
150 |
-
uri = os.environ["BUCKET_NAME"]
|
151 |
-
|
152 |
-
db = lancedb.connect(uri)
|
153 |
-
logger.info("Connected to Wikivoyage Listings DB.")
|
154 |
-
|
155 |
-
table = db.open_table("wikivoyage_listings")
|
156 |
-
|
157 |
-
cities_filter = f"city IN {tuple(cities)}"
|
158 |
-
|
159 |
-
if reranking:
|
160 |
-
try:
|
161 |
-
reranker = ColbertReranker(column='description')
|
162 |
-
results = table.search(query) \
|
163 |
-
.where(cities_filter) \
|
164 |
-
.metric('cosine') \
|
165 |
-
.rerank(reranker=reranker) \
|
166 |
-
.limit(limit) \
|
167 |
-
.to_list()
|
168 |
-
|
169 |
-
except Exception as e:
|
170 |
-
exc_type, exc_obj, exc_tb = sys.exc_info()
|
171 |
-
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
172 |
-
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
173 |
-
|
174 |
-
else:
|
175 |
-
try:
|
176 |
-
results = table.search(query) \
|
177 |
-
.where(cities_filter) \
|
178 |
-
.metric('cosine') \
|
179 |
-
.limit(limit) \
|
180 |
-
.to_list()
|
181 |
-
except Exception as e:
|
182 |
-
exc_type, exc_obj, exc_tb = sys.exc_info()
|
183 |
-
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
184 |
-
logger.error(f"Error while getting context: {e}, {(exc_type, fname, exc_tb.tb_lineno)}")
|
185 |
-
|
186 |
-
logger.info("Found the most relevant documents.")
|
187 |
-
city_listings = [{"city": r['city'], "country": r['country'], "type": r['type'], "title": r['title'],
|
188 |
-
"description": r['description']} for r in results]
|
189 |
-
|
190 |
-
return city_listings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|