awinml's picture
Upload 16 files
8046e72
def get_bm25_search_hits(corpus, sparse_scores, top_n=50):
bm25_search = []
indices = []
for idx in sparse_scores:
if len(bm25_search) <= top_n:
bm25_search.append(corpus[idx])
indices.append(idx)
indices = [int(x) for x in indices]
return indices
def query_pinecone(
dense_vec,
top_k,
index,
year,
quarter,
ticker,
participant_type,
keywords=None,
indices=None,
threshold=0.25,
):
if participant_type == "Company Speaker":
participant = "Answer"
else:
participant = "Question"
if year == "All":
if quarter == "All":
if indices != None:
if keywords != None:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
"index": {"$in": indices},
},
include_metadata=True,
)
else:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"index": {"$in": indices},
},
include_metadata=True,
)
else:
if keywords != None:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
},
include_metadata=True,
)
else:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
},
include_metadata=True,
)
else:
if indices != None:
if keywords != None:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
"index": {"$in": indices},
},
include_metadata=True,
)
else:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"index": {"$in": indices},
},
include_metadata=True,
)
else:
if keywords != None:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
},
include_metadata=True,
)
else:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
},
include_metadata=True,
)
else:
# search pinecone index for context passage with the answer
if indices != None:
if keywords != None:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": int(year),
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
"index": {"$in": indices},
},
include_metadata=True,
)
else:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": int(year),
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"index": {"$in": indices},
},
include_metadata=True,
)
else:
if keywords != None:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": int(year),
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
},
include_metadata=True,
)
else:
xc = index.query(
vector=dense_vec,
top_k=top_k,
filter={
"Year": int(year),
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
},
include_metadata=True,
)
# filter the context passages based on the score threshold
filtered_matches = []
for match in xc["matches"]:
if match["score"] >= threshold:
filtered_matches.append(match)
xc["matches"] = filtered_matches
return xc
def query_pinecone_sparse(
dense_vec,
sparse_vec,
top_k,
index,
year,
quarter,
ticker,
participant_type,
keywords=None,
indices=None,
threshold=0.25,
):
if participant_type == "Company Speaker":
participant = "Answer"
else:
participant = "Question"
if year == "All":
if quarter == "All":
xc = index.query(
vector=dense_vec,
sparse_vector=sparse_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$in": ["Q1", "Q2", "Q3", "Q4"]},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
},
include_metadata=True,
)
else:
xc = index.query(
vector=dense_vec,
sparse_vector=sparse_vec,
top_k=top_k,
filter={
"Year": {
"$in": [
int("2020"),
int("2019"),
int("2018"),
int("2017"),
int("2016"),
]
},
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
},
include_metadata=True,
)
else:
# search pinecone index for context passage with the answer
xc = index.query(
vector=dense_vec,
sparse_vector=sparse_vec,
top_k=top_k,
filter={
"Year": int(year),
"Quarter": {"$eq": quarter},
"Ticker": {"$eq": ticker},
"QA_Flag": {"$eq": participant},
"Keywords": {"$in": keywords},
},
include_metadata=True,
)
# filter the context passages based on the score threshold
filtered_matches = []
for match in xc["matches"]:
if match["score"] >= threshold:
filtered_matches.append(match)
xc["matches"] = filtered_matches
return xc
def format_query(query_results):
# extract passage_text from Pinecone search result
context = [
result["metadata"]["Text"] for result in query_results["matches"]
]
return context
def sentence_id_combine(data, query_results, lag=1):
# Extract sentence IDs from query results
ids = [
result["metadata"]["Sentence_id"]
for result in query_results["matches"]
]
# Generate new IDs by adding a lag value to the original IDs
new_ids = [id + i for id in ids for i in range(-lag, lag + 1)]
# Remove duplicates and sort the new IDs
new_ids = sorted(set(new_ids))
# Create a list of lookup IDs by grouping the new IDs in groups of lag*2+1
lookup_ids = [
new_ids[i : i + (lag * 2 + 1)]
for i in range(0, len(new_ids), lag * 2 + 1)
]
# Create a list of context sentences by joining the sentences
# corresponding to the lookup IDs
context_list = [
" ".join(
data.loc[data["Sentence_id"].isin(lookup_id), "Text"].to_list()
)
for lookup_id in lookup_ids
]
return context_list
def text_lookup(data, sentence_ids):
context = ". ".join(data.iloc[sentence_ids].to_list())
return context
def year_quarter_range(start_quarter, start_year, end_quarter, end_year):
"""Creates a list of all (year, quarter) pairs that lie in the range including the start and end quarters."""
start_year = int(start_year)
end_year = int(end_year)
quarters = (
[("Q1", "Q2", "Q3", "Q4")] * (end_year - start_year)
+ [("Q1", "Q2", "Q3" if end_quarter == "Q4" else "Q4")]
* (end_quarter == "Q4")
+ [
(
"Q1"
if start_quarter == "Q1"
else "Q2"
if start_quarter == "Q2"
else "Q3"
if start_quarter == "Q3"
else "Q4",
)
* (end_year - start_year)
]
)
years = list(range(start_year, end_year + 1))
list_year_quarter = [
(y, q) for y in years for q in quarters[years.index(y)]
]
# Remove duplicate pairs
seen = set()
list_year_quarter_cleaned = []
for tup in list_year_quarter:
if tup not in seen:
seen.add(tup)
list_year_quarter_cleaned.append(tup)
return list_year_quarter_cleaned
def multi_document_query(
dense_query_embedding,
sparse_query_embedding,
num_results,
pinecone_index,
start_quarter,
start_year,
end_quarter,
end_year,
ticker,
participant_type,
threshold,
):
pass