earnings-calls-qa / utils /retriever.py
awinml's picture
Upload 206 files (#1)
c49f0b0
raw
history blame
3.19 kB
from utils.models import get_bm25_model, preprocess_text
import numpy as np
# BM25 Filtering and Retrieval
def filter_data_docs(data, ticker, quarter, year):
year_int = int(year)
data_subset = data[
(data["Year"] == year_int)
& (data["Quarter"] == quarter)
& (data["Ticker"] == ticker)
]
return data_subset
def get_bm25_search_hits(corpus, sparse_scores, top_n=50):
bm25_search = []
indices = []
for idx in sparse_scores:
if len(bm25_search) <= top_n:
bm25_search.append(corpus[idx])
indices.append(idx)
indices = [int(x) for x in indices]
return indices
# BM-25 Filtering
def get_indices_bm25(
data, query, ticker=None, quarter=None, year=None, num_candidates=50
):
if ticker is None or quarter is None or year is None:
corpus, bm25 = get_bm25_model(data)
else:
filtered_data = filter_data_docs(data, ticker, quarter, year)
corpus, bm25 = get_bm25_model(filtered_data)
tokenized_query = preprocess_text(query).split()
sparse_scores = np.argsort(bm25.get_scores(tokenized_query), axis=0)[::-1]
indices_hits = get_bm25_search_hits(corpus, sparse_scores, num_candidates)
return indices_hits
def query_pinecone(
dense_vec,
top_k,
index,
year=None,
quarter=None,
ticker=None,
keywords=None,
indices=None,
threshold=0.25,
):
filter_dict = {
"QA_Flag": {"$eq": "Answer"},
}
if year is not None:
filter_dict["Year"] = int(year)
if quarter is not None:
filter_dict["Quarter"] = {"$eq": quarter}
if ticker is not None:
filter_dict["Ticker"] = {"$eq": ticker}
if keywords is not None:
filter_dict["Keywords"] = {"$in": keywords}
if indices is not None:
filter_dict["index"] = {"$in": indices}
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter=filter_dict,
include_metadata=True,
)
# filter the context passages based on the score threshold
filtered_matches = []
for match in xc["matches"]:
if match["score"] >= threshold:
filtered_matches.append(match)
xc["matches"] = filtered_matches
return xc
def sentence_id_combine(data, query_results, lag=1):
# Extract sentence IDs from query results
ids = [
result["metadata"]["Sentence_id"]
for result in query_results["matches"]
]
# Generate new IDs by adding a lag value to the original IDs
new_ids = [id + i for id in ids for i in range(-lag, lag + 1)]
# Remove duplicates and sort the new IDs
new_ids = sorted(set(new_ids))
# Create a list of lookup IDs by grouping the new IDs in groups of lag*2+1
lookup_ids = [
new_ids[i : i + (lag * 2 + 1)]
for i in range(0, len(new_ids), lag * 2 + 1)
]
# Create a list of context sentences by joining the sentences
# corresponding to the lookup IDs
context_list = [
" ".join(
data.loc[data["Sentence_id"].isin(lookup_id), "Text"].to_list()
)
for lookup_id in lookup_ids
]
context = " ".join(context_list).strip()
return context